# Configs

In [1]:
import json
import time
import torch
import pprint
import mdtraj
import random
import pandas
import nglview
import numpy as np
import matplotlib.pyplot as plt

from openmm import *
from tqdm.auto import tqdm
from matplotlib import animation 
from openmm.app import *
from openmm.unit import *
from matplotlib.colors import LogNorm
from mpl_toolkits.mplot3d import Axes3D

import jax.numpy as jnp
from jax import grad, value_and_grad, vmap




In [7]:
molecule = "alanine"
temperature = "300.0"
state = "c5"

result_dir = f"../log/{molecule}/{temperature}/{state}"
pdb_file = f"../data/{molecule}/{state}.pdb"

In [8]:
arg_file = f"{result_dir}/args.json"

with open(arg_file, 'r') as f:
	arg_data = json.load(f)
 
pprint.pprint(arg_data)

{'config': 'config/alanine/debug.json',
 'force_field': 'amber99',
 'freq_csv': 1000,
 'freq_dcd': 1,
 'freq_stdout': 10000,
 'molecule': 'alanine',
 'platform': 'OpenCL',
 'precision': 'mixed',
 'solvent': 'tip3p',
 'state': 'c5',
 'temperature': 300.0,
 'time': 100000000}


# Trajectory from simulation

In [None]:
start = time.time()
print("Loading trajectory...")
loaded_traj = mdtraj.load(
    f"{result_dir}/traj.dcd",
    top=pdb_file
)
end = time.time()
print(f"{end-start} seconds")
print("Trajectory loaded.!!")

In [None]:
def plot_ram_from_sim(loaded_traj, pdb_file, state):
    fig = plt.figure()
    fig, ax = plt.subplots(figsize=(6,6))
    phis = mdtraj.compute_phi(loaded_traj)[1].ravel()
    psis = mdtraj.compute_psi(loaded_traj)[1].ravel()
    
    state_traj = mdtraj.load(pdb_file)
    phi_start = mdtraj.compute_phi(state_traj)[1].ravel()
    psi_start = mdtraj.compute_psi(state_traj)[1].ravel()
    ax.set_title(f"State {state}")
    ax.scatter(phi_start * 180 / np.pi, psi_start * 180 / np.pi, c='red', s=100, zorder=1)
    
    # Ramachandran plot
    # ax.hist2d(phis * 180 / np.pi, psis * 180 / np.pi, 100, norm=LogNorm(), zorder=0)
    ax.scatter(phis * 180 / np.pi, psis * 180 / np.pi, s=100)
    ax.set_xlim(-180, 180)
    ax.set_ylim(-180, 180)
    ax.set_xticks(np.linspace(-180, 180, 5))
    ax.set_yticks(np.linspace(-180, 180, 5))
    ax.set_xlabel("Phi [deg]")
    ax.set_ylabel("Psi [deg]")
    fig.tight_layout()

In [None]:
plot_ram_from_sim(loaded_traj, pdb_file, state)

# Trajectory from dataset

In [9]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from torch.utils.data import Dataset
from util.dataset import MD_Dataset

In [16]:
for temperature in ["100.0", "200.0", "300.0", "400.0", "500.0", "600.0"]:
	print(f"Temperature: {temperature}")
	try:
		dataset_dir = f"../dataset/{molecule}/{temperature}"
		data = torch.load(f"{dataset_dir}/{state}-random.pt")
		print(f"Number of samples: {len(data)}")
	except Exception as e:
		print("Exception: ", e)	

Temperature: 100.0
Exception:  Can't get attribute 'MD_Dataset' on <module 'dataset' (namespace)>
Temperature: 200.0
Exception:  Can't get attribute 'MD_Dataset' on <module 'dataset' (namespace)>
Temperature: 300.0
Exception:  Can't get attribute 'MD_Dataset' on <module 'dataset' (namespace)>
Temperature: 400.0
Exception:  Can't get attribute 'MD_Dataset' on <module 'dataset' (namespace)>
Temperature: 500.0
Exception:  Can't get attribute 'MD_Dataset' on <module 'dataset' (namespace)>
Temperature: 600.0
Exception:  Can't get attribute 'MD_Dataset' on <module 'dataset' (namespace)>


In [14]:
dataset_dir = f"../dataset/{molecule}/{temperature}"
data = torch.load(f"{dataset_dir}/{state}-random.pt")

frames = []
for t in tqdm(range(len(data))):
    x, y, goal, detla_k = data[t]
    frames.append(x)
    
print(frames[0].shape)

  0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([22, 3])


In [None]:
angle_1 = [6, 8, 14, 16]
angle_2 = [4, 6, 8, 14]

def dihedral_old(i, j, k, l):
    b1, b2, b3 = j - i, k - j, l - k

    c1 = vmap(jnp.cross, (0, 0))(b2, b3)
    c2 = vmap(jnp.cross, (0, 0))(b1, b2)

    p1 = (b1 * c1).sum(-1)
    p1 = p1 * jnp.sqrt((b2 * b2).sum(-1))
    p2 = (c1 * c2).sum(-1)

    r = vmap(jnp.arctan2, (0, 0))(p1, p2)
    return r

def dihedral(p):
    """http://stackoverflow.com/q/20305272/1128289"""
    b = p[:-1] - p[1:]
    b[0] *= -1
    v = np.array([v - (v.dot(b[1]) / b[1].dot(b[1])) * b[1] for v in [b[0], b[2]]])
    # Normalize vectors
    v /= np.sqrt(np.einsum('...i,...i', v, v)).reshape(-1, 1)
    b1 = b[1] / np.linalg.norm(b[1])
    x = np.dot(v[0], v[1])
    m = np.cross(v[0], b1)
    y = np.dot(m, v[1])
    return np.arctan2(y, x)

In [None]:
def plot_ram_from_dataset(frames, pdb_file, state):
	fig = plt.figure()
	fig, ax = plt.subplots(figsize=(6,6))
	
	phis = []
	psis = []
	for frame in frames:
		psis.append(dihedral(np.array(frame[angle_1, :].cpu())) * 180 / np.pi)
		phis.append(dihedral(np.array(frame[angle_2,: ].cpu())) * 180 / np.pi)
	print(psis)
	print(phis)

	state_traj = mdtraj.load(pdb_file)
	phi_start = mdtraj.compute_phi(state_traj)[1].ravel()
	psi_start = mdtraj.compute_psi(state_traj)[1].ravel()
	ax.scatter(phi_start * 180 / np.pi, psi_start * 180 / np.pi, c='red', s=100, zorder=1)

	# Ramachandran plot
	# ax.hist2d(phis, psis, 100, norm=LogNorm(), zorder=0)
	ax.scatter(phis, psis, s=100)
	ax.set_xlim(-180, 180)
	ax.set_ylim(-180, 180)
	ax.set_xticks(np.linspace(-180, 180, 5))
	ax.set_yticks(np.linspace(-180, 180, 5))
	ax.set_xlabel("Phi [deg]")
	ax.set_ylabel("Psi [deg]")
	fig.tight_layout()

In [None]:
plot_ram_from_dataset(frames, pdb_file, state)