Ran REST on vanilla solvated THR and ALA systems, analysis only took 40 seconds...? why so fast? 

In [1]:
import math
from simtk import unit
import os
import tempfile
import pickle
import mdtraj as md
import numpy as np
from simtk.unit.quantity import Quantity
import logging

# Set up logger
_logger = logging.getLogger()
_logger.setLevel(logging.INFO)

from matplotlib import pyplot as plt
from simtk.openmm import app
from tqdm import tqdm
import argparse
import pickle


In [None]:
def get_dihedrals_all_replicas(i, aa, length, out_dir, topology, dihedral_indices):

    # From Hannah: https://github.com/hannahbrucemacdonald/endstate_pdbs/blob/master/scripts/input_for_pol_calc.py
    from perses.analysis.utils import open_netcdf
    nc = open_netcdf(os.path.join(out_dir, f"{i}_{args.phase}_{aa}_{length}.nc"))
    nc_checkpoint = open_netcdf(os.path.join(out_dir, f"{i}_{args.phase}_{aa}_{length}_checkpoint.nc"))
    checkpoint_interval = nc_checkpoint.CheckpointInterval
    all_positions = nc_checkpoint.variables['positions']
    n_iter, n_replicas, n_atoms, _ = np.shape(all_positions)

    from tqdm import tqdm
    dihedrals_master = []
    for i in [0, 6, 11]:
        index = i # of replica
        all_pos = np.zeros(shape=(n_iter, topology.getNumAtoms(), 3))
        for iteration in tqdm(range(n_iter)):
            replica_id = np.where(nc.variables['states'][iteration*checkpoint_interval] == index)[0]
            pos = all_positions[iteration,replica_id,:,:][0] *unit.nanometers
            all_pos[iteration] = pos

        traj = md.Trajectory(np.array(all_pos), topology)
    #   dihedrals = np.sin(md.compute_dihedrals(traj, np.array([indices])))
        dihedrals = md.compute_dihedrals(traj, np.array([dihedral_indices]))
        dihedrals_master.append(dihedrals)
    return dihedrals_master, n_iter

In [2]:
def plot_dihedrals(dihedrals, outfile):
    # Plot histogram with error bars : https://stackoverflow.com/questions/35390276/how-to-add-error-bars-to-histogram-diagram-in-python
    entries, edges, _ = plt.hist(dihedrals)
    bin_centers = 0.5 * (edges[:-1] + edges[1:]) # calculate bin centers
    plt.errorbar(bin_centers, entries, yerr=np.sqrt(entries), fmt='r.') # draw errobars, use the sqrt error
    plt.xlim(-np.pi, np.pi)
    plt.savefig(outfile, dpi=300)
    plt.close()

def plot_time_series(dihedrals, n_iter, outfile):
    from perses.dispersed import feptasks
    t0, g, neff_max, a_t, uncorrelated_indices = feptasks.compute_timeseries(dihedrals)

    plt.scatter(range(n_iter), dihedrals)
    plt.ylabel("dihedral")
    plt.xlabel("iteration number")
    plt.ylim(-np.pi, np.pi)
    plt.savefig(outfile, dpi=300)
    plt.close()

    return uncorrelated_indices

def plot_dihedrals_uncorrelated(dihedrals, uncorrelated_indices, outfile):
    # Plot histogram with error bars : https://stackoverflow.com/questions/35390276/how-to-add-error-bars-to-histogram-diagram-in-python
    entries, edges, _ = plt.hist(dihedrals[uncorrelated_indices])
    bin_centers = 0.5 * (edges[:-1] + edges[1:]) # calculate bin centers
    plt.errorbar(bin_centers, entries, yerr=np.sqrt(entries), fmt='r.') # draw errobars, use the sqrt error
    plt.xlim(-np.pi, np.pi)
    plt.savefig(outfile, dpi=300)
    plt.close()

In [7]:
i = 10
aa ='ala'
length = '5ns'
out_dir = "/data/chodera/zhangi/perses_benchmark/neq/11/10"
phase = 'solvent'

if aa == 'thr':
    indices = [6, 7, 10, 12]
elif aa == 'ala':
    indices = [6, 8, 10, 15]

# Get topology
with open(os.path.join(out_dir, f"{i}_{aa}_vanilla_topology.pickle"), "rb") as f:
    topology = pickle.load(f)

# dihedrals, n_iter = get_dihedrals_all_replicas(i, aa, length, out_dir, topology, indices)

# for j, replica in tqdm(enumerate(dihedrals)):
#     plot_dihedrals(replica, os.path.join(out_dir, f"{i}_{args.phase}_{aa}_{length}_{j}_{aa}_correla
# ted.png"))
#     uncorrelated_old = plot_time_series(replica, n_iter, os.path.join(out_dir, f"{i}_{args.phase}_{
# aa}_{length}_{j}_{aa}_timeseries.png"))
#     plot_dihedrals_uncorrelated(replica, uncorrelated_old, os.path.join(out_dir, f"{i}_{args.phase}
# _{aa}_{length}_{j}_{aa}_decorrelated.png"))

In [6]:
for atom in topology.atoms():
    print(atom)

<Atom 0 (H1) of chain 0 residue 0 (ACE)>
<Atom 1 (CH3) of chain 0 residue 0 (ACE)>
<Atom 2 (H2) of chain 0 residue 0 (ACE)>
<Atom 3 (H3) of chain 0 residue 0 (ACE)>
<Atom 4 (C) of chain 0 residue 0 (ACE)>
<Atom 5 (O) of chain 0 residue 0 (ACE)>
<Atom 6 (N) of chain 0 residue 1 (ALA)>
<Atom 7 (H) of chain 0 residue 1 (ALA)>
<Atom 8 (CA) of chain 0 residue 1 (ALA)>
<Atom 9 (HA) of chain 0 residue 1 (ALA)>
<Atom 10 (CB) of chain 0 residue 1 (ALA)>
<Atom 11 (HB2) of chain 0 residue 1 (ALA)>
<Atom 12 (HB3) of chain 0 residue 1 (ALA)>
<Atom 13 (C) of chain 0 residue 1 (ALA)>
<Atom 14 (O) of chain 0 residue 1 (ALA)>
<Atom 15 (HB1) of chain 0 residue 1 (ALA)>
<Atom 16 (N) of chain 0 residue 2 (NME)>
<Atom 17 (H) of chain 0 residue 2 (NME)>
<Atom 18 (C) of chain 0 residue 2 (NME)>
<Atom 19 (H1) of chain 0 residue 2 (NME)>
<Atom 20 (H2) of chain 0 residue 2 (NME)>
<Atom 21 (H3) of chain 0 residue 2 (NME)>
<Atom 22 (O) of chain 1 residue 3 (HOH)>
<Atom 23 (H1) of chain 1 residue 3 (HOH)>
<Atom 24

In [8]:
from perses.analysis.utils import open_netcdf
nc = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{aa}_{length}.nc"))
nc_checkpoint = open_netcdf(os.path.join(out_dir, f"{i}_{phase}_{aa}_{length}_checkpoint.nc"))
checkpoint_interval = nc_checkpoint.CheckpointInterval
all_positions = nc_checkpoint.variables['positions']
n_iter, n_replicas, n_atoms, _ = np.shape(all_positions)



In [9]:
n_atoms

1548

In [None]:
from tqdm import tqdm
dihedrals_master = []
for i in [0, 6, 11]:
    index = i # of replica
    all_pos = np.zeros(shape=(n_iter, topology.getNumAtoms(), 3))
    for iteration in tqdm(range(n_iter)):
        replica_id = np.where(nc.variables['states'][iteration*checkpoint_interval] == index)[0]
        pos = all_positions[iteration,replica_id,:,:][0] *unit.nanometers
        all_pos[iteration] = pos