In [1]:
import math
from simtk import unit
import os
import tempfile
import pickle
import mdtraj as md
import numpy as np
from simtk.unit.quantity import Quantity
import logging 

# Set up logger
_logger = logging.getLogger()
_logger.setLevel(logging.INFO)

from matplotlib import pyplot as plt
from simtk.openmm import app
from tqdm import tqdm
import argparse
import random

In [2]:
phase = 'complex'
name = 'tyr'
state = 1
resid = '501'
out_dir = "/data/chodera/zhangi/perses_benchmark/neq/14/84/"
length = 1
i = os.path.basename(os.path.dirname(out_dir))

In [3]:
def new_positions(htf, hybrid_positions):
    n_atoms_new = htf._topology_proposal.n_atoms_new
    hybrid_indices = [htf._new_to_hybrid_map[idx] for idx in range(n_atoms_new)]
    return hybrid_positions[hybrid_indices, :]
    
def old_positions(htf, hybrid_positions):
    n_atoms_old = htf._topology_proposal.n_atoms_old
    hybrid_indices = [htf._old_to_hybrid_map[idx] for idx in range(n_atoms_old)]
    return hybrid_positions[hybrid_indices, :]

def get_dihedrals(i, name, length, out_dir, htf):
    new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)
    old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)
    
    # From Hannah: https://github.com/hannahbrucemacdonald/endstate_pdbs/blob/master/scripts/input_for_pol_calc.py
    from perses.analysis.utils import open_netcdf
    nc = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{name.lower()}_{length}ns.nc"))
    nc_checkpoint = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{name.lower()}_{length}ns_checkpoint.nc"))
    checkpoint_interval = nc_checkpoint.CheckpointInterval
    all_positions = nc_checkpoint.variables['positions']
    n_iter, n_replicas, n_atoms, _ = np.shape(all_positions)
    box_vectors = np.array(nc_checkpoint['box_vectors'])

    from tqdm import tqdm
    all_pos_new = np.zeros(shape=(n_iter, new_top.n_atoms, 3))
    all_pos_old = np.zeros(shape=(n_iter, old_top.n_atoms, 3))
    all_pos_hybrid = np.zeros(shape=(n_iter, n_atoms, 3))
    for iteration in tqdm(range(n_iter)):
        # replica_id = np.where(nc.variables['states'][iteration*checkpoint_interval] == 0)[0]
        replica_id = 1
        pos = all_positions[iteration,replica_id,:,:] *unit.nanometers
        all_pos_new[iteration] = new_positions(htf, pos).value_in_unit_system(unit.md_unit_system) # Get new positions only
        all_pos_hybrid[iteration] = pos # Get hybrid positions
        all_pos_old[iteration] = old_positions(htf, pos).value_in_unit_system(unit.md_unit_system)

#     dihedrals_all = []
#     for pos, top, indices in zip([all_pos_new, all_pos_old], [new_top, old_top], [dihedral_indices_new, dihedral_indices_old]):
#         traj = md.Trajectory(np.array(pos), top)
#         dihedrals = md.compute_dihedrals(traj, np.array([indices]))
#         dihedrals_all.append(dihedrals)
    
    return n_iter, all_pos_hybrid, box_vectors
    

In [4]:
with open(os.path.join(out_dir, f"{i}_{phase}_{state}.pickle"), 'rb') as f:
    htf = pickle.load(f)



In [5]:
n_iter, all_pos_hybrid, box_vectors = get_dihedrals(i, name, length, out_dir, htf)


100%|██████████| 1001/1001 [02:10<00:00,  7.69it/s]


In [6]:
# Save every 10th snapshot
subset_pos = all_pos_hybrid[1::10] # Make array of hybrid positions for 100 uncorrelated indices
_logger.info(f"subset_pos shape: {subset_pos.shape}")
with open(os.path.join(out_dir, f"{i}_{phase}_{name.lower()}_{length}ns_snapshots_2000K.npy"), 'wb') as f:
    np.save(f, subset_pos)


INFO:root:subset_pos shape: (100, 183508, 3)
