In [8]:
import math
from simtk import unit
import os
import tempfile
import pickle
import mdtraj as md
import numpy as np
from simtk.unit.quantity import Quantity
import logging 
from tqdm import tqdm
# Set up logger
_logger = logging.getLogger()
_logger.setLevel(logging.INFO)

In [9]:
def get_trajs_for_state(i, aa, phase, length, out_dir, index, vanilla=False, htf=None):
    # From Hannah: https://github.com/hannahbrucemacdonald/endstate_pdbs/blob/master/scripts/input_for_pol_calc.py
    from perses.analysis.utils import open_netcdf
    nc = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{aa}_{length}.nc"))
    nc_checkpoint = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{aa}_{length}_checkpoint.nc"))
    checkpoint_interval = nc_checkpoint.CheckpointInterval
    all_positions = nc_checkpoint.variables['positions']
    n_iter, n_replicas, n_atoms, _ = np.shape(all_positions)

    if not vanilla:
        new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)
        old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)

        all_pos_new = np.zeros(shape=(n_iter, new_top.n_atoms, 3))
        all_pos_old = np.zeros(shape=(n_iter, old_top.n_atoms, 3))
        for iteration in tqdm(range(n_iter)):
            replica_id = np.where(nc.variables['states'][iteration*checkpoint_interval] == index)[0]
            pos = all_positions[iteration,replica_id,:,:][0] *unit.nanometers
            all_pos_new[iteration] = htf.new_positions(pos).value_in_unit_system(unit.md_unit_system) # Get new positions only
            all_pos_old[iteration] = htf.old_positions(pos).value_in_unit_system(unit.md_unit_system)
    
        return all_pos_new, all_pos_old
    else:
        # Get topology
        with open(os.path.join(out_dir, f"{i}_{aa}_vanilla_topology.pickle"), "rb") as f:
            topology = pickle.load(f)
        all_pos = np.zeros(shape=(n_iter, topology.getNumAtoms(), 3))
        
        for iteration in tqdm(range(n_iter)):
            replica_id = np.where(nc.variables['states'][iteration*checkpoint_interval] == index)[0]
            all_pos[iteration] = all_positions[iteration,replica_id,:,:][0] *unit.nanometers
    
        return all_pos, topology
    


In [10]:
out_dir = "/data/chodera/zhangi/perses_benchmark/neq/11/18/"
i = os.path.basename(os.path.dirname(out_dir))
phase = 'solvent'
endstate = 0
aa = 'THR'.lower()
length = '5ns'
index = 0
with open(os.path.join(out_dir, f"{i}_{phase}_{endstate}.pickle"), "rb") as f:
    htf = pickle.load(f)


In [11]:
# new, old = get_trajs_for_state(i, aa, phase, length, out_dir, index, vanilla=True)
pos_new, pos_old = get_trajs_for_state(i, aa, phase, length, out_dir, index, vanilla=False, htf=htf)

 17%|█▋        | 844/5001 [06:29<31:58,  2.17it/s]  


KeyboardInterrupt: 

In [12]:
traj_new = md.Trajectory(pos_new, md.Topology.from_openmm(htf._topology_proposal.new_topology))
traj_new.save(out_dir + f"new_{phase}_{index}.dcd")

In [13]:
traj_new[0].save(out_dir + f"new_{phase}_{index}.pdb")

In [None]:
traj_old = md.Trajectory(pos_old, md.Topology.from_openmm(htf._topology_proposal.old_topology))
# traj_old = md.Trajectory(pos, md.Topology.from_openmm(topology))
traj_old.save(out_dir + f"old_{phase}_{index}.dcd") # index here indicates the index of the state

In [None]:
traj_old[0].save(out_dir + f"old_{phase}_{index}.pdb")

# Get trajs of 100 chosen snapshots

## THR->ALA solvent THR snapshots

In [12]:
with open("/data/chodera/zhangi/perses_benchmark/neq/11/18/18_solvent_thr_5ns_snapshots.npy", "rb") as f:
    subset_pos = np.load(f)

In [16]:
old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)

all_pos_old = np.zeros(shape=(100, old_top.n_atoms, 3))
for i, pos in enumerate(subset_pos):
    all_pos_old[i] = htf.old_positions(pos*unit.nanometers)

In [17]:
traj_old = md.Trajectory(all_pos_old, old_top)


In [18]:
traj_old.save(out_dir + "18_solvent_thr_5ns_snapshots.dcd") # index here indicates the index of the state
traj_old[0].save(out_dir + "18_solvent_thr_5ns_snapshots.pdb")

## ALA->THR solvent THR snapshots

In [48]:
out_dir = "/data/chodera/zhangi/perses_benchmark/neq/11/19/"
i = os.path.basename(os.path.dirname(out_dir))
phase = 'solvent'
endstate = 1
aa = 'THR'.lower()
length = '5ns'
index = 0

In [49]:
with open(os.path.join(out_dir, f"{i}_{phase}_{endstate}.pickle"), "rb") as f:
    htf = pickle.load(f)

with open("/data/chodera/zhangi/perses_benchmark/neq/11/19/19_solvent_thr_5ns_snapshots.npy", "rb") as f:
    subset_pos = np.load(f)

In [51]:
new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)

all_pos_new = np.zeros(shape=(100, new_top.n_atoms, 3))
for i, pos in enumerate(subset_pos):
    all_pos_new[i] = htf.new_positions(pos*unit.nanometers)

In [52]:
traj_new = md.Trajectory(all_pos_new, new_top)


In [53]:
traj_new.save(out_dir + "19_solvent_thr_5ns_snapshots.dcd") # index here indicates the index of the state
traj_new[0].save(out_dir + "19_solvent_thr_5ns_snapshots.pdb")

## THR->ALA vacuum THR snapshots

In [59]:
out_dir = "/data/chodera/zhangi/perses_benchmark/neq/11/18/"
i = os.path.basename(os.path.dirname(out_dir))
phase = 'vacuum'
endstate = 0
aa = 'THR'.lower()
length = '5ns'
index = 0

In [60]:
with open(os.path.join(out_dir, f"{i}_{phase}_{endstate}.pickle"), "rb") as f:
    htf = pickle.load(f)

with open("/data/chodera/zhangi/perses_benchmark/neq/11/18/18_vacuum_thr_5ns_snapshots.npy", "rb") as f:
    subset_pos = np.load(f)

In [61]:
old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)

all_pos_old = np.zeros(shape=(100, old_top.n_atoms, 3))
for i, pos in enumerate(subset_pos):
    all_pos_old[i] = htf.old_positions(pos*unit.nanometers)

In [62]:
traj_old = md.Trajectory(all_pos_old, old_top)


In [64]:
traj_old.save(out_dir + "18_vacuum_thr_5ns_snapshots.dcd") # index here indicates the index of the state
traj_old[0].save(out_dir + "18_vacuum_thr_5ns_snapshots.pdb")

## ALA->THR vacuum THR snapshots

In [65]:
out_dir = "/data/chodera/zhangi/perses_benchmark/neq/11/19/"
i = os.path.basename(os.path.dirname(out_dir))
phase = 'vacuum'
endstate = 1
aa = 'THR'.lower()
length = '5ns'
index = 0

In [66]:
with open(os.path.join(out_dir, f"{i}_{phase}_{endstate}.pickle"), "rb") as f:
    htf = pickle.load(f)

with open("/data/chodera/zhangi/perses_benchmark/neq/11/19/19_vacuum_thr_5ns_snapshots.npy", "rb") as f:
    subset_pos = np.load(f)

In [67]:
new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)

all_pos_new = np.zeros(shape=(100, new_top.n_atoms, 3))
for i, pos in enumerate(subset_pos):
    all_pos_new[i] = htf.new_positions(pos*unit.nanometers)

In [68]:
traj_new = md.Trajectory(all_pos_new, new_top)


In [69]:
traj_new.save(out_dir + "19_vacuum_thr_5ns_snapshots.dcd") # index here indicates the index of the state
traj_new[0].save(out_dir + "19_vacuum_thr_5ns_snapshots.pdb")