<a href="https://colab.research.google.com/github/wangleiphy/ml4p/blob/main/projects/alanine_dipeptide.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# some necessary packages
!pip install -q dm-haiku # neural network library
!pip install -q optax     # optimization library
!pip install -q nglview   # visualize molecules
!pip install -q ase    

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.7 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/5.7 MB[0m [31m7.4 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/5.7 MB[0m [31m37.7 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m5.7/5.7 MB[0m [31m65.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.7/5.7 MB[0m [31m50.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m78.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for nglview (pyproject.toml) ... [?25l[?25hdon

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import haiku as hk
import optax
from functools import partial

## Data

In [None]:
!wget http://ftp.imp.fu-berlin.de/pub/cmb-data/alanine-dipeptide-3x250ns-heavy-atom-positions.npz

--2023-03-09 13:44:43--  http://ftp.imp.fu-berlin.de/pub/cmb-data/alanine-dipeptide-3x250ns-heavy-atom-positions.npz
Resolving ftp.imp.fu-berlin.de (ftp.imp.fu-berlin.de)... 160.45.117.8
Connecting to ftp.imp.fu-berlin.de (ftp.imp.fu-berlin.de)|160.45.117.8|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 90000544 (86M)
Saving to: ‘alanine-dipeptide-3x250ns-heavy-atom-positions.npz’


2023-03-09 13:44:50 (13.1 MB/s) - ‘alanine-dipeptide-3x250ns-heavy-atom-positions.npz’ saved [90000544/90000544]



In [None]:
data = np.load('/content/alanine-dipeptide-3x250ns-heavy-atom-positions.npz')

In [None]:
train_data, validation_data, test_data = data['arr_0'], data['arr_1'], data['arr_2']

In [None]:
train_data.shape

(250000, 30)

In [None]:
natoms = 10
dim = 3
train_data = train_data.reshape((-1, natoms, dim))

In [None]:
from ase import Atoms
from ase.visualize import view

L = 23.222 # box length in Angstrom https://markovmodel.github.io/mdshare/ALA2/#alanine-dipeptide

atoms = Atoms('CCONCCCONC', positions=train_data[0]*L)
view(atoms, viewer='ngl')

HBox(children=(NGLWidget(), VBox(children=(Dropdown(description='Show', options=('All', 'C', 'O', 'N'), value=…

we fix the position of the first atom to be at the origin $(0,0,0)$

In [None]:
train_data -= train_data[:, None, 0]

In [None]:
train_data[0]

Array([[ 0.        ,  0.        ,  0.        ],
       [-0.10866126, -0.08496571, -0.05294933],
       [-0.23061313, -0.05237889, -0.03257903],
       [-0.06925163, -0.17692566, -0.14216499],
       [-0.16980399, -0.26459217, -0.21098728],
       [-0.11107272, -0.31215835, -0.34160462],
       [-0.22918014, -0.3769319 , -0.12853031],
       [-0.1992424 , -0.4020753 , -0.00867443],
       [-0.3358656 , -0.4447801 , -0.18495616],
       [-0.4171049 , -0.5499315 , -0.11980057]], dtype=float32)

# Model

In [None]:
def make_transformer(key, n, dim, num_layers, num_heads, key_size):
   
    # read this to understand why we need these lines https://sjmielke.com/jax-purify.htm
    @hk.without_apply_rng
    @hk.transform
    def network(x): 
        assert x.ndim == 2  # (n, dim)
        n = x.shape[0]
        
        model_size = 2*dim # since we will predict mean and variance of atom position
        mask = jnp.tril(np.ones((1, n, n))) # mask for the attention matrix 

        initializer = hk.initializers.TruncatedNormal(0.01)
        h = hk.Linear(model_size, w_init=initializer)(x)
        for _ in range(num_layers):
            # https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/attention.py 
            attn_block = hk.MultiHeadAttention(num_heads=num_heads, 
                                               key_size=key_size, 
                                               model_size=model_size,
                                               w_init = initializer
                                                )   
            h = attn_block(h, h, h, mask) + h

            dense_block = hk.Sequential([hk.Linear(4 * model_size, w_init=initializer),
                                         jax.nn.gelu,
                                         hk.Linear(model_size, w_init=initializer)]
                                         )
            h = dense_block(h) + h 
        
        mu, sigma = jnp.split(h, 2, axis=-1)
        sigma = jax.nn.softplus(sigma) # to ensure positivity
        h = jnp.concatenate([mu, sigma], axis=-1) 
        return h    
    
    x = jax.random.normal(key, (natoms, dim))
    params = network.init(key, x)
    return params, network.apply

In [None]:
num_layers = 4
num_heads = 8
key_size = 16

key = jax.random.PRNGKey(42)
params, model = make_transformer(key, natoms, dim, num_layers, num_heads, key_size)

In [None]:
from jax.flatten_util import ravel_pytree
ravel_pytree(params)[0].size # number of parameters in the model

15144

In [None]:
model(params, train_data[0]).shape

(10, 6)

We can check that the model is indeed autoregressive

In [None]:
def test_fn(x):
    outputs = model(params, x)
    return (outputs).sum(axis=-1)
jac = jax.jacfwd(test_fn)(train_data[0])
jac.shape

(10, 10, 3)

In [None]:
jac = jnp.linalg.norm(jac, axis=-1)
(jac != 0.).astype(int)

Array([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)

# Loss

For maximum likelihood estimation we minimize $$\mathcal{L} = \mathop{\mathbb{E}}_{x\sim \mathrm{data}} \left[ \ln p(x)\right] $$

In [None]:
def make_mle_loss(model):

    @partial(jax.vmap, in_axes=(None, 0), out_axes=0)
    def logp_fn(params, x):
        outputs = model(params, x)
        mu, sigma = jnp.split(outputs[:-1, :], 2, axis=-1) # until the last one 
        return jax.scipy.stats.norm.logpdf(x[1:, :], loc=mu, scale=sigma) # [1:] because we fix the first atom to be at 000

    def loss_fn(params, x):
        logp = logp_fn(params, x)
        return -jnp.mean(logp)
        
    return loss_fn

loss_fn = make_mle_loss(model)

In [None]:
loss_fn(params, train_data[:5])

Array(0.61155874, dtype=float32)

# Optimization

In [None]:
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)

@jax.jit
def update(params, opt_state, data):
    value, grad = jax.value_and_grad(loss_fn)(params, data)
    updates, opt_state = optimizer.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, value


In [None]:
train_data = jax.random.permutation(key, train_data)
train_data = train_data[:10000] # to save time we use only 10000 of them to train

batchsize = 100
assert len(train_data)%batchsize==0

for epoch in range(10):
    key, subkey = jax.random.split(key)
    train_data = jax.random.permutation(subkey, train_data)

    total_loss = 0.0 
    counter = 0 
    for batch_index in range(0, len(train_data), batchsize):
        data = train_data[batch_index:batch_index+batchsize]
        key, subkey = jax.random.split(key)
        params, opt_state, loss = update(params, opt_state, data)
        total_loss += loss 
        counter += 1

    print(epoch, total_loss/counter) 


0 -1.2442603
1 -1.2600908
2 -1.277534
3 -1.298891
4 -1.3170371
5 -1.3422662
6 -1.3507986
7 -1.3897886
8 -1.4162788
9 -1.4675642


# Sample

Now, we can try to sample from the trained model. First, remember the model predicts mu and sigma of Gaussian distribution.

In [None]:
@jax.vmap
def inference(x): # here x can be data with various length
    outputs = model(params, x)
    mu, sigma = jnp.split(outputs[-1, :], 2, axis=-1) # only use the last one
    return mu, sigma

We sample in an autoregressive fashion: starting from atom 0 at the origin, then atom 1, then atom 2 ... 

In [None]:
samples = jnp.zeros((batchsize, natoms, dim))
for i in range(1, natoms):
    mu, sigma = inference(samples[:, :i]) 
    key, subkey = jax.random.split(key)
    x = jax.random.normal(subkey, (batchsize, dim)) * sigma + mu
    samples = samples.at[:, i].set( x ) 

In [None]:
samples[0]

Array([[ 0.        ,  0.        ,  0.        ],
       [-0.36303204, -0.10760424, -0.6692486 ],
       [ 0.13592339, -0.04337658,  1.0170809 ],
       [-0.3794167 , -0.7640658 , -0.43109998],
       [-0.17063145, -0.03453469, -0.3230593 ],
       [-0.3525418 , -1.2325457 , -1.102741  ],
       [ 0.2501487 , -1.2095321 ,  0.6943912 ],
       [-0.02452633,  0.49321824,  0.26617572],
       [ 0.12774971,  0.5856017 ,  0.76268756],
       [ 0.4654139 , -0.54912704,  0.2514172 ]], dtype=float32)

Have a look at generated sample

In [None]:
atoms = Atoms('CCONCCCONC', positions=samples[0]*L)
view(atoms, viewer='ngl')

HBox(children=(NGLWidget(), VBox(children=(Dropdown(description='Show', options=('All', 'C', 'O', 'N'), value=…