## Grab thr snapshots from 298 K cache

In [2]:
import math
from simtk import unit
import os
import tempfile
import pickle
import mdtraj as md
import numpy as np
from simtk.unit.quantity import Quantity
import logging

# Set up logger
_logger = logging.getLogger()
_logger.setLevel(logging.INFO)

from simtk.openmm import app
from tqdm import tqdm
import random

In [61]:
def get_dihedrals(i, aa, length, out_dir, htf, dihedral_indices_new, dihedral_indices_old):
    new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)
    old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)

    # From Hannah: https://github.com/hannahbrucemacdonald/endstate_pdbs/blob/master/scripts/input_for_pol_calc.py
    from perses.analysis.utils import open_netcdf
    nc = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{aa}_{length}ns.nc"))
    nc_checkpoint = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{aa}_{length}ns_checkpoint.nc"))
    checkpoint_interval = nc_checkpoint.CheckpointInterval
    all_positions = nc_checkpoint.variables['positions']
    n_iter, n_replicas, n_atoms, _ = np.shape(all_positions)
    n_iter = 5000
    
    from tqdm import tqdm
    all_pos_hybrid_master = []
    for j in [0]:
        index = j # of replica
        all_pos_new = np.zeros(shape=(n_iter, new_top.n_atoms, 3))
        all_pos_old = np.zeros(shape=(n_iter, old_top.n_atoms, 3))
        all_pos_hybrid = np.zeros(shape=(n_iter, n_atoms, 3))
        for iteration in tqdm(range(0, 5000, 500)):
            replica_id = np.where(nc.variables['states'][iteration*checkpoint_interval] == index)[0]
            pos = all_positions[iteration,replica_id,:,:][0] *unit.nanometers
            all_pos_new[iteration] = htf.new_positions(pos).value_in_unit_system(unit.md_unit_system) # Get new positions only
            all_pos_hybrid[iteration] = pos # Get hybrid positions
            all_pos_old[iteration] = htf.old_positions(pos).value_in_unit_system(unit.md_unit_system)
        all_pos_hybrid_master.append(all_pos_hybrid)
        
        dihedrals_all = []
        for pos, top, indices in zip([all_pos_new, all_pos_old], [new_top, old_top], [dihedral_indices_new, dihedral_indices_old]):
            traj = md.Trajectory(np.array(pos), top)
            dihedrals = md.compute_dihedrals(traj, np.array([indices]))
            dihedrals_all.append(dihedrals)
    return dihedrals_all, all_pos_hybrid_master

In [19]:
def get_dihedrals_from_decorrelated(i, phase, name, length, out_dir, htf, dihedral_indices_new, dihedral_indices_old):
    
    new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)
    old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)

    with open(os.path.join(outdir, f"{i}_{phase}_{aa.lower()}_{length}ns_snapshots.npy"), "rb") as f:
        pos_all = np.load(f)
        pos_new = np.zeros(shape=(len(pos_all), new_top.n_atoms, 3))
        pos_old = np.zeros(shape=(len(pos_all), old_top.n_atoms, 3))
        for i in tqdm(range(len(pos_all))):
            pos_old[i] = htf.old_positions(pos_all[i]*unit.nanometer)
            pos_new[i] = htf.new_positions(pos_all[i]*unit.nanometer)
    
    dihedrals_all = []
    for pos, top, indices in zip([pos_new, pos_old], [new_top, old_top], [dihedral_indices_new, dihedral_indices_old]):
        traj = md.Trajectory(np.array(pos), top)
        dihedrals = md.compute_dihedrals(traj, np.array([indices]))
        dihedrals_all.append(dihedrals)
    
    return dihedrals_all, pos_all

In [4]:
def sort_snapshots(dihedrals):
    d_indices = {}
    d_indices[0], d_indices[1], d_indices[2] = list(), list(), list()
    for i, dihedral_angle in enumerate(dihedrals):
        if dihedral_angle == 0.:
            continue
        if dihedral_angle > 2 or dihedral_angle < -2: # angle is -3 or +3
            d_indices[0].append(i)
        elif dihedral_angle < 0 and dihedral_angle > -2: # angle is -1
            d_indices[1].append(i)
        elif dihedral_angle > 0 and dihedral_angle < 2: # angle is 1
            d_indices[2].append(i)
    return d_indices

### Get snapshots from T42A apo

In [50]:
outdir = "/data/chodera/zhangi/perses_benchmark/neq/12/38/"
name = "THR"
endstate = 0
phase = "apo"
length = 5
i = os.path.basename(os.path.dirname(outdir))
aa = name.lower()


In [51]:
with open(os.path.join(outdir, f"{i}_{phase}_{endstate}.pickle"), 'rb') as f:
    htf = pickle.load(f)

indices_old = [669, 670, 673, 674]
indices_new = [669, 670, 673, 676]

In [52]:
dihedrals, all_pos_hybrid = get_dihedrals_from_decorrelated(i, phase, aa, length, outdir, htf, indices_new, indices_old)

100%|██████████| 100/100 [01:59<00:00,  1.20s/it]


In [53]:
dihedrals_old = dihedrals[1]

In [54]:
d_indices_T42A_apo = sort_snapshots(dihedrals_old)

In [55]:
d_indices_T42A_apo

{0: [20],
 1: [0, 3, 9, 10, 16, 17, 19, 25, 29, 30, 31, 39, 57, 74, 75, 79, 83, 89, 91],
 2: [1,
  2,
  4,
  5,
  6,
  7,
  8,
  11,
  12,
  13,
  14,
  15,
  18,
  21,
  22,
  23,
  24,
  26,
  27,
  28,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  76,
  77,
  78,
  80,
  81,
  82,
  84,
  85,
  86,
  87,
  88,
  90,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99]}

In [58]:
all_pos_hybrid.shape

(100, 14881, 3)

In [59]:
# Take the first snapshot for every angle
pos = all_pos_hybrid[[d_indices_T42A_apo[1][0], d_indices_T42A_apo[2][0]]]

In [69]:
pos.shape

(2, 14881, 3)

In [60]:
with open(os.path.join("/data/chodera/zhangi/perses_benchmark/neq/13/2/", f"{2}_{phase}_{endstate}.npy"), "wb") as f:
    np.save(f, pos)

### Get snapshots from T42A complex

In [62]:
outdir = "/data/chodera/zhangi/perses_benchmark/neq/12/38/"
name = "THR"
endstate = 0
phase = "complex"
length = 5
i = os.path.basename(os.path.dirname(outdir))
aa = name.lower()


In [63]:
with open(os.path.join(outdir, f"{i}_{phase}_{endstate}.pickle"), 'rb') as f:
    htf = pickle.load(f)

indices_old = [669, 670, 673, 674]
indices_new = [669, 670, 673, 676]

In [64]:
dihedrals, all_pos_hybrid = get_dihedrals(i, aa, length, outdir, htf, indices_new, indices_old)                                       


100%|██████████| 10/10 [01:18<00:00,  7.83s/it]


In [65]:
dihedrals_old = dihedrals[1]

In [66]:
d_indices_T42A_complex = sort_snapshots(dihedrals_old)

In [67]:
d_indices_T42A_complex

{0: [], 1: [], 2: [0, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500]}

In [70]:
all_pos_hybrid[0].shape

(5000, 29478, 3)

In [71]:
pos = all_pos_hybrid[0][[d_indices_T42A_complex[2][0]]]

In [73]:
pos.shape

(1, 29478, 3)

In [74]:
with open(os.path.join("/data/chodera/zhangi/perses_benchmark/neq/13/2/", f"{2}_{phase}_{endstate}.npy"), "wb") as f:
    np.save(f, pos)

### Get snapshots from A42T apo

In [28]:
outdir = "/data/chodera/zhangi/perses_benchmark/neq/12/39/"
name = "THR"
endstate = 1
phase = "apo"
length = 5
i = os.path.basename(os.path.dirname(outdir))
aa = name.lower()


In [29]:
with open(os.path.join(outdir, f"{i}_{phase}_{endstate}.pickle"), 'rb') as f:
    htf = pickle.load(f)

indices_old = [669, 670, 673, 676]
indices_new = [669, 670, 673, 681]

In [30]:
dihedrals, all_pos_hybrid = get_dihedrals_from_decorrelated(i, phase, aa, length, outdir, htf, indices_new, indices_old)

100%|██████████| 100/100 [01:15<00:00,  1.33it/s]


In [31]:
dihedrals_new = dihedrals[0]

In [32]:
d_indices_A42T_apo = sort_snapshots(dihedrals_new)

In [33]:
d_indices_A42T_apo

{0: [],
 1: [3,
  12,
  16,
  24,
  27,
  30,
  36,
  40,
  44,
  47,
  50,
  51,
  53,
  59,
  62,
  66,
  78,
  86,
  88,
  89],
 2: [0,
  1,
  2,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  13,
  14,
  15,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  25,
  26,
  28,
  29,
  31,
  32,
  33,
  34,
  35,
  37,
  38,
  39,
  41,
  42,
  43,
  45,
  46,
  48,
  49,
  52,
  54,
  55,
  56,
  57,
  58,
  60,
  61,
  63,
  64,
  65,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  87,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99]}

In [34]:
all_pos_hybrid.shape

(100, 14881, 3)

In [35]:
# Take the first snapshot for every angle
pos = all_pos_hybrid[[d_indices_A42T_apo[1][0], d_indices_A42T_apo[2][0]]]

In [37]:
with open(os.path.join("/data/chodera/zhangi/perses_benchmark/neq/13/3/", f"{3}_{phase}_{endstate}.npy"), "wb") as f:
    np.save(f, pos)

### Get snapshots from A42T complex

In [75]:
outdir = "/data/chodera/zhangi/perses_benchmark/neq/12/39/"
name = "THR"
endstate = 1
phase = "complex"
length = 5
i = os.path.basename(os.path.dirname(outdir))
aa = name.lower()


In [76]:
with open(os.path.join(outdir, f"{i}_{phase}_{endstate}.pickle"), 'rb') as f:
    htf = pickle.load(f)

indices_old = [669, 670, 673, 676]
indices_new = [669, 670, 673, 681]

In [77]:
dihedrals, all_pos_hybrid = get_dihedrals_from_decorrelated(i, phase, aa, length, outdir, htf, indices_new, indices_old)

100%|██████████| 100/100 [02:23<00:00,  1.43s/it]


In [78]:
dihedrals_new = dihedrals[0]

In [79]:
d_indices_A42T_complex = sort_snapshots(dihedrals_new)

In [80]:
d_indices_A42T_complex

{0: [],
 1: [84],
 2: [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99]}

In [81]:
all_pos_hybrid.shape

(100, 29478, 3)

In [82]:
pos = all_pos_hybrid[[d_indices_A42T_complex[2][0]]]

In [84]:
pos.shape

(1, 29478, 3)

In [83]:
with open(os.path.join("/data/chodera/zhangi/perses_benchmark/neq/13/3/", f"{3}_{phase}_{endstate}.npy"), "wb") as f:
    np.save(f, pos)