In [1]:
import math
from simtk import unit
import os
import tempfile
import pickle
import mdtraj as md
import numpy as np
from simtk.unit.quantity import Quantity
import logging

In [2]:
# Set up logger
_logger = logging.getLogger()
_logger.setLevel(logging.INFO)

from matplotlib import pyplot as plt
from simtk.openmm import app
from tqdm import tqdm
import argparse
import random

In [3]:
sim_number = 4
outdir = "/data/chodera/zhangi/perses_benchmark/neq/14/64/"
resid = '501'
old_aa_name = 'ASN'
new_aa_name = 'TYR'

if sim_number == 3:
    phase = 'complex'
    name = "ASN"
    state = 0
elif sim_number == 4:
    phase = 'complex'
    name = "TYR"
    state = 1

length = 1
i = os.path.basename(os.path.dirname(outdir))


In [4]:
def new_positions(htf, hybrid_positions):
    n_atoms_new = htf._topology_proposal.n_atoms_new
    hybrid_indices = [htf._new_to_hybrid_map[idx] for idx in range(n_atoms_new)]
    return hybrid_positions[hybrid_indices, :]

def old_positions(htf, hybrid_positions):
    n_atoms_old = htf._topology_proposal.n_atoms_old
    hybrid_indices = [htf._old_to_hybrid_map[idx] for idx in range(n_atoms_old)]
    return hybrid_positions[hybrid_indices, :]


In [5]:
def get_dihedrals(i, name, length, out_dir, htf, dihedral_indices_new, dihedral_indices_old):
    new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)
    old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)

    # From Hannah: https://github.com/hannahbrucemacdonald/endstate_pdbs/blob/master/scripts/input_for_pol_calc.py
    from perses.analysis.utils import open_netcdf
    nc = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{name.lower()}_{length}ns.nc"))
    nc_checkpoint = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{name.lower()}_{length}ns_checkpoint.nc"))
    checkpoint_interval = nc_checkpoint.CheckpointInterval
    all_positions = nc_checkpoint.variables['positions']
    n_iter, n_replicas, n_atoms, _ = np.shape(all_positions)

    from tqdm import tqdm
    all_pos_new = np.zeros(shape=(n_iter, new_top.n_atoms, 3))
    all_pos_old = np.zeros(shape=(n_iter, old_top.n_atoms, 3))
    all_pos_hybrid = np.zeros(shape=(n_iter, n_atoms, 3))
    for iteration in tqdm(range(n_iter)):
        # replica_id = np.where(nc.variables['states'][iteration*checkpoint_interval]== 0)[0]
        replica_id = 0
        pos = all_positions[iteration,replica_id,:,:] *unit.nanometers
        all_pos_new[iteration] = new_positions(htf, pos).value_in_unit_system(unit.md_unit_system) # Get new positions only
        all_pos_hybrid[iteration] = pos # Get hybrid positions
        all_pos_old[iteration] = old_positions(htf, pos).value_in_unit_system(unit.md_unit_system)

    dihedrals_all = []
    for pos, top, indices in zip([all_pos_new, all_pos_old], [new_top, old_top], [dihedral_indices_new, dihedral_indices_old]):
        traj = md.Trajectory(np.array(pos), top)
        dihedrals = md.compute_dihedrals(traj, np.array([indices]))
        dihedrals_all.append(dihedrals)

    return dihedrals_all, n_iter, all_pos_hybrid, all_pos_new, all_pos_old

def plot_dihedrals(dihedrals, outfile):
    # Plot histogram with error bars : https://stackoverflow.com/questions/35390276/how-to-add-error-bars-to-histogram-diagram-in-python
    entries, edges, _ = plt.hist(dihedrals)
    bin_centers = 0.5 * (edges[:-1] + edges[1:]) # calculate bin centers
    plt.errorbar(bin_centers, entries, yerr=np.sqrt(entries), fmt='r.') # draw errobars, use the sqrt error
    plt.xlim(-np.pi, np.pi)
    plt.savefig(outfile, dpi=300)
    plt.close()
                     
def plot_time_series(dihedrals, n_iter, outfile):
    plt.scatter(range(n_iter), dihedrals)
    plt.ylabel("dihedral")
    plt.xlabel("iteration number")
    plt.ylim(-np.pi, np.pi)
    plt.savefig(outfile, dpi=300)
    plt.close()




In [6]:
with open(os.path.join(outdir, f"{i}_{phase}_{state}.pickle"), 'rb') as f:
    htf = pickle.load(f)



In [9]:
thr_dihedral = ['N', 'CA', 'CB', 'OG1']
other_dihedral = ['N', 'CA', 'CB', 'CG']
ala_dihedral = ['N', 'CA', 'CB', 'HB1']
asp_dihedral = ['CA', 'CB', 'CG', 'OD2']
ile_dihedral = ['N', 'CA', 'CB', 'CG2']
gly_dihedral = ['O', 'C', 'CA', 'HA2']
ser_dihedral = ['N', 'CA', 'CB', 'OG']

dihedral_atoms = []
for aa_name in [old_aa_name, new_aa_name]:
    if aa_name in ["PHE", "TYR", "TRP", "GLU", "LYS", "ARG", "GLN", "ASN"]:
        dihedral_atoms.append(other_dihedral)
    elif aa_name == "THR":
        dihedral_atoms.append(thr_dihedral)
    elif aa_name == "ALA":
        dihedral_atoms.append(ala_dihedral)
    elif aa_name == 'ASP':
        dihedral_atoms.append(asp_dihedral)
    elif aa_name == 'ILE' or aa_name == 'VAL':
        dihedral_atoms.append(ile_dihedral)
    elif aa_name == 'GLY':
        dihedral_atoms.append(gly_dihedral)

In [10]:
for res in htf._topology_proposal.old_topology.residues():
    if res.id == resid and res.chain.index == 0:
        residue_old = res
for res in htf._topology_proposal.new_topology.residues():
    if res.id == resid and res.chain.index == 0:
        residue_new = res

        indices_old = [atom.index for atom in residue_old.atoms() if atom.name in dihedral_atoms[0]]
indices_new = [atom.index for atom in residue_new.atoms() if atom.name in dihedral_atoms[1]]
_logger.info(f"old indices: {indices_old}")
_logger.info(f"new indices: {indices_new}")

dihedrals, n_iter, all_pos_hybrid, all_pos_new, all_pos_old = get_dihedrals(i, name, length, outdir, htf, indices_new, indices_old)

# # Save every 10th snapshot
# subset_pos = all_pos_hybrid[1::10] # Make array of hybrid positions for 100 uncorrelated indices
# _logger.info(f"subset_pos shape: {subset_pos.shape}")
# with open(os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns_snapshots.npy"), 'wb') as f:
#     np.save(f, subset_pos)

INFO:root:old indices: [2605, 2607, 2609, 2612]
INFO:root:new indices: [2605, 2607, 2609, 2612]
100%|██████████| 1001/1001 [03:29<00:00,  4.79it/s]


In [None]:
new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)
old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)
traj_new = md.Trajectory(np.array(all_pos_new), new_top)
traj_old = md.Trajectory(np.array(all_pos_old), old_top)



In [7]:
box_vectors = np.array([val.value_in_unit_system(unit.md_unit_system) for val in htf.hybrid_system.getDefaultPeriodicBoxVectors()])
vectors = np.array([box_vectors]) 

new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)
traj = md.Trajectory(np.array(htf.new_positions(htf.hybrid_positions)), new_top)
    
# Set unit cell vectors in traj 
traj.unitcell_vectors = vectors

traj[0].save(os.path.join(outdir, f"check_box_vectors.pdb"))

In [None]:
traj_new[0].save(os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}_new.pdb"))
traj_new.save(os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}_new.dcd"))
traj_old[0].save(os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}_old.pdb"))
traj_old.save(os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}_old.dcd"))

In [12]:
# # Plot
# dihedrals_new = dihedrals[0]
# dihedrals_old = dihedrals[1]
# plot_dihedrals(dihedrals_old, os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns_{old_aa_name.lower()}_correlated.png"))
# plot_time_series(dihedrals_old, n_iter, os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns_{old_aa_name.lower()}_timeseries.png"))
# plot_dihedrals(dihedrals_new, os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns_{new_aa_name.lower()}_correlated.png"))
# plot_time_series(dihedrals_new, n_iter, os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns_{new_aa_name.lower()}_timeseries.png"))

In [13]:
htf._topology_proposal.new_topology.getPeriodicBoxVectors()

Quantity(value=(Vec3(x=13.804333500000002, y=0.0, z=0.0), Vec3(x=-4.601444128722184, y=13.014850568080696, z=0.0), Vec3(x=-4.601444128722184, y=-6.507424496441187, z=11.271191673136768)), unit=nanometer)

In [15]:
htf.hybrid_system.getDefaultPeriodicBoxVectors()

[Quantity(value=Vec3(x=13.804333500000002, y=0.0, z=0.0), unit=nanometer),
 Quantity(value=Vec3(x=-4.601444128722184, y=13.014850568080696, z=0.0), unit=nanometer),
 Quantity(value=Vec3(x=-4.601444128722184, y=-6.507424496441187, z=11.271191673136768), unit=nanometer)]