In [2]:
import json
import time
import torch
import pprint
import mdtraj
import random
import pandas
import nglview
import argparse
import numpy as np
import matplotlib.pyplot as plt

from openmm import *
from tqdm.auto import tqdm
from matplotlib import animation 
from openmm.app import *
from openmm.unit import *
from matplotlib.colors import LogNorm
from mpl_toolkits.mplot3d import Axes3D
from torch.utils.data import Dataset




In [3]:
molecule = "alanine"
temperature = "300.0"
state = "c5"

In [4]:
print("Loading trajectory...")
result_dir = f"../log/{molecule}/{temperature}/{state}"
pdb_file = f"../data/{molecule}/{state}.pdb"
loaded_traj = mdtraj.load(
    f"{result_dir}/traj.dcd",
    top=pdb_file
)
print("Done.!!")
nglview.show_mdtraj(loaded_traj)

Loading trajectory...
Done.!!


NGLWidget(max_frame=99999999)

In [5]:
arg_file = f"{result_dir}/args.json"

with open(arg_file, 'r') as f:
	config = json.load(f)
	print(">> Loaded config")
	pprint.pprint(config)

>> Loaded config
{'config': 'config/alanine/debug.json',
 'force_field': 'amber99',
 'freq_csv': 1000,
 'freq_dcd': 1,
 'freq_stdout': 10000,
 'molecule': 'alanine',
 'platform': 'OpenCL',
 'precision': 'mixed',
 'solvent': 'tip3p',
 'state': 'c5',
 'temperature': 300.0,
 'time': 100000000}


In [30]:
class MD_Dataset(Dataset):
    def __init__(self, traj, config):
        self.molecule = config['molecule']
        self.state = config['state']
        self.temperature = config['temperature']
        self.time = config['time']
        self.force_field = config['force_field']
        self.solvent = config['solvent']
        self.platform = config['platform']
        self.precision = config['precision']
        
        data_x_list = []
        data_y_list = []
        for t in tqdm(
            # range(self.time -1),
            range(100 - 1),
            desc="Loading data"
        ):
            current_state = torch.tensor(loaded_traj[t].xyz.squeeze())
            next_state = torch.tensor(loaded_traj[t+1].xyz.squeeze())
            data_x_list.append(current_state)
            data_y_list.append(next_state)
        self.x = torch.stack(data_x_list)
        self.y = torch.stack(data_y_list)
        
        self.sanity_check(loaded_traj)
    
    def sanity_check(self, loaded_traj):
        print("Running sanity check...")
        print(f">> x size: {self.x.shape}")
        print(f">> y size: {self.y.shape}")
        for t in tqdm(
            # range(self.time -1),
            range(100 - 1),
            desc="Sanity check"
        ):
            x = self.x[t]
            y = self.y[t]
            x_frame = torch.tensor(loaded_traj[t].xyz.squeeze())
            y_frame = torch.tensor(loaded_traj[t+1].xyz.squeeze())
            
            assert torch.equal(x, x_frame), f"Frame {t}, x not equal"
            assert torch.equal(y, y_frame), f"Frame {t+1}, y not equal"        

    def __getitem__(self, index):
	    return self.x[index], self.y[index]
 
    def __len__(self):
	    return self.x.shape[0]

In [31]:
dataset = MD_Dataset(loaded_traj, config)


Loading data:   0%|          | 0/99 [00:00<?, ?it/s]

Running sanity check...
>> x size: torch.Size([99, 22, 3])
>> y size: torch.Size([99, 22, 3])


Sanity check:   0%|          | 0/99 [00:00<?, ?it/s]