In [1]:
import pickle
import os
from perses.annihilation.rest import RESTTopologyFactory
from perses.annihilation.lambda_protocol import RESTState
from openmmtools.states import SamplerState, ThermodynamicState, CompoundThermodynamicState
from openmmtools import cache, utils
from perses.dispersed.utils import configure_platform
cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName())
from simtk import openmm, unit
import math
from openmmtools.constants import kB
from openmmtools import mcmc, multistate
import argparse
import copy
from perses.dispersed import feptasks
import mdtraj as md
import numpy as np
from perses.app.relative_point_mutation_setup import PointMutationExecutor

from openmmtools.integrators import LangevinIntegrator
from simtk.openmm import unit, app

INFO:numexpr.utils:Note: NumExpr detected 48 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.


conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA


INFO:rdkit:Enabling RDKit 2021.03.4 jupyter extensions


In [2]:
# Set up context cache
context_cache = cache.ContextCache(capacity=None, time_to_live=None)

# Set up logger
import logging
_logger = logging.getLogger()
_logger.setLevel(logging.INFO)

In [11]:
length = 1
move_length = 1
timestep = 4
radius = 0.7
outdir = "/data/chodera/zhangi/perses_benchmark/neq/14/148/"
phase = 'complex'
state = 0
n_replicas = 36
t_max = 1200
name = 'asn'

In [12]:
# Load rhtf
i = os.path.basename(os.path.dirname(outdir))
path = os.path.join(outdir, f"{i}_{phase}_{state}.pickle")
_logger.info(f"path: {path}")
htf = pickle.load(open(path, "rb" ))
positions = htf.hybrid_positions

INFO:root:path: /data/chodera/zhangi/perses_benchmark/neq/14/148/148_complex_0.pickle


ModuleNotFoundError: No module named 'simtk.openmm.app.topology'

In [None]:
# Build REST factory
traj = md.Trajectory(np.array(positions), htf.hybrid_topology)
RBD_chain = [atom.index for atom in htf.hybrid_topology.atoms if atom.residue.chain.index == 0]
ACE2_chain = [atom.index for atom in htf.hybrid_topology.atoms if atom.residue.chain.index == 2]
rbd_matches = md.compute_neighbors(traj, radius, ACE2_chain, haystack_indices=RBD_chain)
ace2_matches = md.compute_neighbors(traj, radius, RBD_chain, haystack_indices=ACE2_chain)
rest_atoms = list(rbd_matches[0]) + list(ace2_matches[0])
factory = RESTTopologyFactory(htf.hybrid_system, solute_region=rest_atoms, use_dispersion_correction=True)

# Get REST system
REST_system = factory.REST_system


In [10]:
with open(os.path.join(outdir, f"{i}_{phase}_{state}_rest.pickle"), "wb") as f:
    pickle.dump(factory, f)

In [16]:
from simtk.openmm import XmlSerializer
with open(os.path.join(outdir, f"{i}_{phase}_{state}_system.xml"), 'w') as f:
    f.write(XmlSerializer.serialize(REST_system))


In [9]:
# Create states for each replica
T_min = 300.0 * unit.kelvin  # Minimum temperature.
T_max = t_max * unit.kelvin  # Maximum temperature.
temperatures = [T_min + (T_max - T_min) * (math.exp(float(i) / float(n_replicas-1)) - 1.0) / (math.e - 1.0)
                for i in range(n_replicas)]

# Create reference thermodynamic state
lambda_zero_alchemical_state = RESTState.from_system(REST_system)
thermostate = ThermodynamicState(REST_system, temperature=T_min)
compound_thermodynamic_state = CompoundThermodynamicState(thermostate, composable_states=[lambda_zero_alchemical_state])

In [10]:
# Create thermodynamics states
sampler_state =  SamplerState(positions, box_vectors=htf.hybrid_system.getDefaultPeriodicBoxVectors())
beta_0 = 1/(kB*T_min)
thermodynamic_state_list = []
sampler_state_list = []
for temperature in temperatures:
    beta_m = 1/(kB*temperature)
    compound_thermodynamic_state_copy = copy.deepcopy(compound_thermodynamic_state)
    compound_thermodynamic_state_copy.set_alchemical_parameters(beta_0, beta_m)
    thermodynamic_state_list.append(compound_thermodynamic_state_copy)

    # now generating a sampler_state for each thermodynamic state, with relaxed positions
    # context, context_integrator = context_cache.get_context(compound_thermodynamic_state_copy)
    feptasks.minimize(compound_thermodynamic_state_copy, sampler_state, max_iterations=0)
    sampler_state_list.append(copy.deepcopy(sampler_state))


In [11]:
from openmmtools.multistate import ReplicaExchangeSampler
import mpiplus
class ReplicaExchangeSampler2(ReplicaExchangeSampler):
    @mpiplus.on_single_node(rank=0, broadcast_result=False, sync_nodes=False)
    @mpiplus.delayed_termination
    def _report_iteration_items(self):
        """
        Sub-function of :func:`_report_iteration` which handles all the actual individual item reporting in a
        sub-class friendly way. The final actions of writing timestamp, last-good-iteration, and syncing
        should be left to the :func:`_report_iteration` and subclasses should extend this function instead
        """
        replica_id = np.where(self._replica_thermodynamic_states == 0)[0][0]
        print("ITERATION: ", self._iteration)
        print("REPLICA THERMOSTATES ", self._replica_thermodynamic_states, type(self._replica_thermodynamic_states))
        print("REPLICA ID ", replica_id, type(replica_id))
        self._reporter.write_sampler_states([self._sampler_states[replica_id]], self._iteration)
        self._reporter.write_replica_thermodynamic_states(self._replica_thermodynamic_states, self._iteration)
        self._reporter.write_mcmc_moves(self._mcmc_moves)  # MCMCMoves can store internal statistics.
        self._reporter.write_energies(self._energy_thermodynamic_states, self._neighborhoods, self._energy_unsampled_states,
                                      self._iteration)
        self._reporter.write_mixing_statistics(self._n_accepted_matrix, self._n_proposed_matrix, self._iteration)


In [12]:
# Set up sampler
move = mcmc.LangevinSplittingDynamicsMove(timestep=timestep*unit.femtoseconds, n_steps=int((move_length*1000)/timestep), context_cache=context_cache)
simulation = ReplicaExchangeSampler2(mcmc_moves=move, number_of_iterations=length*1000)




In [20]:
context, integrator = context_cache.get_context(thermodynamic_state_list[0])

In [22]:
with open(os.path.join(outdir, f"{i}_{phase}_{state}_state.xml"), 'w') as f:
    f.write(XmlSerializer.serialize(context.getState(getPositions=True, getVelocities=True)))
with open(os.path.join(outdir, f"{i}_{phase}_{state}_integrator.xml"), 'w') as f:
    f.write(XmlSerializer.serialize(integrator))


In [25]:
# Run t-repex
reporter_file = os.path.join(outdir, f"{i}_{phase}_{name.lower()}_{length}ns.nc")
reporter = multistate.MultiStateReporter(reporter_file, checkpoint_interval=1)
simulation.create(thermodynamic_states=thermodynamic_state_list,
                  sampler_states=sampler_state_list,
                  storage=reporter)
simulation.run()



Please cite the following:

        Friedrichs MS, Eastman P, Vaidyanathan V, Houston M, LeGrand S, Beberg AL, Ensign DL, Bruns CM, and Pande VS. Accelerating molecular dynamic simulations on graphics processing unit. J. Comput. Chem. 30:864, 2009. DOI: 10.1002/jcc.21209
        Eastman P and Pande VS. OpenMM: A hardware-independent framework for molecular simulations. Comput. Sci. Eng. 12:34, 2010. DOI: 10.1109/MCSE.2010.27
        Eastman P and Pande VS. Efficient nonbonded interactions for molecular dynamics on a graphics processing unit. J. Comput. Chem. 31:1268, 2010. DOI: 10.1002/jcc.21413
        Eastman P and Pande VS. Constant constraint matrix approximation: A robust, parallelizable constraint method for molecular simulations. J. Chem. Theor. Comput. 6:434, 2010. DOI: 10.1021/ct900463w
        Chodera JD and Shirts MR. Replica exchange and expanded ensemble simulations as Gibbs multistate: Simple improvements for enhanced mixing. J. Chem. Phys., 135:194110, 2011. DOI:10.1063/

OpenMMException: Error creating array interactingAtoms: CUDA_ERROR_OUT_OF_MEMORY (2)