In [15]:
import numpy as np
import torch
import mdtraj as md
import pickle
import glob
from natsort import natsorted
import os
import subprocess
from tqdm.notebook import tqdm

In [2]:
group_list = natsorted(glob.glob('./data/groups/*.txt'))
continued_groups = {f.replace('_continued.txt', '') for f in group_list if '_continued' in f}
final_group_list = [
    f for f in group_list
    if '_continued' in f or f.replace('.txt', '') not in continued_groups
]

with open('./data/xtcs_list', 'r') as g:
    xtcslist = g.readlines()

xtcslist = [j.strip() for j in xtcslist]

with open('./data/strip_parms_list', 'r') as g:
    stripparms = g.readlines()

stripparms = [j.strip() for j in stripparms]
def gen_parm_traj_dict():
    parm_traj_dict = {}   
    for group in final_group_list:
        with open(group, 'r') as f:
            lines = f.readlines()
            for line in lines:
                system, projno, run = line.strip().split(',')
                trajs = [j for j in xtcslist if projno in j and f'run{run}' in j]
                if trajs:
                    parm = [j for j in stripparms if system in j]
                    parm_traj_dict[system] = {'trajs':trajs, 'parm':parm[0]}
    pickle.dump(parm_traj_dict, open('./data/parm_traj_dict.pkl','wb'))

# gen_parm_traj_dict()

In [3]:
#Hyperparameters

maxlen = 8000
timestep = 64
windows = {'train':[0,int(timestep*0.8)],
            'valid':[int(timestep*0.8), int(timestep*0.9)+1],
            'test':[int(timestep*0.9)+1, int(timestep)-1] 
          }
n_dims = 6 #3 positions, 3 velocities

In [4]:
# input_files = natsorted(glob.glob('./data/pdb/p12000_run0*.xtc'))
# n_trajs = len(input_file)

In [5]:
windows

{'train': [0, 51], 'valid': [51, 58], 'test': [58, 63]}

In [6]:
def generate_train_valid_test(input_files, parm, sysname):
    for trajno, traj in enumerate(input_files):
        trajname = traj.split('/')[-1].split('.')[0]
        # print(f'Using {traj} as input...')
        pdb = md.load(traj, top=parm)
        CA = pdb.topology.select('name CA')
        pdb = pdb.atom_slice(CA)
        n_residues = len(CA)
        n_frames_per_window = pdb.n_frames // timestep
        if pdb.n_frames % timestep != 0:
            print(sysname, pdb.n_frames)
            closest_multiple_to_timestep_less_than_total_length = (pdb.n_frames // timestep)*timestep
            pdb = pdb[:closest_multiple_to_timestep_less_than_total_length]
        for mode, window in windows.items():
            # print(f'Generating dataset for {mode}ing...')
            start, end = window
            features = np.zeros((end-start, n_frames_per_window, n_residues, n_dims), dtype=np.float64)
            window_start = start
            for nwindow, windowtraj in enumerate(range(start, end)):
                frames_to_choose_for_this_window = np.arange(windowtraj, pdb.n_frames, timestep)
                # if pdb.n_frames < 8000:
                # print(len(frames_to_choose_for_this_window))
                # print(frames_to_choose_for_this_window)
                vel_frames = frames_to_choose_for_this_window + 1
                coords = pdb[frames_to_choose_for_this_window].xyz*10
                # print(coords.shape, n_frames_per_window)
                vels = pdb[vel_frames].xyz*10 - coords
                features[nwindow, :, :, :3] = coords
                features[nwindow, :, :, 3:] = vels
            pickle.dump(features, open(f'/home/prateek/storage/ML_Allostery/NRI-MD/data/processed_data/{sysname}_{trajname}_{mode}.pkl','wb'))

In [14]:
def gen_dataset_for_all_systems():
    parm_traj_dict = pickle.load(open('./data/parm_traj_dict.pkl','rb'))
    for system_name, trajparm in parm_traj_dict.items():
        trajs, parm = trajparm['trajs'], trajparm['parm']
        trajnames = [traj.split('/')[-1].split('.')[0] for traj in trajs]
        for trajno, trajname in tqdm(enumerate(trajnames), total=len(trajnames), desc=f"Processing {system_name}"):
            for mode, window in windows.items():
                if not os.path.exists(f'/home/prateek/storage/ML_Allostery/NRI-MD/data/processed_data/{system_name}_{trajname}_{mode}.pkl'):
                    generate_train_valid_test([trajs[trajno]], parm, system_name)
gen_dataset_for_all_systems()

TypeError: 'module' object is not callable. Did you mean: 'tqdm.tqdm(...)'?

In [None]:
# for mode in windows.keys():
#     mode2 = f'_{mode}' if mode != 'train' else ''
#     p = pickle.load(open(f'./data/pdb/ca_1.pdb_{mode}.pkl','rb'))
#     q = np.load(f'./data/features{mode2}.npy')
#     p_2 = np.transpose(p, (0, 1, 3, 2))
#     print(p_2.shape, q.shape, p_2.shape == q.shape)

In [211]:
atol=1e-5
np.where(np.isclose(p_2, q, atol=atol))[0].shape[0]/(np.where(np.isclose(p_2, q, atol=atol))[0].shape[0] + np.where(np.logical_not(np.isclose(p_2, q, atol=atol)))[0].shape[0])*100

2.050092764378479

In [214]:
pwd

'/home/prateek/storage/ML_Allostery/NRI-MD'

In [150]:
# a

In [151]:
# np.where(coords == 0)

In [152]:
# 41+12*60

In [230]:
file_list = [
    './data/groups/group_50_60.txt',
    './data/groups/group_60_70.txt',
    './data/groups/group_70_80.txt',
    './data/groups/group_80_90.txt',
    './data/groups/group_80_90_continued.txt',
    './data/groups/group_90_100.txt',
    './data/groups/group_90_100_continued.txt',
    './data/groups/group_100_110.txt',
    './data/groups/group_110_120.txt',
    './data/groups/group_110_120_continued.txt',
    './data/groups/group_120_130.txt',
    './data/groups/group_120_130_continued.txt',
    './data/groups/group_130_140.txt',
    './data/groups/group_130_140_continued.txt',
    './data/groups/group_140_150.txt',
    './data/groups/group_140_150_continued.txt',
    './data/groups/group_150_160_continued.txt',
    './data/groups/group_170_180_continued.txt',
    './data/groups/group_180_190_continued.txt',
    './data/groups/group_190_200_continued.txt',
    './data/groups/group_200_210_continued.txt',
    './data/groups/group_210_220_continued.txt',
    './data/groups/group_220_230_continued.txt',
    './data/groups/group_320_330_continued.txt',
    './data/groups/group_350_360_continued.txt'
]


continued_groups = {f.replace('_continued.txt', '') for f in file_list if '_continued' in f}

# Create the filtered list
filtered_list = [
    f for f in file_list
    if '_continued' in f or f.replace('.txt', '') not in continued_groups
]

# Display the result
for file in filtered_list:
    print(file)


./data/groups/group_50_60.txt
./data/groups/group_60_70.txt
./data/groups/group_70_80.txt
./data/groups/group_80_90_continued.txt
./data/groups/group_90_100_continued.txt
./data/groups/group_100_110.txt
./data/groups/group_110_120_continued.txt
./data/groups/group_120_130_continued.txt
./data/groups/group_130_140_continued.txt
./data/groups/group_140_150_continued.txt
./data/groups/group_150_160_continued.txt
./data/groups/group_170_180_continued.txt
./data/groups/group_180_190_continued.txt
./data/groups/group_190_200_continued.txt
./data/groups/group_200_210_continued.txt
./data/groups/group_210_220_continued.txt
./data/groups/group_220_230_continued.txt
./data/groups/group_320_330_continued.txt
./data/groups/group_350_360_continued.txt


In [44]:
8000 // 64

125