# ASMSA: Train AAE model with the tuned hyperparameters

**Previous steps**
- [prepare.ipynb](prepare.ipynb): Download and sanity check input files
- [tune.ipynb](tune.ipynb): Perform initial hyperparameter tuning for this molecule

**Next step**
- [md.ipynb](md.ipynb): Use a trained model in MD simulation with Gromacs

## Notebook setup

In [None]:
threads = 2

import os
os.environ['OMP_NUM_THREADS']=str(threads)
import tensorflow as tf

# PyTorch favours OMP_NUM_THREADS in environment
import torch

# Tensorflow needs explicit cofig calls
tf.config.threading.set_inter_op_parallelism_threads(threads)
tf.config.threading.set_intra_op_parallelism_threads(threads)

In [None]:
from asmsa.tuning_analyzer import TuningAnalyzer
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import mdtraj as md
import numpy as np
import urllib.request
from tensorflow import keras
import keras_tuner
import asmsa.visualizer as visualizer
import asmsa


from asmsa.plot_training import LiveTrainingPlot

## Input files

All input files are prepared (up- or downloaded) in **prepare.ipynb**. 


In [None]:
exec(open('inputs.py').read())

## Load datasets
Load filtered trajectory datasets that were processed in **prepare.ipynb**. Trajectories are in internal coordinates format.

In [None]:
# load train dataset
X_train = tf.data.Dataset.load('datasets/intcoords/train')

# get batched version of dataset to feed to AAE model for training
X_train_batched = X_train.batch(hps['batch_size'],drop_remainder=True)

# get numpy version for visualization purposes
X_train_np = np.stack(list(X_train))
X_train_np.shape

In [None]:
# load test dataset
X_test = tf.data.Dataset.load('datasets/intcoords/test')

# get batched version of dataset to feed to AAE model for prediction
X_test_batched = X_test.batch(hps['batch_size'],drop_remainder=True)

# get numpy version for testing purposes
X_test_np = np.stack(list(X_test))
X_test_np.shape

In [None]:
X_val = tf.data.Dataset.load('datasets/intcoords/validate').batch(hps['batch_size'],drop_remainder=True)
X_val_np = np.stack(list(X_val))
X_val_np.shape

In [None]:
# Get best HP from latest tuning 
# e.g: "analysis/xxx-yyy/"
# ... or don't specify, by default use the last analysis

analyzer = TuningAnalyzer()
analyzer.get_best_hp(num_trials=3)

In [None]:
# Select HP to use by specifying trial_id
#  e.g: trial_id = '483883b929b3445bff6dee9759c4d50ee3a4ba7f0db22e665c49b5f942d9693b'
# ... or don't specify, by default use the trial with the lowest score
trial_id = ''

hps = None
for trial in analyzer.sorted_trials:
    if trial['trial_id'] == trial_id:
        hps = trial['hp']
    
if not hps:
    print(f'Could not find trial with specified ID, using one with the lowest score - {analyzer.sorted_trials[0]["trial_id"]}')
    hps = analyzer.sorted_trials[0]['hp']
    
print(hps)

In [None]:
# Pick best number of encoder and discriminator seeds from plots in tune.ipynb 
best_enc_seed=128
best_disc_seed=32 

## Train

### Distribution prior
Train with common prior distributions. See https://www.tensorflow.org/probability/api_docs/python/tfp/distributions for all available distributions. It is ideal to use tuned Hyperparameters for training.

In [None]:
# Choose the prior p(z)

#prior = tfp.distributions.Normal(loc=0, scale=1)
#prior = tfp.distributions.Uniform()

# ...or Build your custom prior
'''
tfd = tfp.distributions
means = tf.constant([[0.7, 0.0],[-0.7, 0.0],[0.0, 0.7] ], dtype=tf.float32)
scales = tf.constant([[0.15, 0.15],[0.15, 0.15],[0.15, 0.15]], dtype=tf.float32)
components = tfd.MultivariateNormalDiag(loc=means, scale_diag=scales)
mix = tfd.Categorical(probs=[0.3, 0.3, 0.3])

prior = tfd.MixtureSameFamily(mixture_distribution=mix, components_distribution=components)
'''


In [None]:
# Prepare model using the best hyperparameters from analysis

test = asmsa.AAEModel((X_train_np.shape[1],),
                       prior=prior,
                       hp=hps,
                       enc_seed=best_enc_seed,
                       disc_seed=best_disc_seed,
                       with_density=False
                      )
test.compile()

In [None]:
# train it (can be repeated several times to add more epochs)

metric_groups = {
    'Autoencoder Loss': ['AE loss min', 'val_val_AE loss min'],
    'Discriminator Loss': ['disc loss min', 'val_val_disc loss min']
}

early_stop_cb = tf.keras.callbacks.EarlyStopping(
    monitor="val_val_AE loss min",
    min_delta=0.0001,
    patience=20,
    verbose=1,
    mode="min",
    restore_best_weights=True,
)

test.fit(X_train_batched, 
          epochs=1000,
          verbose=2, 
          validation_data=X_val,
          callbacks=[
              early_stop_cb,
              LiveTrainingPlot(metric_groups=metric_groups, freq=1),
              #visualizer.VisualizeCallback(test,freq=10,inputs=X_train_np[15000:25000],figsize=(12,3)) 
          ])

#Turn on the visualizer if you would like to see the latent space evolution every "freq" epochs. We advice to turn off the LiveTrainingPlot to avoid crowded output 

In [None]:
# final visualization, pick a slice of the input data for demo purposes
#visualizer.Visualizer(figsize=(12,3)).make_visualization(testm.call_enc(X_train_np[15000:20000]).numpy())
# on test data
visualizer.Visualizer(figsize=(12,3)).make_visualization(test.call_enc(X_test_np).numpy())

In [None]:
# load testing trajectory for further visualizations and computations
tr = md.load('x_train.xtc',top=conf)
idx=tr[0].top.select("name CA")

# for trivial cases like AlanineDipeptide
#idx=tr[0].top.select("element != H") 

tr.superpose(tr[0],atom_indices=idx)
geom = np.moveaxis(tr.xyz ,0,-1)
geom.shape

In [None]:
# Rgyr and rmsd color coded in low dim (rough view), compute any other properties according to your needs

lows = test.call_enc(X_train_np).numpy()
rg = md.compute_rg(tr)
base = md.load(conf)
rmsd = md.rmsd(tr,base[0])
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rg,cmap=cmap,s=1)
plt.colorbar(cmap=cmap)
plt.title("Rg")
plt.subplot(122)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rmsd,cmap=cmap,s=1)
plt.colorbar(cmap=cmap)
plt.title("RMSD")
plt.show()

### Image prior

Use Image as a prior distribution. Again use tuned Hyperparameters for better training performance.

In [None]:
urllib.request.urlretrieve("https://drive.google.com/uc?export=download&id=1I2WP92MMWS5s5vin_4cvmruuV-1W77Hl", "mushroom_bw.png")

In [None]:
mmush = asmsa.AAEModel((X_train_np.shape[1],),
                       hp=hps,
                       enc_seed=best_enc_seed,
                       disc_seed=best_disc_seed,
                       prior='mushroom_bw.png'
                      )
mmush.compile()

In [None]:
early_stop_cb = tf.keras.callbacks.EarlyStopping(
    monitor="val_val_AE loss min",
    min_delta=0.0001,
    patience=15,
    verbose=1,
    mode="min",
    restore_best_weights=True,
)

In [None]:
mmush.fit(X_train_batched, # X_train_dens, # X_train_batched,
          epochs=1000,
          verbose=2, 
          validation_data=X_val,
          callbacks=[
              early_stop_cb,
              LiveTrainingPlot(metric_groups=metric_groups, freq=1),
              #visualizer.VisualizeCallback(testm,freq=25,inputs=X_train_np[15000:25000],figsize=(12,3))
          ])
#Turn on the visualizer if you would like to see the latent space evolution every "freq" epochs. We advice to turn off the LiveTrainingPlot to avoid crowded output 

In [None]:
# Rgyr and rmsd color coded in low dim (rough view), compute any other properties according to your needs

step=4
tr2 = tr[::step]
lows = mmush.call_enc(X_test_np[::step]).numpy()
rg = md.compute_rg(tr2)
base = md.load(conf)
rmsd = md.rmsd(tr2,base[0])
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rg,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.title("Rg")
plt.subplot(122)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rmsd,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.title("RMSD")
plt.show()

## Save the encoder and decoder models

In [None]:
import tf2onnx
import onnx2torch
import tempfile

def _convert_to_onnx(model, destination_path):
    input_tensor = model.layers[0]._input_tensor
    input_signature = tf.TensorSpec(
        name=input_tensor.name, shape=input_tensor.shape, dtype=input_tensor.dtype
    )
    output_name = model.layers[-1].name

    @tf.function(input_signature=[input_signature])
    def _wrapped_model(input_data):
        return {output_name: model(input_data)}

    tf2onnx.convert.from_function(
        _wrapped_model, input_signature=[input_signature], output_path=destination_path
    )

In [None]:
model = test

In [None]:
with tempfile.NamedTemporaryFile() as onnx:
    _convert_to_onnx(model.enc,onnx.name)
    torch_enc = onnx2torch.convert(onnx.name)

example_input = torch.randn([X_train_np.shape[1]])
traced_script_module = torch.jit.trace(torch_enc, example_input)

traced_script_module.save('encoder-unif.pt')

In [None]:
lenc = torch.jit.load('encoder-unif.pt')
example_input = np.random.rand(10000,X_train_np.shape[1])
rtf = model.enc(example_input)
rpt = lenc(torch.tensor(example_input,dtype=torch.float32))

In [None]:
maxerr = np.max(np.abs(rtf - rpt.detach().numpy()))
maxerr

In [None]:
with tempfile.NamedTemporaryFile() as onnx:
    _convert_to_onnx(model.dec,onnx.name)
    torch_dec = onnx2torch.convert(onnx.name)

example_input = torch.randn([2])
traced_script_module = torch.jit.trace(torch_dec, example_input)

traced_script_module.save('decoder-unif.pt')

In [None]:
ldec = torch.jit.load('decoder-unif.pt')
example_input = np.random.rand(10000,2)
rtf = model.dec(example_input)
rpt = ldec(torch.tensor(example_input,dtype=torch.float32))

In [None]:
err = np.abs(rtf - rpt.detach().numpy())
train_mean = np.loadtxt('datasets/intcoords/mean.txt',dtype=np.float32).reshape(1,1,-1)
rerr = err/np.abs(train_mean)
np.max(err),np.max(rerr)

## Final visualization

In [None]:
test_geom = np.moveaxis(np.stack(list(tf.data.Dataset.load('datasets/geoms/test'))),2,0)
test_geom.shape

In [None]:
tr = md.load(traj, top=conf)
tr.xyz.shape

In [None]:
train_mean = np.loadtxt('datasets/intcoords/mean.txt',dtype=np.float32)
train_scale = np.loadtxt('datasets/intcoords/scale.txt',dtype=np.float32)

In [None]:
mol_model = torch.jit.load('features.pt')
torch_encoder = torch.jit.load('encoder-unif.pt')

In [None]:
class CompleteModel(torch.nn.Module):
    def __init__(self, mol_model, torch_encoder, train_mean, train_scale):
        super(CompleteModel, self).__init__()
        self.mol_model = mol_model
        self.torch_encoder = torch_encoder
        self.train_mean = torch.from_numpy(np.reshape(train_mean, (-1, 1)))
        self.train_scale = torch.from_numpy(np.reshape(train_scale, (-1, 1)))

    def forward(self, x):
        mol_output = self.mol_model(x.moveaxis(0,-1))
        normalized = (mol_output - self.train_mean) / self.train_scale
        return self.torch_encoder(normalized.T)

complete_model = CompleteModel(mol_model, torch_encoder, train_mean, train_scale)

example_input = torch.randn([1,test_geom.shape[1], test_geom.shape[2]])
traced_script_module = torch.jit.trace(complete_model, example_input)

model_file_name = "model.pt"
traced_script_module.save(model_file_name)

In [None]:
m = torch.jit.load('model.pt')
lows = m(torch.tensor(tr.xyz)).numpy()
np.savetxt("lows.txt", lows)
lows.shape

In [None]:
lows = np.loadtxt("lows.txt")
rg = md.compute_rg(tr)
base = md.load(conf)
rmsd = md.rmsd(tr,base[0])
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=pot[:,1],cmap=cmap,s=1)
plt.colorbar(cmap=cmap)
plt.title("??")
plt.subplot(122)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rmsd,cmap=cmap,s=1)
plt.colorbar(cmap=cmap)
plt.title("RMSD")
plt.savefig("xxx.png")


### Other Properties
* Color the latent space above with the variables calculated in this section to explore the computed properties in the low dimentinal space

#### Alpha elics
* **Traj** must be the tranining .xtc and .pdb

In [None]:
traj = md.load_xtc("xxx.xtc", top="xxx.pdb")

dssp = md.compute_dssp(traj, simplified=True) 
alpha_content_per_frame = np.mean(dssp == 'H', axis=1)
average_alpha_helix_content = np.mean(alpha_content_per_frame)

print(f"Avarage alpha elics content: {average_alpha_helix_content:.3f}")

#### Contact pairs
* **x**: residue number.
* **y**:  Ca, Cb or whatever belonging with X, the user wish to compute. 

In [None]:
atom_indices = (traj.topology.select("resid x and name y")[0],  
                traj.topology.select("resid x and name y")[0])

distances = md.compute_distances(traj, [atom_indices])  

In [None]:
print(f'pair: {pairs[94]} \
distance: {distances[:, 94]}')

#### Angles
* **x**: same as above.
* **y**:  same as above

In [None]:
atom_indices = traj.topology.select("resid x and name y")[0], \
               traj.topology.select("resid x and name y")[0], \
               traj.topology.select("resid x and name y")[0]

# Radiants
angles = md.compute_angles(traj, [atom_indices])  

# Degree
angles_deg = np.rad2deg(angles[:, 0])

#### dihedrals
* **x**: same as above.
* **y**:  same as above

In [None]:
atom1 = traj.topology.select("resid x and name y")[0]
atom2 = traj.topology.select("resid x and name y")[0]
atom3 = traj.topology.select("resid x and name y")[0]
atom4 = traj.topology.select("resid x and name y")[0]

# Radiants
dihedrals = md.compute_dihedrals(traj, [[atom1, atom2, atom3, atom4]])
# Degree
dihedrals_deg = np.rad2deg(dihedrals[:, 0])  