In [None]:
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
import mdtraj as md
import math
from cgae.utils import write_traj, save_traj
from cgae.cgae_dense import *
import numpy as np

In [None]:
# hyperparameters 
par = {}
par["n_atom"] = 32
par["N_cg"] = 3
par["lr"]  = 1e-4
par["batch_size"] = 20  
par["Tstart"] = 4.0      # starting temperature 
par["max_epoch"] = 800
par["n_mol"] = 1         # number of molecules, 
par["epoch_regularize"] = 400    # first epoch to include instataneous force regularization
par["rho"] = 0.005       # relative weight for force regularization
par["decay_ratio"] = 0.4 # the decay rate for the temperature annealing 
device = 'cpu'

In [None]:
# load data 
otp = md.load("data/otp.pdb")
otp_top = otp.top.to_dataframe()[0]
otp_element = otp_top['element'].values.tolist()
traj = np.load('data/otp_xyz.npy')
force = np.load('data/otp_force.npy')

In [None]:
# prepare data 
traj = traj[-3000:] * 10
force = force[-3000:] * 0.0239

N_cg = par["N_cg"]
n_atom = par["n_atom"]
n_mol = par["n_mol"]
batch_size = par["batch_size"]
n_batch = int(traj.shape[0] // batch_size)
n_sample = n_batch * batch_size
xyz = traj[:n_sample].reshape(-1, batch_size, n_mol, n_atom, 3)
force = force[:n_sample].reshape(-1, batch_size, n_atom, 3)
device = "cpu"

In [None]:
encoder = Encoder(in_dim=par["n_atom"], out_dim=par["N_cg"], hard=False, device=device).to(device)
decoder = Decoder(in_dim=par["N_cg"], out_dim=par["n_atom"]).to(device)

# Initialize Temperature Scheduler for Gumbel softmax 
t0 = par["Tstart"]
tmin = 0.2
temp = np.linspace(0, par["max_epoch"], par["max_epoch"] )

decay_epoch = int(par["max_epoch"] * par["decay_ratio"])
t_sched = t0 * np.exp(-temp/decay_epoch ) +  tmin
t_sched = torch.Tensor(t_sched).to(device)


# Initialize Optimizer 
criterion = torch.nn.MSELoss()
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=par["lr"]) 

loss_log = []
loss_ae_log = []
loss_fm_log = []

In [None]:
for epoch in range(par["max_epoch"]):  
        loss_epoch = 0.0  
        loss_ae_epoch = 0.0
        loss_fm_epoch = 0.0
            
        for i, batch in enumerate(xyz):
            batch = torch.Tensor(batch.reshape(-1, n_atom, 3)).to(device) 
            cg_xyz = encoder(batch, t_sched[epoch])
            CG = gumbel_softmax(encoder.weight1.t(), t_sched[epoch] * 0.7, device=device).t()

            decoded = decoder(cg_xyz)
            loss_ae = criterion(decoded, batch)
            
            f0 = torch.Tensor(force[i].reshape(-1, n_atom, 3)).to(device)
            f = torch.matmul(CG, f0)
            mean_force = f.pow(2).sum(2).mean()

            loss_fm = mean_force
            
            if epoch >= par["epoch_regularize"]:
                loss = par["rho"] * loss_ae + par["rho"]  * mean_force
            else: 
                loss = par["rho"] * loss_ae

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_epoch += loss.item()
            loss_ae_epoch += loss_ae.item()
            loss_fm_epoch += loss_fm.item()
        
        loss_epoch = loss_epoch/xyz.shape[0]
        loss_ae_epoch = loss_ae_epoch/xyz.shape[0]
        loss_fm_epoch = loss_fm_epoch/xyz.shape[0]

        loss_log.append(loss_epoch)
        loss_ae_log.append(loss_ae_epoch)
        loss_fm_log.append(loss_fm_epoch)

        print("epoch %d reconstruction  %.3f instantaneous forces %.3f  tau  %.3f"  % (epoch, loss_ae_epoch, loss_fm_epoch, t_sched[epoch].item()))

        # Plot the mapping 
        CG = gumbel_softmax(encoder.weight1.t(), t_sched[epoch], device=device).t()
        plt.imshow(CG.detach().cpu().numpy(), aspect=4)
        plt.xticks(np.arange(n_atom), otp_element)
        plt.yticks(np.arange(N_cg), ["CG" + str(i+1) for i in range(N_cg)])
        plt.show()

In [None]:
# Save trajectory 
save_traj(Z=[1] * N_cg, traj=cg_xyz.detach().cpu().numpy(),name= 'CG.xyz')
save_traj(Z=otp_element, traj=decoded.detach().cpu().numpy(),name='decode.xyz')

In [None]:
cg_xyz[0]