In [1]:
import math
from simtk import unit
import os
import tempfile
import pickle
import mdtraj as md
import numpy as np
from simtk.unit.quantity import Quantity
import logging 

# Set up logger
_logger = logging.getLogger()
_logger.setLevel(logging.INFO)

from matplotlib import pyplot as plt
from simtk.openmm import app
from tqdm import tqdm, tqdm_notebook
import argparse
import random
import time


In [7]:
def get_dihedrals(i, name, length, out_dir, htf, dihedral_indices_new, dihedral_indices_old):
    
    from tqdm import tqdm
    
    old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)
    new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)
    
    # Load nc file
    # From Hannah: https://github.com/hannahbrucemacdonald/endstate_pdbs/blob/master/scripts/input_for_pol_calc.py
    from perses.analysis.utils import open_netcdf
    nc = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{name.lower()}_{length}ns.nc"))
    nc_checkpoint = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{name.lower()}_{length}ns_checkpoint.nc"))
    checkpoint_interval = nc_checkpoint.CheckpointInterval
    all_positions = nc_checkpoint.variables['positions']
    n_iter, n_replicas, n_atoms, _ = np.shape(all_positions)
    n_iter = 10

    # Initialize arrays for old, new, and hybrid positions
    all_pos_hybrid = np.zeros(shape=(n_iter, n_atoms, 3)) # includes solvent atoms
    
    # Save old, new, and hybrid positions
    for iteration in tqdm(range(n_iter)):
        replica_id = np.where(nc.variables['states'][iteration*checkpoint_interval] == 0)[0]
        pos = all_positions[iteration,replica_id,:,:][0] *unit.nanometers
        all_pos_hybrid[iteration] = pos 
    
    return None, n_iter, all_pos_hybrid
    
def plot_dihedrals(dihedrals, outfile):
    # Plot histogram with error bars : https://stackoverflow.com/questions/35390276/how-to-add-error-bars-to-histogram-diagram-in-python
    entries, edges, _ = plt.hist(dihedrals)
    bin_centers = 0.5 * (edges[:-1] + edges[1:]) # calculate bin centers
    plt.errorbar(bin_centers, entries, yerr=np.sqrt(entries), fmt='r.') # draw errobars, use the sqrt error
    plt.xlim(-np.pi, np.pi)
    plt.savefig(outfile, dpi=300)
    plt.close()
    
def plot_time_series(dihedrals, n_iter, outfile):
    from perses.dispersed import feptasks
    t0, g, neff_max, a_t, uncorrelated_indices = feptasks.compute_timeseries(dihedrals)

    plt.scatter(range(n_iter), dihedrals)
    plt.ylabel("dihedral")
    plt.xlabel("iteration number")
    plt.ylim(-np.pi, np.pi)
    plt.savefig(outfile, dpi=300)
    plt.close()
    
    return uncorrelated_indices
    
def plot_dihedrals_uncorrelated(dihedrals, uncorrelated_indices, outfile):
    # Plot histogram with error bars : https://stackoverflow.com/questions/35390276/how-to-add-error-bars-to-histogram-diagram-in-python
    entries, edges, _ = plt.hist(dihedrals[uncorrelated_indices])
    bin_centers = 0.5 * (edges[:-1] + edges[1:]) # calculate bin centers
    plt.errorbar(bin_centers, entries, yerr=np.sqrt(entries), fmt='r.') # draw errobars, use the sqrt error
    plt.xlim(-np.pi, np.pi)
    plt.savefig(outfile, dpi=300)
    plt.close()

In [10]:
outdir = "/data/chodera/zhangi/perses_benchmark/neq/14/1/"
endstate = 0
phase = "complex"
length = 1
i = 1
old_aa_name = 'ASN'
new_aa_name = 'LYS'
resid = '439'
name = 'asn'

In [4]:
with open(os.path.join(outdir, f"{i}_{phase}_{endstate}.pickle"), 'rb') as f:
    htf = pickle.load(f)


In [11]:
thr_dihedral = ['N', 'CA', 'CB', 'OG']
other_dihedral = ['N', 'CA', 'CB', 'CG']
ala_dihedral = ['N', 'CA', 'CB', 'HB1']
asp_dihedral = ['CA', 'CB', 'CG', 'OD2']

dihedral_atoms = []
for aa_name in [old_aa_name, new_aa_name]:
    if aa_name in ["PHE", "TYR", "TRP", "GLU", "LYS", "ARG", "GLN", "ASN", "ILE"]:
        dihedral_atoms.append(other_dihedral)
    elif aa_name == "THR":
        dihedral_atoms.append(thr_dihedral)
    elif aa_name == "ALA":
        dihedral_atoms.append(ala_dihedral)
    elif aa_name == 'ASP':
        dihedral_atoms.append(asp_dihedral)

for res in htf._topology_proposal.old_topology.residues():
    if res.id == resid and res.chain.index == 0:
        residue_old = res
for res in htf._topology_proposal.new_topology.residues():
    if res.id == resid and res.chain.index == 0:
        residue_new = res
# indices_old = [atom.index for atom in list(htf._topology_proposal.old_topology.residues())[int(args.resid)].atoms() if atom.name in dihedral_atoms[0]]
# indices_new = [atom.index for atom in list(htf._topology_proposal.new_topology.residues())[int(args.resid)].atoms() if atom.name in dihedral_atoms[1]]
indices_old = [atom.index for atom in residue_old.atoms() if atom.name in dihedral_atoms[0]]
indices_new = [atom.index for atom in residue_new.atoms() if atom.name in dihedral_atoms[1]]
_logger.info(f"old indices: {indices_old}")
_logger.info(f"new indices: {indices_new}")



INFO:root:old indices: [1633, 1635, 1637, 1640]
INFO:root:new indices: [1633, 1635, 1637, 1640]


In [12]:
_, n_iter, all_pos_hybrid = get_dihedrals(i, name, length, outdir, htf, indices_new, indices_old)


100%|██████████| 10/10 [00:01<00:00,  6.83it/s]


In [13]:
traj = md.Trajectory(all_pos_hybrid, htf.hybrid_topology)

traj.atom_slice(traj.topology.select("water or resname 'na\+' or resn 'cl\-'"), inplace=True)

traj.save(os.path.join(outdir, f"{i}_{phase}_{endstate}_solvent_hybrid_analysis_1.dcd"))
traj[0].save(os.path.join(outdir, f"{i}_{phase}_{endstate}_solvent_hybrid_analysis_1.pdb"))


# Figure out what's wrong with the full analysis pipeline

In [1]:
import math
from simtk import unit
import os
import tempfile
import pickle
import mdtraj as md
import numpy as np
from simtk.unit.quantity import Quantity
import logging 

# Set up logger
_logger = logging.getLogger()
_logger.setLevel(logging.INFO)

from matplotlib import pyplot as plt
from simtk.openmm import app
from tqdm import tqdm
import argparse
import random



In [10]:
def new_positions(htf, hybrid_positions):
    n_atoms_new = htf._topology_proposal.n_atoms_new
    hybrid_indices = [htf._new_to_hybrid_map[idx] for idx in range(n_atoms_new)]
    return hybrid_positions[hybrid_indices, :]
    
def old_positions(htf, hybrid_positions):
    n_atoms_old = htf._topology_proposal.n_atoms_old
    hybrid_indices = [htf._old_to_hybrid_map[idx] for idx in range(n_atoms_old)]
    return hybrid_positions[hybrid_indices, :]

def get_dihedrals(i, name, length, out_dir, htf, dihedral_indices_new, dihedral_indices_old):
    new_top = md.Topology.from_openmm(htf._topology_proposal.new_topology)
    old_top = md.Topology.from_openmm(htf._topology_proposal.old_topology)
    
    # From Hannah: https://github.com/hannahbrucemacdonald/endstate_pdbs/blob/master/scripts/input_for_pol_calc.py
    from perses.analysis.utils import open_netcdf
    nc = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{name.lower()}_{length}ns.nc"))
    nc_checkpoint = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{name.lower()}_{length}ns_checkpoint.nc"))
    checkpoint_interval = nc_checkpoint.CheckpointInterval
    all_positions = nc_checkpoint.variables['positions']
    n_iter, n_replicas, n_atoms, _ = np.shape(all_positions)
    n_iter = 40
    
    from tqdm import tqdm
    all_pos_new = np.zeros(shape=(n_iter, new_top.n_atoms, 3))
    all_pos_old = np.zeros(shape=(n_iter, old_top.n_atoms, 3))
    all_pos_hybrid = np.zeros(shape=(n_iter, n_atoms, 3))
    for iteration in tqdm(range(n_iter)):
        replica_id = np.where(nc.variables['states'][iteration*checkpoint_interval] == 0)[0]
        pos = all_positions[iteration,replica_id,:,:][0] *unit.nanometers
        all_pos_new[iteration] = new_positions(htf, pos).value_in_unit_system(unit.md_unit_system) # Get new positions only
        all_pos_hybrid[iteration] = pos # Get hybrid positions
        all_pos_old[iteration] = old_positions(htf, pos).value_in_unit_system(unit.md_unit_system)

    dihedrals_all = []
    for pos, top, indices in zip([all_pos_new, all_pos_old], [new_top, old_top], [dihedral_indices_new, dihedral_indices_old]):
        traj = md.Trajectory(np.array(pos), top)
        dihedrals = md.compute_dihedrals(traj, np.array([indices]))
        dihedrals_all.append(dihedrals)
    
    return dihedrals_all, n_iter, all_pos_hybrid
    
def plot_dihedrals(dihedrals, outfile):
    # Plot histogram with error bars : https://stackoverflow.com/questions/35390276/how-to-add-error-bars-to-histogram-diagram-in-python
    entries, edges, _ = plt.hist(dihedrals)
    bin_centers = 0.5 * (edges[:-1] + edges[1:]) # calculate bin centers
    plt.errorbar(bin_centers, entries, yerr=np.sqrt(entries), fmt='r.') # draw errobars, use the sqrt error
    plt.xlim(-np.pi, np.pi)
#     plt.savefig(outfile, dpi=300)
    plt.close()
    
def plot_time_series(dihedrals, n_iter, outfile):
    from perses.dispersed import feptasks
    t0, g, neff_max, a_t, uncorrelated_indices = feptasks.compute_timeseries(dihedrals)

    plt.scatter(range(n_iter), dihedrals)
    plt.ylabel("dihedral")
    plt.xlabel("iteration number")
    plt.ylim(-np.pi, np.pi)
#     plt.savefig(outfile, dpi=300)
    plt.close()
    
    return uncorrelated_indices
    
def plot_dihedrals_uncorrelated(dihedrals, uncorrelated_indices, outfile):
    # Plot histogram with error bars : https://stackoverflow.com/questions/35390276/how-to-add-error-bars-to-histogram-diagram-in-python
    entries, edges, _ = plt.hist(dihedrals[uncorrelated_indices])
    bin_centers = 0.5 * (edges[:-1] + edges[1:]) # calculate bin centers
    plt.errorbar(bin_centers, entries, yerr=np.sqrt(entries), fmt='r.') # draw errobars, use the sqrt error
    plt.xlim(-np.pi, np.pi)
#     plt.savefig(outfile, dpi=300)
    plt.close()




In [3]:
outdir = "/data/chodera/zhangi/perses_benchmark/neq/14/6/"
state = 0
phase = "apo"
length = 1
i = 6
old_aa_name = 'ASN'
new_aa_name = 'LYS'
resid = '439'
name = 'asn'

length = 1
i = os.path.basename(os.path.dirname(outdir))   
    
with open(os.path.join(outdir, f"{i}_{phase}_{state}.pickle"), 'rb') as f:
    htf = pickle.load(f)



In [4]:
thr_dihedral = ['N', 'CA', 'CB', 'OG1']
other_dihedral = ['N', 'CA', 'CB', 'CG']
ala_dihedral = ['N', 'CA', 'CB', 'HB1']
asp_dihedral = ['CA', 'CB', 'CG', 'OD2']
ile_dihedral = ['N', 'CA', 'CB', 'CG2']

dihedral_atoms = []
for aa_name in [old_aa_name,new_aa_name]:
    if aa_name in ["PHE", "TYR", "TRP", "GLU", "LYS", "ARG", "GLN", "ASN"]:
        dihedral_atoms.append(other_dihedral)
    elif aa_name == "THR":
        dihedral_atoms.append(thr_dihedral)
    elif aa_name == "ALA":
        dihedral_atoms.append(ala_dihedral)
    elif aa_name == 'ASP':
        dihedral_atoms.append(asp_dihedral)
    elif aa_name == 'ILE':
        dihedral_atoms.append(ile_dihedral)

for res in htf._topology_proposal.old_topology.residues():
    if res.id == resid and res.chain.index == 0:
        residue_old = res
for res in htf._topology_proposal.new_topology.residues():
    if res.id == resid and res.chain.index == 0:
        residue_new = res
# indices_old = [atom.index for atom in list(htf._topology_proposal.old_topology.residues())[int(args.resid)].atoms() if atom.name in dihedral_atoms[0]]
# indices_new = [atom.index for atom in list(htf._topology_proposal.new_topology.residues())[int(args.resid)].atoms() if atom.name in dihedral_atoms[1]]
indices_old = [atom.index for atom in residue_old.atoms() if atom.name in dihedral_atoms[0]]
indices_new = [atom.index for atom in residue_new.atoms() if atom.name in dihedral_atoms[1]]
_logger.info(f"old indices: {indices_old}")
_logger.info(f"new indices: {indices_new}")




INFO:root:old indices: [1633, 1635, 1637, 1640]
INFO:root:new indices: [1633, 1635, 1637, 1640]


In [6]:
dihedrals, n_iter, all_pos_hybrid = get_dihedrals(i, name, length, outdir, htf, indices_new, indices_old)


100%|██████████| 20/20 [00:01<00:00, 17.90it/s]


In [9]:
# # Plot 
# dihedrals_new = dihedrals[0]
# dihedrals_old = dihedrals[1]
# plot_dihedrals(dihedrals_old, os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns_{old_aa_name.lower()}_correlated.png"))
# uncorrelated_old = plot_time_series(dihedrals_old, n_iter, os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns_{old_aa_name.lower()}_timeseries.png"))
# plot_dihedrals_uncorrelated(dihedrals_old, uncorrelated_old, os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns_{old_aa_name.lower()}_decorrelated.png"))
# plot_dihedrals(dihedrals_new, os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns_{new_aa_name.lower()}_correlated.png"))
# uncorrelated_new = plot_time_series(dihedrals_new, n_iter, os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns_{new_aa_name.lower()}_timeseries.png"))
# plot_dihedrals_uncorrelated(dihedrals_new, uncorrelated_new, os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns_{new_aa_name.lower()}_decorrelated.png"))

# # Save 100 random uncorrelated hybrid pos snapshots
# if name == new_aa_name.lower():
#     uncorrelated_indices = uncorrelated_new
# elif name == old_aa_name.lower():
#     uncorrelated_indices = uncorrelated_old
# else:
#     raise Exception("Your specified amino acid did not match the old or new aa names")
# if len(uncorrelated_indices) >= 100:
#     subset_indices = random.sample(uncorrelated_indices, k=100) # Choose 100 random indices (without replacement) from uncorrelated indices
# else:
#     subset_indices = random.choices(uncorrelated_indices, k=100) # Choose 100 random indices (without replacement) from uncorrelated indices
# _logger.info(f"randomly chosen indices: {subset_indices}")


INFO:numexpr.utils:Note: detected 72 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO:numexpr.utils:Note: NumExpr detected 72 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
INFO:root:randomly chosen indices: [61, 64, 12, 15, 75, 81, 97, 45, 17, 42, 94, 70, 88, 27, 82, 18, 89, 1, 87, 8, 34, 60, 62, 54, 4, 48, 3, 29, 76, 85, 55, 39, 22, 86, 90, 43, 93, 91, 98, 23, 0, 69, 74, 13, 28, 33, 26, 72, 51, 6, 7, 32, 84, 80, 99, 25, 30, 46, 20, 79, 56, 50, 67, 9, 38, 71, 92, 78, 63, 5, 40, 11, 31, 95, 47, 77, 41, 96, 52, 35, 65, 66, 16, 73, 68, 36, 57, 59, 19, 83, 44, 49, 14, 58, 37, 53, 21, 2, 24, 10]


In [20]:
uncorrelated_old

[169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 182,
 183,
 184,
 185,
 186,
 187,
 188,
 189,
 190,
 191,
 192,
 193,
 194,
 195,
 196,
 197,
 198,
 199,
 200,
 201,
 202,
 203,
 205,
 206,
 207,
 208,
 209,
 210,
 211,
 212,
 213,
 214,
 215,
 216,
 217,
 218,
 219,
 220,
 221,
 222,
 223,
 224,
 225,
 226,
 228,
 229,
 230,
 231,
 232,
 233,
 234,
 235,
 236,
 237,
 238,
 239,
 240,
 241,
 242,
 243,
 244,
 245,
 246,
 247,
 248,
 249,
 251,
 252,
 253,
 254,
 255,
 256,
 257,
 258,
 259,
 260,
 261,
 262,
 263,
 264,
 265,
 266,
 267,
 268,
 269,
 270,
 271,
 272,
 274,
 275,
 276,
 277,
 278,
 279,
 280,
 281,
 282,
 283,
 284,
 285,
 286,
 287,
 288,
 289,
 290,
 291,
 292,
 293,
 294,
 295,
 297,
 298,
 299,
 300,
 301,
 302,
 303,
 304,
 305,
 306,
 307,
 308,
 309,
 310,
 311,
 312,
 313,
 314,
 315,
 316,
 317,
 318,
 319,
 321,
 322,
 323,
 324,
 325,
 326,
 327,
 328,
 329,
 330,
 331,
 332,
 333,
 334,
 335,
 336,
 337,
 338,
 339,
 340,
 341,
 342

In [7]:
# subset_pos = all_pos_hybrid[subset_indices] # Make array of hybrid positions for 100 uncorrelated indices
# with open(os.path.join(outdir, f"test_snapshots.npy"), 'wb') as f:
#     np.save(f, subset_pos)


In [7]:
traj = md.Trajectory(all_pos_hybrid[:20], htf.hybrid_topology)

traj.atom_slice(traj.topology.select("water or resname 'na\+' or resn 'cl\-'"), inplace=True)
# traj.atom_slice(traj.topology.select("protein or water"), inplace=True)

traj.save(os.path.join(outdir, f"test_snapshots.dcd"))
traj[0].save(os.path.join(outdir, f"test_snapshots.pdb"))


In [8]:
htf.hybrid_system.getForces()

[<simtk.openmm.openmm.MonteCarloBarostat; proxy of <Swig Object of type 'OpenMM::MonteCarloBarostat *' at 0x2b7144834330> >,
 <simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x2b7144834480> >,
 <simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x2b7144834a20> >,
 <simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x2b7144834990> >,
 <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x2b71448348a0> >]

In [18]:
# with open(os.path.join(outdir, f"test_snapshots.npy"), 'rb') as f:
#     cache = np.load(f)

# traj = md.Trajectory(cache, htf.hybrid_topology)

# # traj.atom_slice(traj.topology.select("water or resname 'na\+' or resn 'cl\-'"), inplace=True)

# traj.save(os.path.join(outdir, f"test_snapshots_2.dcd"))
# traj[0].save(os.path.join(outdir, f"test_snapshots_2.pdb"))


In [13]:
subset_pos.shape

(100, 200487, 3)

In [14]:
htf.hybrid_positions.shape

(200487, 3)

In [19]:
traj

<mdtraj.Trajectory with 100 frames, 200487 atoms, 62711 residues, without unitcells at 0x2b6063a99430>

In [24]:
# 1000th snapshot
traj = md.Trajectory(all_pos_hybrid[1000], htf.hybrid_topology)

traj.atom_slice(traj.topology.select("water or resname 'na\+' or resn 'cl\-'"), inplace=True)

traj[0].save(os.path.join(outdir, f"test_snapshot.pdb"))


In [25]:
traj

<mdtraj.Trajectory with 1 frames, 184146 atoms, 61382 residues, without unitcells at 0x2b605ec80a30>

In [30]:
n_atoms = 0
for chain in htf.hybrid_topology.chains:
    if chain.index == 5:
        for atom in chain.atoms:
            if atom.residue.name == 'HOH':
                n_atoms += 1

In [32]:
n_atoms

184146

In [2]:
import mdtraj as md

In [3]:
top = md.load("/data/chodera/zhangi/vir/coronavirus/WT/output/equilibrated.pdb")
traj = md.load("/data/chodera/zhangi/vir/coronavirus/WT/output/equilibrated.dcd", top=top)

In [4]:
traj

<mdtraj.Trajectory with 500 frames, 312033 atoms, 100032 residues, and unitcells at 0x2b61e8184b80>

In [11]:
traj[::50].save("/data/chodera/zhangi/vir/coronavirus/WT/output/shortened_equilibrated.dcd")