In [None]:
%cd ~/ASMSA/trpcage

In [None]:
threads = 2

import os
os.environ['OMP_NUM_THREADS']=str(threads)
import tensorflow as tf

# PyTorch favours OMP_NUM_THREADS in environment
import torch

# Tensorflow needs explicit cofig calls
tf.config.threading.set_inter_op_parallelism_threads(threads)
tf.config.threading.set_intra_op_parallelism_threads(threads)

In [None]:
from tensorflow import keras as k
import numpy as np
import mdtraj as md
import matplotlib.pyplot as plt
import math
import datetime
from asmsa_callbacks import callbacks

In [None]:
exec(open('inputs.py').read())

In [None]:
def _compute_number_of_neurons(layers, seed):
    neurons = [seed]

    tmp = seed
    for _ in range(layers-1):
        tmp = int(tmp / 2)
        neurons.append(tmp)
    return neurons

In [None]:
class AAEClassModel(k.models.Model):
    def __init__(self,
            n_features,
            n_classes,
            enc_layers, enc_seed,
            disc_layers, disc_seed,
            af='gelu',
            bn_momentum=0.8, leak_alpha=0.2,
            latent_dim=2,
            dist_threshold=5.
        ):

        super().__init__()
# https://doi.org/10.48550/arXiv.1511.05644, Sect. 7 / Fig. 10
        self.n_classes = n_classes
        self.latent_dim = latent_dim
        self.dist_threshold2 = dist_threshold*dist_threshold
        
        enc_neurons = _compute_number_of_neurons(enc_layers, enc_seed) 
        disc_neurons = _compute_number_of_neurons(disc_layers, disc_seed)

        inp = k.Input(shape=(n_features,),name='inp')
        out = inp

        for n in range(enc_layers):
            out = k.layers.Dense(enc_neurons[n],activation=af,name=f'enc_{n}')(out)
            out = k.layers.BatchNormalization(momentum=bn_momentum, name=f'enc_bn_{n}')(out)

        z = k.layers.Dense(latent_dim,name='latent_z')(out)
        y = k.layers.Dense(n_classes,activation='softmax',name='latent_y')(out)

        chs = k.layers.Dense(latent_dim,name='cluster_heads')
        hz = chs(y)
        d = (math.pow(n_classes,1./latent_dim) - 1) * dist_threshold / 2 / 3 # XXX
        chs.set_weights([
            np.random.uniform(low=-d,high=d,size=(n_classes,latent_dim)),
            np.random.normal(size=(latent_dim,))
        ])
        
        latent = k.layers.Add(name='add_z_hz')([z,hz])

        out = latent
        for n in reversed(range(enc_layers)):
            out = k.layers.Dense(enc_neurons[n],activation=af,name=f'dec_{n}')(out)
            out = k.layers.BatchNormalization(momentum=bn_momentum, name=f'dec_bn_{n}')(out)

        dec_out = k.layers.Dense(n_features,name='dec_out')(out)

        out = y
        for n in range(disc_layers):
            out = k.layers.Dense(disc_neurons[n],name=f'y_disc_{n}')(out)
            out = k.layers.LeakyReLU(negative_slope=leak_alpha,name=f'y_disc_relu_{n}')(out)
        
        y_disc_out = k.layers.Dense(1,name='y_disc_out')(out)

        out = z
        for n in range(disc_layers):
            out = k.layers.Dense(disc_neurons[n],name=f'z_disc_{n}')(out)
            out = k.layers.LeakyReLU(negative_slope=leak_alpha,name=f'z_disc_relu_{n}')(out)

        z_disc_out = k.layers.Dense(1,name='z_disc_out')(out)

        self.enc = k.Model(inputs=inp,outputs=[y,z])
        self.sum = k.Model(inputs=[y,z],outputs=latent)
        self.dec = k.Model(inputs=latent,outputs=dec_out)
        self.heads = k.Model(inputs=y,outputs=hz)
        # self.ae = k.Model(inputs=inp,outputs=dec_out)
        self.y_disc = k.Model(inputs=y,outputs=y_disc_out)
        self.z_disc = k.Model(inputs=z,outputs=z_disc_out)

    def compile(self,optimizer=None,lr=None):
        
        if optimizer is None:
            optimizer = tf.keras.optimizers.AdamW(
            learning_rate=1e-4,
            weight_decay=1e-5,
            beta_1=0.9,
            beta_2=0.999
        )
            opt = optimizer
            
        else:
            opt = optimizer

        self.ae_loss = k.losses.Huber() #MeanSquaredError()

        super().compile(optimizer=opt,loss=self.ae_loss)
        self.optimizer.build(self.enc.trainable_weights+self.sum.trainable_weights+self.dec.trainable_weights+self.y_disc.trainable_weights+self.z_disc.trainable_weights)

        self.enc.compile()
        self.dec.compile()
        self.y_disc.compile()
        self.z_disc.compile()

    @tf.function
    def train_step(self,in_batch):
        if isinstance(in_batch, tuple):
            batch = in_batch[0]
        else:
            batch = in_batch

        # autoencoder
        with tf.GradientTape() as aet:
            y,z = self.enc(batch)
            rec = self.dec(self.sum([y,z]))
            ael = self.ae_loss(batch,rec)
    
        aew = self.enc.trainable_weights + self.sum.trainable_weights + self.dec.trainable_weights
        aeg = aet.gradient(ael,aew)
        self.optimizer.apply_gradients(zip(aeg,aew))
    
        # categoric discriminator
        idx = tf.random.uniform((batch.shape[0],), minval=0, maxval=self.n_classes, dtype=tf.int32)
        randy = tf.one_hot(idx, depth=self.n_classes)
    
        # binary crossentropy from logits
        with tf.GradientTape() as yt:
            nyp = self.y_disc(y)
            nyp *= tf.random.uniform(tf.shape(nyp), 1., 1.05)
            nyl = tf.reduce_mean(nyp,axis=0)
    
            pyp = self.y_disc(randy)
            pyp *= tf.random.uniform(tf.shape(pyp), 1., 1.05)
            pyl = -tf.reduce_mean(pyp,axis=0)
            y_disc_loss = (nyl + pyl) * 1e-5 #XXX
    
        yg = yt.gradient(y_disc_loss,self.y_disc.trainable_weights)
        self.optimizer.apply_gradients(zip(yg,self.y_disc.trainable_weights))
    
        # cheet it
        with tf.GradientTape() as yct:
            yc = self.y_disc(self.enc(batch)[0])
            yc *= tf.random.uniform(tf.shape(yc), 1., 1.05)
            ycl = -tf.reduce_mean(yc, axis=0) * 1e-5 #XXX
    
        ycg = yct.gradient(ycl,self.enc.trainable_weights)
        self.optimizer.apply_gradients(zip(ycg,self.enc.trainable_weights))
            
        # intra category discriminator
        randz = tf.random.normal(shape=(batch.shape[0], self.latent_dim))
    
        with tf.GradientTape() as zt:
            nzp = self.z_disc(z)
            nzp *= tf.random.uniform(tf.shape(nzp), 1., 1.05)
            nzl = tf.reduce_mean(nzp,axis=0)
    
            pzp = self.z_disc(randz)
            pzp *= tf.random.uniform(tf.shape(pzp), 1., 1.05)
            pzl = -tf.reduce_mean(pzp,axis=0)
            z_disc_loss = nzl + pzl
    
        zg = zt.gradient(z_disc_loss,self.z_disc.trainable_weights)
        self.optimizer.apply_gradients(zip(zg,self.z_disc.trainable_weights))
            
        # cheet it
        with tf.GradientTape() as zct:
            zc = self.z_disc(self.enc(batch)[1])
            zc *= tf.random.uniform(tf.shape(zc), 1., 1.05)
            zcl = -tf.reduce_mean(zc, axis=0)
    
        zcg = zct.gradient(zcl,self.enc.trainable_weights)
        self.optimizer.apply_gradients(zip(zcg,self.enc.trainable_weights))
        
        # keep cluster heads apart
        randc = tf.linalg.diag(tf.random.uniform((self.n_classes,),0.95,1.05))
        with tf.GradientTape() as ht:
            ch = self.heads(randc)
            norms = tf.reduce_sum(tf.square(ch), axis=1, keepdims=True)  # shape=(N,1)
            dists_squared = norms - 2 * tf.matmul(ch, ch, transpose_b=True) + tf.transpose(norms)
            small_dists = tf.boolean_mask(tf.exp(-dists_squared), dists_squared < self.dist_threshold2)
            hl = tf.reduce_sum(small_dists)/self.dist_threshold2/self.n_classes/self.n_classes * 21. #XXX

        hg = ht.gradient(hl,self.heads.trainable_weights)
        self.optimizer.apply_gradients(zip(hg,self.heads.trainable_weights))

        return {'ae_loss' : ael}

    @tf.function
    def test_step(self, in_batch):
        """Test step per il monitoraggio della validazione durante il training"""
        if isinstance(in_batch, tuple):
            batch = in_batch[0]
        else:
            batch = in_batch

        y, z = self.enc(batch, training=False)
        rec = self.dec(self.sum([y, z]), training=False)
        val_ae_loss = self.ae_loss(batch, rec)
        
        idx = tf.random.uniform((batch.shape[0],), minval=0, maxval=self.n_classes, dtype=tf.int32)
        randy = tf.one_hot(idx, depth=self.n_classes)
        
        nyp = self.y_disc(y, training=False)
        nyl = tf.reduce_mean(nyp, axis=0)
        pyp = self.y_disc(randy, training=False)
        pyl = -tf.reduce_mean(pyp, axis=0)
        val_y_disc_loss = (nyl + pyl) * 1e-5
        
        yc = self.y_disc(y, training=False)
        val_y_cheat_loss = -tf.reduce_mean(yc, axis=0) * 1e-5
        
        randz = tf.random.normal(shape=(batch.shape[0], self.latent_dim))
        nzp = self.z_disc(z, training=False)
        nzl = tf.reduce_mean(nzp, axis=0)
        pzp = self.z_disc(randz, training=False)
        pzl = -tf.reduce_mean(pzp, axis=0)
        val_z_disc_loss = nzl + pzl
        
        # Z cheat validation loss
        zc = self.z_disc(z, training=False)
        val_z_cheat_loss = -tf.reduce_mean(zc, axis=0)
        
        # Cluster head validation loss
        randc = tf.linalg.diag(tf.ones((self.n_classes,)))  # No random noise for validation
        ch = self.heads(randc, training=False)
        norms = tf.reduce_sum(tf.square(ch), axis=1, keepdims=True)
        dists_squared = norms - 2 * tf.matmul(ch, ch, transpose_b=True) + tf.transpose(norms)
        small_dists = tf.boolean_mask(tf.exp(-dists_squared), dists_squared < self.dist_threshold2)
        val_cluster_head_loss = tf.reduce_sum(small_dists)/self.dist_threshold2/self.n_classes/self.n_classes * 21.

        return {'ae_loss': val_ae_loss,}

    @tf.function
    def call(self,inp):
        return self.sum(self.enc(inp))

    def call_enc(self,inp):
        return self.call(inp)
            

In [None]:
# XXX: essential hps manually
batch_size = 64
best_enc_seed=128
best_disc_seed=128 
ae_layers=2 
disc_layers=3
#learning_rate=0.00002
n_classes = 4

In [None]:
X_train = tf.data.Dataset.load('datasets/intcoords/train')

# get batched version of dataset to feed to AAE model for training
X_train_batched = X_train.batch(batch_size,drop_remainder=True)

# get numpy version for visualization purposes
X_train_np = np.stack(list(X_train))
X_train_np.shape

In [None]:
X_test = tf.data.Dataset.load('datasets/intcoords/test')

# get numpy version for visualization purposes
X_test_np = np.stack(list(X_test))
X_test_np.shape

In [None]:
X_val = tf.data.Dataset.load('datasets/intcoords/validate').batch(batch_size,drop_remainder=True)
X_val_np = np.stack(list(X_val))
X_val_np.shape

In [None]:
m = AAEClassModel(n_features=X_train_np.shape[1],n_classes=n_classes,enc_layers=ae_layers,enc_seed=best_enc_seed,disc_layers=disc_layers,disc_seed=best_disc_seed)

In [None]:
log_dir = "logs/class_traom/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

cb = callbacks(log_dir, m, X_test_np,freq=20, monitor="val_ae_loss")

In [None]:
m.compile()#lr=learning_rate)

In [None]:
m.summary()

In [None]:
m.fit(X_train_batched,epochs=500,verbose=2, validation_data=X_val,
     callbacks = cb)

In [None]:
import asmsa.visualizer as visualizer

visualizer.Visualizer(figsize=(12,3)).make_visualization(m(X_train_np[::10,:]).numpy())

In [None]:
visualizer.Visualizer(figsize=(12,3)).make_visualization(m(X_test_np).numpy())

In [None]:
tr = md.load('../DE-Shaw/trpcage_red.xtc',top=conf)

In [None]:
# load all dataset
X_all = tf.data.Dataset.load('datasets/intcoords/X_all')

# get batched version of dataset to feed to AAE model for prediction
X_all_batched = X_all.batch(batch_size,drop_remainder=True)

# get numpy version for testing purposes
X_all_np = np.stack(list(X_all))
X_all_np.shape

In [None]:
tica = np.load("../ASMSA_DE/tica3.npy")

In [None]:
lows = m(X_all_np).numpy()
rg = md.compute_rg(tr)
base = md.load(conf)
rmsd = md.rmsd(tr,base[0])
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rg,cmap=cmap,s=1) #tica[:,0]
plt.colorbar(cmap=cmap)
plt.title("Rg")
plt.subplot(122)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rmsd,cmap=cmap,s=1) #tica[:,1]
plt.colorbar(cmap=cmap)
plt.title("RMSD")
plt.show()

In [None]:
lows = m(X_all_np).numpy()
rg = md.compute_rg(tr)
base = md.load(conf)
rmsd = md.rmsd(tr,base[0])
cmap = plt.get_cmap('Dark2')
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rg,cmap=cmap,s=1) #tica[:,0]
plt.colorbar(cmap=cmap)
plt.title("Rg")
plt.subplot(122)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rmsd,cmap=cmap,s=1) #tica[:,1]
plt.colorbar(cmap=cmap)
plt.title("RMSD")
plt.show()

In [None]:
traj = md.load_xtc("../DE-Shaw/trpcage_red.xtc", top=conf)

dssp = md.compute_dssp(traj, simplified=True)  # simplified=True â†’ 'H' (alpha helix), 'E' (beta sheet), etc.

alpha_content_per_frame = np.mean(dssp == 'H', axis=1)

average_alpha_helix_content = np.mean(alpha_content_per_frame)

In [None]:
lows = m(X_all_np).numpy()
rg = md.compute_rg(tr)
base = md.load(conf)
rmsd = md.rmsd(tr,base[0])
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=alpha_content_per_frame,cmap=cmap,s=1) #tica[:,0]
plt.colorbar(cmap=cmap)
plt.title("alpha_content")

In [None]:
lows = m(X_all_np).numpy()
rg = md.compute_rg(tr)
base = md.load(conf)
rmsd = md.rmsd(tr,base[0])
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=tica[:,0],cmap=cmap,s=1) #
plt.colorbar(cmap=cmap)
plt.title("tica1")
plt.subplot(122)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=tica[:,1],cmap=cmap,s=1) #
plt.colorbar(cmap=cmap)
plt.title("tica2")
plt.show()

In [None]:
heads = m.heads(tf.linalg.diag(tf.constant([1.] * n_classes))).numpy()
c = np.log(np.sum(m.enc(X_all_np)[0].numpy(),axis=0))
plt.scatter(heads[:,0],heads[:,1],c=c,cmap='Dark2')
plt.colorbar()
plt.show()

In [None]:
z = m.enc(X_all_np)[1].numpy()
plt.scatter(z[:,0],z[:,1])
plt.show()