## Grab thr snapshots from high temperature cache

In [1]:
import math
from simtk import unit
import os
import tempfile
import pickle
import mdtraj as md
import numpy as np
from simtk.unit.quantity import Quantity
import logging

# Set up logger
_logger = logging.getLogger()
_logger.setLevel(logging.INFO)

from simtk.openmm import app
from tqdm import tqdm
import random

In [12]:
def get_dihedrals(i, aa, length, out_dir, htf, dihedral_indices_new, dihedral_indices_old):
    new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)
    old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)

    # From Hannah: https://github.com/hannahbrucemacdonald/endstate_pdbs/blob/master/scripts/input_for_pol_calc.py
    from perses.analysis.utils import open_netcdf
    nc = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{aa}_{length}ns.nc"))
    nc_checkpoint = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{aa}_{length}ns_checkpoint.nc"))
    checkpoint_interval = nc_checkpoint.CheckpointInterval
    all_positions = nc_checkpoint.variables['positions']
    n_iter, n_replicas, n_atoms, _ = np.shape(all_positions)
    n_iter = 5000
    
    from tqdm import tqdm
    all_pos_hybrid_master = []
    for j in [19]:
        index = j # of replica
        all_pos_new = np.zeros(shape=(n_iter, new_top.n_atoms, 3))
        all_pos_old = np.zeros(shape=(n_iter, old_top.n_atoms, 3))
        all_pos_hybrid = np.zeros(shape=(n_iter, n_atoms, 3))
        for iteration in tqdm(range(3000, 3010)):
            replica_id = np.where(nc.variables['states'][iteration*checkpoint_interval] == index)[0]
            pos = all_positions[iteration,replica_id,:,:][0] *unit.nanometers
            all_pos_new[iteration] = htf.new_positions(pos).value_in_unit_system(unit.md_unit_system) # Get new positions only
            all_pos_hybrid[iteration] = pos # Get hybrid positions
            all_pos_old[iteration] = htf.old_positions(pos).value_in_unit_system(unit.md_unit_system)
        all_pos_hybrid_master.append(all_pos_hybrid)
        
        dihedrals_all = []
        for pos, top, indices in zip([all_pos_new, all_pos_old], [new_top, old_top], [dihedral_indices_new, dihedral_indices_old]):
            traj = md.Trajectory(np.array(pos), top)
            dihedrals = md.compute_dihedrals(traj, np.array([indices]))
            dihedrals_all.append(dihedrals)
    return dihedrals_all, all_pos_hybrid_master

In [3]:
def sort_snapshots(dihedrals):
    d_indices = {}
    d_indices[0], d_indices[1], d_indices[2] = list(), list(), list()
    for i, dihedral_angle in enumerate(dihedrals):
        if dihedral_angle == 0.:
            continue
        if dihedral_angle > 2 or dihedral_angle < -2: # angle is -3 or +3
            d_indices[0].append(i)
        elif dihedral_angle < 0 and dihedral_angle > -2: # angle is -1
            d_indices[1].append(i)
        elif dihedral_angle > 0 and dihedral_angle < 2: # angle is 1
            d_indices[2].append(i)
    return d_indices

### Get snapshots from T42A apo

In [77]:
outdir = "/data/chodera/zhangi/perses_benchmark/neq/12/36/"
name = "THR"
endstate = 0
phase = "apo"
length = 5
i = os.path.basename(os.path.dirname(outdir))
aa = name.lower()


In [78]:
with open(os.path.join(outdir, f"{i}_{phase}_{endstate}.pickle"), 'rb') as f:
    htf = pickle.load(f)

indices_old = [669, 670, 673, 674]
indices_new = [669, 670, 673, 676]

In [86]:
dihedrals, all_pos_hybrid = get_dihedrals(i, aa, length, outdir, htf, indices_new, indices_old)                                       


100%|██████████| 10/10 [00:18<00:00,  1.87s/it]


In [87]:
dihedrals_old = dihedrals[1]

In [88]:
d_indices_T42A_apo = sort_snapshots(dihedrals_old)

In [89]:
d_indices_T42A_apo

{0: [1000, 1008], 1: [1006, 1007], 2: [1001, 1002, 1003, 1004, 1005, 1009]}

In [90]:
all_pos_hybrid[0].shape

(5000, 14881, 3)

In [91]:
# Take the first snapshot for every angle
pos = all_pos_hybrid[0][[d_indices_T42A_apo[0][0], d_indices_T42A_apo[1][0], d_indices_T42A_apo[2][0]]]

In [92]:
with open(os.path.join("/data/chodera/zhangi/perses_benchmark/neq/13/0/", f"{0}_{phase}_{endstate}.npy"), "wb") as f:
    np.save(f, pos)

### Get snapshots from T42A complex

In [61]:
outdir = "/data/chodera/zhangi/perses_benchmark/neq/12/36/"
name = "THR"
endstate = 0
phase = "complex"
length = 5
i = os.path.basename(os.path.dirname(outdir))
aa = name.lower()


In [62]:
with open(os.path.join(outdir, f"{i}_{phase}_{endstate}.pickle"), 'rb') as f:
    htf = pickle.load(f)

indices_old = [669, 670, 673, 674]
indices_new = [669, 670, 673, 676]

In [70]:
dihedrals, all_pos_hybrid = get_dihedrals(i, aa, length, outdir, htf, indices_new, indices_old)                                       


100%|██████████| 10/10 [00:37<00:00,  3.74s/it]


In [71]:
dihedrals_old = dihedrals[1]

In [72]:
d_indices_T42A_complex = sort_snapshots(dihedrals_old)

In [73]:
d_indices_T42A_complex

{0: [2001], 1: [2004, 2008], 2: [2000, 2002, 2003, 2005, 2006, 2007, 2009]}

In [74]:
all_pos_hybrid[0].shape

(5000, 29478, 3)

In [75]:
pos = all_pos_hybrid[0][[d_indices_T42A_complex[0][0], d_indices_T42A_complex[1][0], d_indices_T42A_complex[2][0]]]

In [76]:
with open(os.path.join("/data/chodera/zhangi/perses_benchmark/neq/13/0/", f"{0}_{phase}_{endstate}.npy"), "wb") as f:
    np.save(f, pos)

### Get snapshots from A42T apo

In [115]:
outdir = "/data/chodera/zhangi/perses_benchmark/neq/12/37/"
name = "THR"
endstate = 1
phase = "apo"
length = 5
i = os.path.basename(os.path.dirname(outdir))
aa = name.lower()


In [116]:
with open(os.path.join(outdir, f"{i}_{phase}_{endstate}.pickle"), 'rb') as f:
    htf = pickle.load(f)

indices_old = [669, 670, 673, 676]
indices_new = [669, 670, 673, 681]

In [117]:
dihedrals, all_pos_hybrid = get_dihedrals(i, aa, length, outdir, htf, indices_new, indices_old)                                       


100%|██████████| 10/10 [00:18<00:00,  1.88s/it]


In [118]:
dihedrals_new = dihedrals[0]

In [119]:
d_indices_A42T_apo = sort_snapshots(dihedrals_new)

In [120]:
d_indices_A42T_apo

{0: [1001, 1008], 1: [1000, 1002, 1003, 1005, 1007, 1009], 2: [1004, 1006]}

In [121]:
all_pos_hybrid[0].shape

(5000, 14881, 3)

In [122]:
# Take the first snapshot for every angle
pos = all_pos_hybrid[0][[d_indices_A42T_apo[0][0], d_indices_A42T_apo[1][0], d_indices_A42T_apo[2][0]]]

In [123]:
with open(os.path.join("/data/chodera/zhangi/perses_benchmark/neq/13/1/", f"{1}_{phase}_{endstate}.npy"), "wb") as f:
    np.save(f, pos)

### Get snapshots from A42T complex

In [4]:
outdir = "/data/chodera/zhangi/perses_benchmark/neq/12/37/"
name = "THR"
endstate = 1
phase = "complex"
length = 5
i = os.path.basename(os.path.dirname(outdir))
aa = name.lower()


In [5]:
with open(os.path.join(outdir, f"{i}_{phase}_{endstate}.pickle"), 'rb') as f:
    htf = pickle.load(f)

indices_old = [669, 670, 673, 676]
indices_new = [669, 670, 673, 681]

In [None]:
dihedrals, all_pos_hybrid = get_dihedrals(i, aa, length, outdir, htf, indices_new, indices_old)                                       


 60%|██████    | 6/10 [00:22<00:14,  3.72s/it]

In [None]:
dihedrals_new = dihedrals[0]

In [None]:
d_indices_A42T_complex = sort_snapshots(dihedrals_new)

In [None]:
d_indices_A42T_complex

In [10]:
all_pos_hybrid[0].shape

(5000, 29478, 3)

In [11]:
pos = all_pos_hybrid[0][[d_indices_A42T_complex[0][0], d_indices_A42T_complex[1][0], d_indices_A42T_complex[2][0]]]

IndexError: list index out of range

In [None]:
with open(os.path.join("/data/chodera/zhangi/perses_benchmark/neq/13/1/", f"{1}_{phase}_{endstate}.npy"), "wb") as f:
    np.save(f, pos)