In [1]:
from simtk.openmm import unit, app
from tqdm import tqdm_notebook

import os
import pickle
import mdtraj as md
import numpy as np




In [7]:
def new_positions(htf, hybrid_positions):
    n_atoms_new = htf._topology_proposal.n_atoms_new
    hybrid_indices = [htf._new_to_hybrid_map[idx] for idx in range(n_atoms_new)]
    return hybrid_positions[hybrid_indices, :]

def old_positions(htf, hybrid_positions):
    n_atoms_old = htf._topology_proposal.n_atoms_old
    hybrid_indices = [htf._old_to_hybrid_map[idx] for idx in range(n_atoms_old)]
    return hybrid_positions[hybrid_indices, :]

def get_trajs(i, phase, out_dir, htf, state):
    new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)
    old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)

    # From Hannah: https://github.com/hannahbrucemacdonald/endstate_pdbs/blob/master/scripts/input_for_pol_calc.py
    from perses.analysis.utils import open_netcdf
    nc = open_netcdf(os.path.join(out_dir, f"{i}_{phase}.nc"))
    nc_checkpoint = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_checkpoint.nc"))
    checkpoint_interval = nc_checkpoint.CheckpointInterval
    all_positions = nc_checkpoint.variables['positions']
    n_iter, n_replicas, n_atoms, _ = np.shape(all_positions)
    box_vectors = np.array(nc_checkpoint['box_vectors'])
#     print("n_iter: ", n_iter)
    n_iter = 100

    all_pos_new = np.zeros(shape=(n_iter, new_top.n_atoms, 3))
    all_pos_old = np.zeros(shape=(n_iter, old_top.n_atoms, 3))
    all_pos_hybrid = np.zeros(shape=(n_iter, n_atoms, 3))
    for iteration in tqdm_notebook(range(n_iter)): # NOTE THAT I AM ONLY EXTRACTING THE FIRST 1 NS SNAPSHOTS
        replica_id = np.where(nc.variables['states'][iteration*checkpoint_interval] == state)[0]
#         replica_id = 0 # This should only be hard coded to replica 0, for rest simulations, where I am only saving the positions at state = 0
        pos = all_positions[iteration,replica_id,:,:][0] *unit.nanometers
        all_pos_new[iteration] = new_positions(htf, pos).value_in_unit_system(unit.md_unit_system) # Get new positions only
        all_pos_hybrid[iteration] = pos # Get hybrid positions
        all_pos_old[iteration] = old_positions(htf, pos).value_in_unit_system(unit.md_unit_system)

    traj_old = md.Trajectory(all_pos_old, old_top)
    traj_new = md.Trajectory(all_pos_new, new_top)
    
    # Set unit cell vectors in traj 
    box_vectors = [val.value_in_unit_system(unit.md_unit_system) for val in htf.hybrid_system.getDefaultPeriodicBoxVectors()]
    box_vectors_formatted = np.ndarray(shape=(3,3), buffer=np.array(box_vectors)).astype(np.float32) # note that mdraj expect np.ndarray, not np.arrays
    box_vectors_repeated = np.repeat(np.array([box_vectors_formatted]), n_iter, axis=0)
    traj_old.unitcell_vectors = box_vectors_repeated
    traj_new.unitcell_vectors = box_vectors_repeated
    
#     print("imaging old traj")
#     traj_old = traj_old.image_molecules()
    
#     print("saving old traj")
#     traj_old.save(os.path.join(out_dir, f"{i}_{phase}_old_state_{state}.dcd"))
# #     traj_old[0].save(os.path.join(out_dir, f"{i}_{phase}_old.pdb"))
#     app.PDBxFile.writeFile(traj_old.topology.to_openmm(), traj_old.openmm_positions(0), file=open(os.path.join(out_dir, f"{i}_{phase}_old_state_{state}.cif"), "w"), keepIds=True)
    
    print(traj_new)
    
    print("imaging new traj")
    traj_new = traj_new.image_molecules()
    
    print("saving new traj")
    traj_new.save(os.path.join(out_dir, f"{i}_{phase}_new_state_{state}.dcd"))
#     traj_new[0].save(os.path.join(out_dir, f"{i}_{phase}_new.pdb"))
    app.PDBxFile.writeFile(traj_new.topology.to_openmm(), traj_new.openmm_positions(99), file=open(os.path.join(out_dir, f"{i}_{phase}_new_state_{state}.cif"), "w"), keepIds=True)


In [None]:
with open("/data/chodera/zhangi/perses_benchmark/repex/32/0/0/0_complex.pickle", "rb") as f:
    htf = pickle.load(f)

INFO:rdkit:Enabling RDKit 2021.03.5 jupyter extensions


In [15]:
outdir = "/data/chodera/zhangi/perses_benchmark/repex/32/0/0/"
i = 0
phase = 'complex'
get_trajs(i, phase, outdir, htf)

n_iter:  105


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for iteration in tqdm_notebook(range(n_iter)):


  0%|          | 0/105 [00:00<?, ?it/s]

# Get traj of 32/0/2 at state 2 (lambda = 0.21818)

In [10]:
with open("/data/chodera/zhangi/perses_benchmark/repex/32/0/2/2_complex.pickle", "rb") as f:
    htf = pickle.load(f)
    

In [11]:
outdir = "/data/chodera/zhangi/perses_benchmark/repex/32/0/2/"
i = 2
phase = 'complex'
get_trajs(i, phase, outdir, htf, 2)



Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for iteration in tqdm_notebook(range(n_iter)): # NOTE THAT I AM ONLY EXTRACTING THE FIRST 1 NS SNAPSHOTS


  0%|          | 0/100 [00:00<?, ?it/s]

<mdtraj.Trajectory with 100 frames, 185384 atoms, 57660 residues, and unitcells>
imaging new traj
saving new traj


# Get traj of 32/2/0 at state 5 (lambda = 0.217)

In [8]:
with open("/data/chodera/zhangi/perses_benchmark/repex/32/2/0/0_complex.pickle", "rb") as f:
    htf = pickle.load(f)

In [9]:
outdir = "/data/chodera/zhangi/perses_benchmark/repex/32/2/0/"
i = 0
phase = 'complex'
get_trajs(i, phase, outdir, htf, 5)



Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for iteration in tqdm_notebook(range(n_iter)): # NOTE THAT I AM ONLY EXTRACTING THE FIRST 1 NS SNAPSHOTS


  0%|          | 0/100 [00:00<?, ?it/s]

<mdtraj.Trajectory with 100 frames, 185384 atoms, 57660 residues, and unitcells>
imaging new traj
saving new traj
