# Customizable Design Notebook

This notebook contains a scaffold of a simple optimization for a custom state-level property.
Please implement your own state-level observable in `observable_fn` and set the value of `target_observable`.

## Imports

We import necessary libraries, e.g. `jax` and `jax-md`.
This section can likely remain unchanged, unless you need special libraries for your observable implementation.

In [1]:
import pdb
import numpy as onp
import pandas as pd
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import argparse
from pathlib import Path
from copy import deepcopy
import pprint
import time

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.scipy as jsp
from jax import vmap, jit, value_and_grad, lax, random
from jax_md import space, simulate
import optax

from idp_design.energy_prob import get_energy_fn
import idp_design.utils as utils
import idp_design.checkpoint as checkpoint
import idp_design.observable as observable

## Constants

Here we define constants for the optimization.
This includes
- the sequence length
- hyperparameters for sampling reference states via simulation (e.g. the number of steps)
- hyperparameters for optimization (e.g. learning rate, number of iterations).

You may wish to change these parameters depending on your optimization, e.g. longer sequences will require longer simulations to sample a representative ensemble.

In [1]:
key = random.PRNGKey(0)
res_masses = utils.masses

# Sequence parameters
seq_length = 50

bonded_nbrs = jnp.array([(i, i+1) for i in range(seq_length-1)])
unbonded_nbrs = list()
for pair in itertools.combinations(jnp.arange(seq_length), 2):
    unbonded_nbrs.append(pair)
unbonded_nbrs = jnp.array(unbonded_nbrs)

unbonded_nbrs_set = set([tuple(pr) for pr in onp.array(unbonded_nbrs)])
bonded_nbrs_set = set([tuple(pr) for pr in onp.array(bonded_nbrs)])
unbonded_nbrs = jnp.array(list(unbonded_nbrs_set - bonded_nbrs_set))

# Simulation parameters
n_sims = 10
n_eq_steps = 10000
n_sample_steps = 250000
sample_every = 1000
assert(n_sample_steps % sample_every == 0)
num_points_per_batch = n_sample_steps // sample_every
n_ref_states = num_points_per_batch * n_sims

kT = 300*utils.kb
beta = 1 / kT
dt = 0.2
gamma = 0.001

n_iters = 50
lr = 0.1
min_neff_factor = 0.90
min_n_eff = int(n_ref_states * min_neff_factor)
max_approx_iters = 5

gumbel_end = 0.01
gumbel_start = 1.0
gumbel_temps = onp.linspace(gumbel_start, gumbel_end, n_iters)


NameError: name 'random' is not defined

## Helper Functions

Here we define helper functions

First, we define misc. helper functions such as
- Normalization of the logits to a probabilistic sequence representation. Note that we optimize a `(n, 20)` set of logits that can be deterministically mapped to a probabilistic sequence.
- An energy function that computes the expected energy of the probabilistic sequence

We then define helper functions specific to simulation via `jax-md` based on the hyperparameters defined above.

In [None]:
def normalize(logits, temp, norm_key):
    gumbel_weights = jax.random.gumbel(norm_key, logits.shape)
    pseq = jax.nn.softmax(logits / temp)

    return pseq

displacement_fn, shift_fn = space.free()

subterms_fn, energy_fn = get_energy_fn(bonded_nbrs, unbonded_nbrs, displacement_fn)
energy_fn = jit(energy_fn)
mapped_energy_fn = vmap(energy_fn, (0, None)) # To evaluate a set of states for a given pseq

checkpoint_every = 50
scan = checkpoint.get_scan(checkpoint_every)

### Simulation

In [None]:
@jit
def eq_fn(eq_key, R, pseq, mass):
    init_fn, step_fn = simulate.nvt_langevin(energy_fn, shift_fn, dt, kT, gamma)
    init_state = init_fn(eq_key, R, pseq=pseq, mass=mass)
    def fori_step_fn(t, state):
        return step_fn(state, pseq=pseq)
    fori_step_fn = jit(fori_step_fn)

    eq_state = lax.fori_loop(0, n_eq_steps, fori_step_fn, init_state)
    return eq_state.position

@jit
def sample_fn(sample_key, R_eq, pseq, mass):
    init_fn, step_fn = simulate.nvt_langevin(energy_fn, shift_fn, dt, kT, gamma)
    init_state = init_fn(sample_key, R_eq, pseq=pseq, mass=mass)

    def fori_step_fn(t, state):
        return step_fn(state, pseq=pseq)
    fori_step_fn = jit(fori_step_fn)

    @jit
    def scan_fn(state, step):
        state = lax.fori_loop(0, sample_every, fori_step_fn, state)
        return state, state.position

    _, traj = lax.scan(scan_fn, init_state, jnp.arange(num_points_per_batch))
    return traj

@jit
def batch_sim(ref_key, R, pseq, mass):

    ref_key, eq_key = random.split(ref_key)
    eq_keys = random.split(eq_key, n_sims)
    eq_states = vmap(eq_fn, (0, None, None, None))(eq_keys, R, pseq, mass)

    sample_keys = random.split(ref_key, n_sims)
    sample_trajs = vmap(sample_fn, (0, 0, None, None))(sample_keys, eq_states, pseq, mass)

    sample_traj = sample_trajs.reshape(-1, seq_length, 3)
    return sample_traj

## Define Custom Observable

Here, you may define a **custom observable** as the target for optimization.
Observables are defined at the state-level and you may optimize an arbitrary
(continuous and differentiable) function of the expected value of the observable
across the entire ensemble. This notebook assumes that you only wish to optimize
a simple root mean square error (RMSE) from a target value. Note that certain
observables may require enhanced sampling to obtain a representative distribution,
which we do not implement here.

In [None]:
def observable_fn(R):
    """
    This defines a state-level observable for which you would like to 
    optimize a sequence with a target expected value. `R` is an `(n, 3)`
    JAX array denoting the positions of the `n` particles. 
    
    For example, to optimize the end-to-end distance, this function would 
    return:
    
    `jnp.linalg.norm(displacement_fn(R[0], R[-1]))`
    """
    raise NotImplementedError
target_observable = FIXME # The target value of the observable

## Define Function for Getting Reference States

Given helper functions for simulation via `jax-md` and our observable, 
we define a function that samples reference states for use in differentiable
trajectory reweighting (DiffTRe). This includes running a simulation, sampling
states with some periodicity (defined as `sample_every` above), and returning
- the sampled states
- the calculated observable of each reference state
- the calculated energy of each reference state

In [None]:
def get_ref_states(params, i, R, iter_key, temp):
    curr_logits = params['logits']
    iter_key, norm_key = random.split(iter_key)
    curr_pseq = normalize(curr_logits, temp, norm_key)

    iter_dir = ref_traj_dir / f"iter{i}"
    iter_dir.mkdir(parents=False, exist_ok=False)

    curr_mass = utils.get_pseq_mass(curr_pseq, res_masses=res_masses)

    iter_key, batch_key = random.split(iter_key)
    start = time.time()
    sample_traj = batch_sim(batch_key, R, curr_pseq, curr_mass)
    end = time.time()
    print(f"- Batched simulation took {end - start} seconds")
    sample_traj = utils.tree_stack(sample_traj)

    sample_observables = vmap(observable_fn(sample_traj)
    mean_observable = onp.mean(sample_observables)

    sample_energies = mapped_energy_fn(sample_traj, curr_pseq)

    return sample_traj, sample_energies, jnp.array(sample_observables), mean_observable

## Define Loss Function

The second required component for DiffTRe is a loss function that takes as input a reference ensemble and
returns the loss as a function of the computed expected value of the observable by reweighting with
respect to the current parameters.

In [None]:
def loss_fn(params, ref_states, ref_energies, ref_observables, temp, loss_key):
    logits = params['logits']
    loss_key, norm_key = random.split(loss_key)
    pseq = normalize(logits, temp, norm_key)

    energy_scan_fn = lambda state, ts: (None, energy_fn(ts, pseq=pseq))
    _, new_energies = scan(energy_scan_fn, None, ref_states)

    weights, n_eff = utils.compute_weights(ref_energies, new_energies, beta)
    weighted_observables = weights * ref_observables # element-wise multiplication
    expected_observable = jnp.sum(weighted_observables)


    mse = (expected_observable - target_observable)**2
    rmse = jnp.sqrt(mse)
    loss = rmse

    return loss, (n_eff, expected_observable, pseq)
grad_fn = value_and_grad(loss_fn, has_aux=True)
grad_fn = jit(grad_fn)

## Setup Optimization

Finally, we set up the final pieces of our optimization:
- the initialized `(n, 20)` logits (that can be deterministically mapped to a probabilistic sequence)
- an optimizer
- an initial state

In [None]:
init_logits = onp.full((seq_length, 20), 100.0)
init_logits = jnp.array(init_logits, dtype=jnp.float64)

# Setup the optimization
params = {"logits": init_logits}
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(params)

R_init = list()
init_spring_r0 = utils.spring_r0
for i in range(seq_length):
    R_init.append([out_box_size/2, out_box_size/2, out_box_size/2+init_spring_r0*i])
R_init = jnp.array(R_init)

## Generate Initial Reference States

Before we can do our first gradient update, we have to sample an initial set of reference states.

In [None]:
key, iter_key = random.split(key)
ref_states, ref_energies, ref_observables, mean_observable = get_ref_states(params, 0, R_init, iter_key, temp=gumbel_temps[0])

## Optimize

We can then perform gradient descent iteratively, resampling reference states when necessary!

In [None]:
num_resample_iters = 0
for i in range(n_iters):
    print(f"\nIteration {i}:")
    key, loss_key = random.split(key)
    (loss, aux), grads = grad_fn(params, ref_states, ref_energies, ref_observables, gumbel_temps[i], loss_key)
    n_eff = aux[0]
    num_resample_iters += 1

    if n_eff < min_n_eff or num_resample_iters > max_approx_iters:
        print(f"- N_eff was {n_eff}... resampling reference states...")
        num_resample_iters = 0

        key, iter_key = random.split(key)
        ref_states, ref_energies, ref_observables, mean_rg = get_ref_states(
            params, i, utils.recenter(ref_states[-1], out_box_size), iter_key, gumbel_temps[i]
        )

        (loss, aux), grads = grad_fn(params, ref_states, ref_energies, ref_observables, gumbel_temps[i], loss_key)
    (n_eff, expected_observable, pseq) = aux


    max_residues = jnp.argmax(pseq, axis=1)
    argmax_seq = ''.join([utils.RES_ALPHA[res_idx] for res_idx in max_residues])
    print(f"- Argmax seq: {argmax_seq}")
    
    print(f"- Loss: {loss}")
    print(f"- Current value: {expected_observable}")

    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
