In [1]:
import argparse
import pickle
import os
from perses.annihilation.lambda_protocol import LambdaProtocol
import simtk.unit as unit
from openmmtools.multistate import MultiStateReporter
# from perses.samplers.multistate import HybridRepexSampler
from openmmtools import mcmc
import logging
import numpy as np




In [2]:
outdir = "/data/chodera/zhangi/perses_benchmark/neq/14/156/"
phase = 'complex'

In [3]:
i = os.path.basename(os.path.dirname(outdir))
htf = pickle.load(open(os.path.join(outdir, f"{i}_{phase}.pickle"), "rb" ))


ERROR! Session/line number was not unique in database. History logging moved to new session 5018


INFO:rdkit:Enabling RDKit 2021.09.11 jupyter extensions


In [4]:
# Build the hybrid repex samplers
_logger = logging.getLogger()
_logger.setLevel(logging.DEBUG)
checkpoint_interval = 10
n_states = 12
n_cycles = 5000
lambda_protocol = LambdaProtocol(functions='default')
lambda_schedule = np.linspace(0.,0.1,n_states)

In [7]:
from perses.annihilation.lambda_protocol import RelativeAlchemicalState, LambdaProtocol

from openmmtools.multistate import sams, replicaexchange
from openmmtools import cache, utils
from perses.dispersed.utils import configure_platform
cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName())
context_cache = cache.ContextCache(capacity=None, time_to_live=None)
from openmmtools.states import CompoundThermodynamicState, SamplerState, ThermodynamicState
from perses.dispersed.utils import create_endstates

import numpy as np
import copy

import logging
_logger = logging.getLogger()
_logger.setLevel(logging.INFO)
_logger = logging.getLogger("multistate")


class HybridCompatibilityMixin(object):
    """
    Mixin that allows the MultistateSampler to accommodate the situation where
    unsampled endpoints have a different number of degrees of freedom.
    """

    def __init__(self, *args, hybrid_factory=None, **kwargs):
        self._hybrid_factory = hybrid_factory
        super(HybridCompatibilityMixin, self).__init__(*args, **kwargs)

    def setup(self, n_states, temperature, storage_file, minimisation_steps=100,
              n_replicas=None, lambda_schedule=None,
              lambda_protocol=LambdaProtocol(), endstates=True):


        from perses.dispersed import feptasks

        hybrid_system = self._factory.hybrid_system

        positions = self._factory.hybrid_positions
        lambda_zero_alchemical_state = RelativeAlchemicalState.from_system(hybrid_system)

        thermostate = ThermodynamicState(hybrid_system, temperature=temperature)
        compound_thermodynamic_state = CompoundThermodynamicState(thermostate, composable_states=[lambda_zero_alchemical_state])

        thermodynamic_state_list = []
        sampler_state_list = []

#         context_cache = cache.ContextCache()

        if n_replicas is None:
            _logger.info(f'n_replicas not defined, setting to match n_states, {n_states}')
            n_replicas = n_states
        elif n_replicas > n_states:
            _logger.warning(f'More sampler states: {n_replicas} requested greater than number of states: {n_states}. Setting n_replicas to n_states: {n_states}')
            n_replicas = n_states

        # TODO this feels like it should be somewhere else... just not sure where. Maybe into lambda_protocol
        if lambda_schedule is None:
            lambda_schedule = np.linspace(0.,1.,n_states)
        else:
            assert (len(lambda_schedule) == n_states) , 'length of lambda_schedule must match the number of states, n_states'
#             assert (lambda_schedule[0] == 0.), 'lambda_schedule must start at 0.'
#             assert (lambda_schedule[-1] == 1.), 'lambda_schedule must end at 1.'
            difference = np.diff(lambda_schedule)
            assert ( all(i >= 0. for i in difference ) ), 'lambda_schedule must be monotonicly increasing'

        #starting with the initial positions generated py geometry.py
        sampler_state =  SamplerState(positions, box_vectors=hybrid_system.getDefaultPeriodicBoxVectors())
        for lambda_val in lambda_schedule:
            compound_thermodynamic_state_copy = copy.deepcopy(compound_thermodynamic_state)
            compound_thermodynamic_state_copy.set_alchemical_parameters(lambda_val,lambda_protocol)
            thermodynamic_state_list.append(compound_thermodynamic_state_copy)

             # now generating a sampler_state for each thermodyanmic state, with relaxed positions
            context, context_integrator = context_cache.get_context(compound_thermodynamic_state_copy)
            feptasks.minimize(compound_thermodynamic_state_copy,sampler_state)
            sampler_state_list.append(copy.deepcopy(sampler_state))

        reporter = storage_file

        # making sure number of sampler states equals n_replicas
        if len(sampler_state_list) != n_replicas:
            # picking roughly evenly spaced sampler states
            # if n_replicas == 1, then it will pick the first in the list
            idx = np.round(np.linspace(0, len(sampler_state_list) - 1, n_replicas)).astype(int)
            sampler_state_list = [state for i,state in enumerate(sampler_state_list) if i in idx]

        assert len(sampler_state_list) == n_replicas

        if endstates:
            # generating unsampled endstates
            _logger.info('Generating unsampled endstates.')
            unsampled_dispersion_endstates = create_endstates(copy.deepcopy(thermodynamic_state_list[0]), copy.deepcopy(thermodynamic_state_list[-1]))
            self.create(thermodynamic_states=thermodynamic_state_list, sampler_states=sampler_state_list,
                    storage=reporter, unsampled_thermodynamic_states=unsampled_dispersion_endstates)
        else:
            self.create(thermodynamic_states=thermodynamic_state_list, sampler_states=sampler_state_list,
                        storage=reporter)


class HybridRepexSampler(HybridCompatibilityMixin, replicaexchange.ReplicaExchangeSampler):
    """
    ReplicaExchangeSampler that supports unsampled end states with a different number of positions
    """

    def __init__(self, *args, hybrid_factory=None, **kwargs):
        super(HybridRepexSampler, self).__init__(*args, hybrid_factory=hybrid_factory, **kwargs)
        self._factory = hybrid_factory

conducting subsequent work with the following platform: CUDA


In [9]:
reporter_file = os.path.join(f"{i}_{phase}.nc")
reporter = MultiStateReporter(reporter_file, checkpoint_interval=checkpoint_interval)
hss = HybridRepexSampler(mcmc_moves=mcmc.LangevinSplittingDynamicsMove(timestep= 4.0 * unit.femtoseconds,
                                                                      collision_rate=5.0 / unit.picosecond,
                                                                      n_steps=250,
                                                                      reassign_velocities=False,
                                                                      n_restart_attempts=20,
                                                                      splitting="V R R R O R R R V",
                                                                      constraint_tolerance=1e-06),
                                                                      replica_mixing_scheme='swap-all',
                                                                      hybrid_factory=htf, 
                                                                      online_analysis_interval=10)
hss.setup(n_states=n_states, temperature=300*unit.kelvin, storage_file=reporter, lambda_protocol=lambda_protocol)


INFO:multistate:n_replicas not defined, setting to match n_states, 12
INFO:multistate:Generating unsampled endstates.


Please cite the following:

        Friedrichs MS, Eastman P, Vaidyanathan V, Houston M, LeGrand S, Beberg AL, Ensign DL, Bruns CM, and Pande VS. Accelerating molecular dynamic simulations on graphics processing unit. J. Comput. Chem. 30:864, 2009. DOI: 10.1002/jcc.21209
        Eastman P and Pande VS. OpenMM: A hardware-independent framework for molecular simulations. Comput. Sci. Eng. 12:34, 2010. DOI: 10.1109/MCSE.2010.27
        Eastman P and Pande VS. Efficient nonbonded interactions for molecular dynamics on a graphics processing unit. J. Comput. Chem. 31:1268, 2010. DOI: 10.1002/jcc.21413
        Eastman P and Pande VS. Constant constraint matrix approximation: A robust, parallelizable constraint method for molecular simulations. J. Chem. Theor. Comput. 6:434, 2010. DOI: 10.1021/ct900463w
        Chodera JD and Shirts MR. Replica exchange and expanded ensemble simulations as Gibbs multistate: Simple improvements for enhanced mixing. J. Chem. Phys., 135:194110, 2011. DOI:10.1063/

In [11]:
_logger = logging.getLogger()
_logger.setLevel(logging.DEBUG)

In [12]:
hss.extend(n_cycles)


DEBUG:mpiplus.mpiplus:Single node: executing <bound method MultiStateSampler._store_options of <instance of HybridRepexSampler>>
DEBUG:openmmtools.multistate.multistatesampler:Storing general ReplicaExchange options...
DEBUG:openmmtools.multistate.multistatesampler:********************************************************************************
DEBUG:openmmtools.multistate.multistatesampler:Iteration 3/5002
DEBUG:openmmtools.multistate.multistatesampler:********************************************************************************
DEBUG:mpiplus.mpiplus:Single node: executing <function ReplicaExchangeSampler._mix_replicas at 0x2b3ec7849e50>
DEBUG:openmmtools.multistate.replicaexchange:Mixing replicas...
DEBUG:openmmtools.utils:Mixing of replicas took    0.000s
DEBUG:openmmtools.multistate.replicaexchange:Accepted 402/3456 attempted swaps (11.6%)
DEBUG:openmmtools.multistate.multistatesampler:Propagating all replicas...
DEBUG:mpiplus.mpiplus:Running _propagate_replica serially.
DEBUG:m

DEBUG:openmmtools.utils:Computing energy matrix took    2.076s
DEBUG:mpiplus.mpiplus:Single node: executing <function MultiStateSampler._report_iteration at 0x2b3ec1ad34c0>
DEBUG:mpiplus.mpiplus:Single node: executing <function MultiStateSampler._report_iteration_items at 0x2b3ec1ad3790>
DEBUG:openmmtools.multistate.multistatereporter:Iteration 6 not on the Checkpoint Interval of 10. Sampler State not written.
DEBUG:openmmtools.utils:Storing sampler states took    0.003s
DEBUG:openmmtools.utils:Writing iteration information to storage took    0.059s
DEBUG:openmmtools.multistate.multistatesampler:Not enough iterations for online analysis (self.online_analysis_minimum_iterations = 200)
DEBUG:mpiplus.mpiplus:Single node: executing <function MultiStateSampler._online_analysis at 0x2b3ec1ad5310>
DEBUG:openmmtools.multistate.multistatesampler:*** ONLINE analysis free energies:
DEBUG:openmmtools.multistate.multistatesampler:        -0.0    12.8    12.8    12.8    12.8    12.8    12.8    12.8 

DEBUG:openmmtools.multistate.multistatesampler:********************************************************************************
DEBUG:openmmtools.multistate.multistatesampler:Iteration 10/5002
DEBUG:openmmtools.multistate.multistatesampler:********************************************************************************
DEBUG:mpiplus.mpiplus:Single node: executing <function ReplicaExchangeSampler._mix_replicas at 0x2b3ec7849e50>
DEBUG:openmmtools.multistate.replicaexchange:Mixing replicas...
DEBUG:openmmtools.utils:Mixing of replicas took    0.000s
DEBUG:openmmtools.multistate.replicaexchange:Accepted 406/3456 attempted swaps (11.7%)
DEBUG:openmmtools.multistate.multistatesampler:Propagating all replicas...
DEBUG:mpiplus.mpiplus:Running _propagate_replica serially.


KeyboardInterrupt: 