In [2]:
import pickle
import os
from openeye import oechem
###########################################
# IMPORTS
###########################################
from openmmtools.states import SamplerState, ThermodynamicState, CompoundThermodynamicState
from simtk import unit, openmm
from perses.tests.utils import compute_potential_components
from openmmtools.constants import kB
from perses.dispersed.utils import configure_platform
from perses.annihilation.rest import RESTTopologyFactory, RESTTopologyFactoryV2
from perses.annihilation.lambda_protocol import RESTState, RESTStateV2
import numpy as np
from perses.tests.test_topology_proposal import generate_atp, generate_dipeptide_top_pos_sys
from openmmtools.testsystems import AlanineDipeptideVacuum, AlanineDipeptideExplicit
import itertools


#############################################
# CONSTANTS
#############################################
temperature = 300.0 * unit.kelvin
kT = kB * temperature
beta = 1.0/kT
REFERENCE_PLATFORM = openmm.Platform.getPlatformByName("CUDA")

conducting subsequent work with the following platform: CUDA


INFO:rdkit:Enabling RDKit 2021.03.3 jupyter extensions


In [3]:
def compare_energy_components(rest_system, other_system, positions, platform=REFERENCE_PLATFORM):
    """
    Get energy components of a given system
    """
    platform = configure_platform(platform)

    # Create thermodynamic state and sampler state for non-rest system
    thermostate_other = ThermodynamicState(system=other_system, temperature=temperature)

    # Create context for non-rest system
    integrator_other = openmm.VerletIntegrator(1.0*unit.femtosecond)
    context_other = thermostate_other.create_context(integrator_other)
    context_other.setPositions(positions)

    # Get energy components for non-rest system
    components_other = [component for component in compute_potential_components(context_other, beta=beta)]
    
    # Create thermodynamic state for rest_system
    thermostate_rest = ThermodynamicState(system=rest_system, temperature=temperature)

    # Create context for rest system
    integrator_rest = openmm.VerletIntegrator(1.0 * unit.femtosecond)
    context_rest = thermostate_rest.create_context(integrator_rest)
    context_rest.setPositions(positions)

    # Get energy components for rest system
    components_rest = [component for component in compute_potential_components(context_rest, beta=beta)]

    print(components_other)
    print(components_rest)
    
    # Check that bond, angle, and torsion energies match
    for other, rest in zip(components_other[:3], components_rest[:3]):
        assert np.isclose([other[1]], [rest[1]]), f"The energies do not match for the {other[0]}: {other[1]} (other system) vs. {rest[1]} (REST system)"

    # Check that nonbonded energies match
    nonbonded_other = np.array([component[1] for component in components_other[3:]]).sum()
    nonbonded_rest = np.array([component[1] for component in components_rest[3:]]).sum()
    assert np.isclose([nonbonded_other], [nonbonded_rest]), f"The energies do not match for the NonbondedForce: {nonbonded_other} (other system) vs. {nonbonded_rest} (REST system)"

    print("Energy bookkeeping was a success!")


In [4]:
def compare_energies(topology, REST_system, other_system, positions, rest_atoms, T_min, T):

    # Get solvent atoms
    positive_ion_name = "NA"
    negative_ion_name = "CL"
    water_name = "HOH"
    solvent_atoms = []
    if 'openmm' in topology.__module__:
        atoms = topology.atoms()
    elif 'mdtraj' in topology.__module__:
        atoms = topology.atoms
    for atom in atoms:
        if atom.residue.name == positive_ion_name or atom.residue.name == negative_ion_name or atom.residue.name == water_name:
            solvent_atoms.append(atom.index)
    
    # Create thermodynamic state
    lambda_zero_alchemical_state = RESTStateV2.from_system(REST_system)
    thermostate = ThermodynamicState(REST_system, temperature=T_min)
    compound_thermodynamic_state = CompoundThermodynamicState(thermostate,
                                                              composable_states=[lambda_zero_alchemical_state])
    
    # Set alchemical parameters
    beta_0 = 1 / (kB * T_min)
    beta_m = 1 / (kB * T)
    compound_thermodynamic_state.set_alchemical_parameters(beta_0, beta_m)
    print("beta_0: ", beta_0)
    print("beta_m: ", beta_m)
    
    # Minimize and save energy
    integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
    context = compound_thermodynamic_state.create_context(integrator)
    context.setPositions(positions)
    sampler_state = SamplerState.from_context(context)
    REST_energy = compound_thermodynamic_state.reduced_potential(sampler_state)
    
    # Get energy components for rest system
    components_rest = [component for component in compute_potential_components(context, beta=beta)]
    print(components_rest)
    
    # Compute energy for non-RESTified system
    # Determine regions and scaling factors
    nonrest_atoms = [i for i in range(other_system.getNumParticles()) if i not in rest_atoms and i not in solvent_atoms]
    rest_scaling = beta_m / beta_0
    inter_scaling = np.sqrt(beta_m / beta_0)

    # Scale the terms in the bond force appropriately
    bond_force = other_system.getForce(0)
    for bond in range(bond_force.getNumBonds()):
        p1, p2, length, k = bond_force.getBondParameters(bond)
        if p1 in rest_atoms and p2 in rest_atoms:
            bond_force.setBondParameters(bond, p1, p2, length, k * rest_scaling)
        elif (p1 in rest_atoms and p2 in nonrest_atoms) or (p1 in nonrest_atoms and p2 in rest_atoms):
            bond_force.setBondParameters(bond, p1, p2, length, k * inter_scaling)

    # Scale the terms in the angle force appropriately
    angle_force = other_system.getForce(1)
    for angle_index in range(angle_force.getNumAngles()):
        p1, p2, p3, angle, k = angle_force.getAngleParameters(angle_index)
        if p1 in rest_atoms and p2 in rest_atoms and p3 in rest_atoms:
            angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k * rest_scaling)
        elif set([p1, p2, p3]).intersection(set(rest_atoms)) != set() and set([p1, p2, p3]).intersection(set(nonrest_atoms)) != set():
            angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k * inter_scaling)

    # Scale the terms in the torsion force appropriately
    torsion_force = other_system.getForce(2)
    for torsion_index in range(torsion_force.getNumTorsions()):
        p1, p2, p3, p4, periodicity, phase, k = torsion_force.getTorsionParameters(torsion_index)
        if p1 in rest_atoms and p2 in rest_atoms and p3 in rest_atoms and p4 in rest_atoms:
            torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k * rest_scaling)
        elif set([p1, p2, p3, p4]).intersection(set(rest_atoms)) != set() and set([p1, p2, p3, p4]).intersection(set(nonrest_atoms)) != set():
            torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k * inter_scaling)

    # Scale the exceptions in the nonbonded force appropriately
    nb_force = other_system.getForce(3)
    for nb_index in range(nb_force.getNumExceptions()):
        p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(nb_index)
        if (p1 in rest_atoms and p2 in rest_atoms):
            nb_force.setExceptionParameters(nb_index, p1, p2, rest_scaling * chargeProd, sigma, rest_scaling * epsilon)
        elif (p1 in rest_atoms and p2 in nonrest_atoms) or (p1 in nonrest_atoms and p2 in rest_atoms):
            nb_force.setExceptionParameters(nb_index, p1, p2, inter_scaling * chargeProd, sigma, inter_scaling * epsilon)

    # Scale nonbonded interactions for rest-rest region by adding exceptions for all pairs of atoms
    exception_pairs = [tuple(sorted([nb_force.getExceptionParameters(nb_index)[0], nb_force.getExceptionParameters(nb_index)[1]])) for nb_index in range(nb_force.getNumExceptions())]
    rest_pairs = set([tuple(sorted(pair)) for pair in list(itertools.product(rest_atoms, rest_atoms))])
    for pair in list(rest_pairs):
        p1 = pair[0]
        p2 = pair[1]
        p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
        p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
        if p1 != p2:
            if pair not in exception_pairs:
                nb_force.addException(p1, p2, p1_charge * p2_charge * rest_scaling, 0.5 * (p1_sigma + p2_sigma),
                                      np.sqrt(p1_epsilon * p2_epsilon) * rest_scaling)

    # Scale nonbonded interactions for inter region by adding exceptions for all pairs of atoms
    for pair in list(itertools.product(rest_atoms, nonrest_atoms)):
        p1 = pair[0]
        p2 = int(pair[1])  # otherwise, will be a numpy int
        p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
        p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
        if tuple(sorted(pair)) not in exception_pairs:
            nb_force.addException(p1, p2, p1_charge * p2_charge * inter_scaling, 0.5 * (p1_sigma + p2_sigma), np.sqrt(p1_epsilon * p2_epsilon) * inter_scaling)

    # Scale nonbonded interactions for rest-water region by adding exceptions for all pairs of atoms
    for pair in list(itertools.product(rest_atoms, solvent_atoms)):
        p1 = pair[0]
        p2 = pair[1]
        p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
        p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
        nb_force.addException(p1, p2, p1_charge * p2_charge * rest_scaling, 0.5 * (p1_sigma + p2_sigma), np.sqrt(p1_epsilon * p2_epsilon) * rest_scaling)
            
    # Get energy
    thermostate = ThermodynamicState(other_system, temperature=T_min)
    integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
    context = thermostate.create_context(integrator)
    context.setPositions(positions)
    sampler_state = SamplerState.from_context(context)
    nonREST_energy = thermostate.reduced_potential(sampler_state)
    
    print(f"Energies: {REST_energy} (rest), {nonREST_energy} (nonrest)")
    print(f"Discrepancy: {abs(REST_energy - nonREST_energy)}")
    assert abs(REST_energy - nonREST_energy) < 1, f"The energy of the REST system ({REST_energy}) does not match " \
                                                        f"that of the non-REST system with terms manually scaled according to REST2({nonREST_energy})."
    
    print("Success!")


# Test endstate validation (energy bookkeeping) for production systems

In [2]:
# Load r-htf for N501Y
with open("/data/chodera/zhangi/perses_benchmark/neq/14/98/98_complex_0.pickle", "rb") as f:
    htf = pickle.load(f)
    

In [3]:
htf.hybrid_system.getForces()


[<simtk.openmm.openmm.MonteCarloBarostat; proxy of <Swig Object of type 'OpenMM::MonteCarloBarostat *' at 0x2b0bd9f79780> >,
 <simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x2b0bd9f88660> >,
 <simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x2b0cee891e70> >,
 <simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x2b0cee891ed0> >,
 <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x2b0cee891fc0> >]

In [4]:
# Get unmodified hybrid system
positions = htf.hybrid_positions
system = htf.hybrid_system
system.removeForce(0)
# system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)



In [5]:
system.getForce(3).getNonbondedMethod()

4

In [6]:
openmm.NonbondedForce.PME

4

In [7]:
system.getForce(3).getUseDispersionCorrection()

True

In [17]:
# Create REST-ified hybrid system
for res in htf.hybrid_topology.residues:
    if res.resSeq == 501 and res.chain.index == 0:
        mutated_res = res
query_indices = [atom.index for atom in mutated_res.atoms]
factory = RESTTopologyFactoryV2(system, htf.hybrid_topology, rest_region=query_indices, use_dispersion_correction=True)
REST_system = factory.REST_system
# REST_system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)
# REST_system.getForce(4).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)


INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=13.804333500000002, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=-4.601444128722184, y=13.014850568080696, z=0.0), unit=nanometer), Quantity(value=Vec3(x=-4.601444128722184, y=-6.507424496441187, z=11.271191673136768), unit=nanometer)]
INFO:REST:No unknown forces.
INFO:REST:Handling constraints
INFO:REST:Handling bonds
INFO:REST:Handling angles
INFO:REST:Handling torsions
INFO:REST:Handling nonbondeds
INFO:REST:Handling nonbonded scaling (custom nb)
INFO:REST:Handling nonbonded exception scaling (custom bond)


In [18]:
REST_system.getForce(3).getNonbondedMethod()

4

In [19]:
openmm.NonbondedForce.CutoffPeriodic

2

In [20]:
REST_system.getForce(4).getNonbondedMethod()

2

In [21]:
REST_system.getForce(3).getUseDispersionCorrection()

True

In [22]:
REST_system.getForce(4).getUseLongRangeCorrection()

False

In [25]:
compare_energy_components(REST_system, system, positions)


conducting subsequent work with the following platform: CPU
conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA
[('HarmonicBondForce', 1661.2602820430661), ('HarmonicAngleForce', 4324.100327914152), ('PeriodicTorsionForce', 17558.905814415008), ('NonbondedForce', -841612.1463387645), ('AndersenThermostat', 0.0)]
[('CustomBondForce', 1661.2602820430661), ('CustomAngleForce', 4324.100327914152), ('CustomTorsionForce', 17558.906930616802), ('NonbondedForce', -841612.1463387645), ('CustomNonbondedForce', 0.0), ('CustomBondForce', 0.0), ('AndersenThermostat', 0.0)]
Energy bookkeeping was a success!


# Test energy scaling with NoCutoff

In [14]:
# Get unmodified hybrid system
positions = htf.hybrid_positions
system = htf.hybrid_system
system.removeForce(0)
system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)



In [15]:
# Create REST-ified hybrid system
for res in htf.hybrid_topology.residues:
    if res.resSeq == 501 and res.chain.index == 0:
        mutated_res = res
query_indices = [atom.index for atom in mutated_res.atoms]
factory = RESTTopologyFactoryV2(system, htf.hybrid_topology, rest_region=query_indices)
REST_system = factory.REST_system
REST_system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)
REST_system.getForce(4).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)


INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=13.804333500000002, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=-4.601444128722184, y=13.014850568080696, z=0.0), unit=nanometer), Quantity(value=Vec3(x=-4.601444128722184, y=-6.507424496441187, z=11.271191673136768), unit=nanometer)]
INFO:REST:No unknown forces.
INFO:REST:Handling constraints
INFO:REST:Handling bonds
INFO:REST:Handling angles
INFO:REST:Handling torsions
INFO:REST:Handling nonbondeds
INFO:REST:Handling nonbonded scaling (custom nb)
INFO:REST:Handling nonbonded exception scaling (custom bond)


In [23]:
# Set temperatures
# This takes a really long time
T_min = 298.0 * unit.kelvin
T = 600 * unit.kelvin
compare_energies(htf.hybrid_topology, REST_system, system, htf.hybrid_positions, query_indices, T_min, T)

beta_0:  0.00040359850685478543 mol/J
beta_m:  0.00020045392507121008 mol/J
conducting subsequent work with the following platform: CUDA
[('CustomBondForce', 1659.071135086621), ('CustomAngleForce', 4321.036042287911), ('CustomTorsionForce', 17547.90256723848), ('NonbondedForce', -858858.1690164063), ('CustomNonbondedForce', 82.17060246149275), ('CustomBondForce', -0.36445388564614173), ('AndersenThermostat', 0.0)]
Energies: -840854.0467683397 (rest), -840854.049021062 (nonrest)
Discrepancy: 0.002252722275443375
Success!


# Test energy scaling with production run systems

In [20]:
# Load r-htf for N501Y
with open("/data/chodera/zhangi/perses_benchmark/neq/14/98/98_complex_0.pickle", "rb") as f:
    htf = pickle.load(f)
    

In [6]:
htf.hybrid_system.getForces()


[<simtk.openmm.openmm.MonteCarloBarostat; proxy of <Swig Object of type 'OpenMM::MonteCarloBarostat *' at 0x2b06f5fa6960> >,
 <simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x2b06f5fa6a20> >,
 <simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x2b06f5fa6a80> >,
 <simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x2b06f5fa6b40> >,
 <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x2b06f5fa6bd0> >]

In [21]:
# Get unmodified hybrid system
positions = htf.hybrid_positions
system = htf.hybrid_system
system.removeForce(0)
# system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)



In [8]:
system.getForce(3).getNonbondedMethod()

4

In [9]:
openmm.NonbondedForce.PME

4

In [10]:
system.getForce(3).getUseDispersionCorrection()

True

In [22]:
thermostate = ThermodynamicState(system, temperature=T_min)
integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
context = thermostate.create_context(integrator)
context.setPositions(positions)
sampler_state = SamplerState.from_context(context)
nonREST_energy = thermostate.reduced_potential(sampler_state)

In [23]:
nonREST_energy

-823558.2683701934

In [11]:
# Create REST-ified hybrid system
for res in htf.hybrid_topology.residues:
    if res.resSeq == 501 and res.chain.index == 0:
        mutated_res = res
query_indices = [atom.index for atom in mutated_res.atoms]
factory = RESTTopologyFactoryV2(system, htf.hybrid_topology, rest_region=query_indices, use_dispersion_correction=True)
REST_system = factory.REST_system
# REST_system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)
# REST_system.getForce(4).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)


INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=13.804333500000002, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=-4.601444128722184, y=13.014850568080696, z=0.0), unit=nanometer), Quantity(value=Vec3(x=-4.601444128722184, y=-6.507424496441187, z=11.271191673136768), unit=nanometer)]
INFO:REST:No unknown forces.
INFO:REST:Handling constraints
INFO:REST:Handling bonds
INFO:REST:Handling angles
INFO:REST:Handling torsions
INFO:REST:Handling nonbondeds
INFO:REST:Handling nonbonded scaling (custom nb)
INFO:REST:Handling nonbonded exception scaling (custom bond)


In [12]:
REST_system.getForce(3).getNonbondedMethod()

4

In [13]:
openmm.NonbondedForce.CutoffPeriodic

2

In [14]:
REST_system.getForce(4).getNonbondedMethod()

2

In [15]:
REST_system.getForce(3).getUseDispersionCorrection()

True

In [16]:
REST_system.getForce(4).getUseLongRangeCorrection()

False

In [17]:
# Set temperatures
# This takes a really long time
T_min = 298.0 * unit.kelvin
T = 600 * unit.kelvin
compare_energies(htf.hybrid_topology, REST_system, system, htf.hybrid_positions, query_indices, T_min, T)

beta_0:  0.00040359850685478543 mol/J
beta_m:  0.00020045392507121008 mol/J
conducting subsequent work with the following platform: CUDA
[('CustomBondForce', 1659.0711350866632), ('CustomAngleForce', 4321.03604228801), ('CustomTorsionForce', 17547.902807037044), ('NonbondedForce', -841612.1058550582), ('CustomNonbondedForce', 84.28177980505599), ('CustomBondForce', -0.364453885647623), ('AndersenThermostat', 0.0)]


SystemError: <built-in function Context_getState> returned NULL without setting an error

# Test energy validation for rest vs alchemical system

In [None]:
# Load r-htf for N501Y
with open("/data/chodera/zhangi/perses_benchmark/neq/14/98/98_complex.pickle", "rb") as f:
    htf = pickle.load(f)
    other_system = htf.hybrid_system
    

In [None]:
# Load r-htf for N501Y
with open("/data/chodera/zhangi/perses_benchmark/neq/14/98/98_complex_0.pickle", "rb") as f:
    rhtf = pickle.load(f)
    

In [None]:
# Read in REST snapshot
with open("/data/chodera/zhangi/perses_benchmark/neq/14/98/98_complex_asn_1ns_snapshots.npy", "rb") as f:
    cache = np.load(f)
    positions = cache[0]

In [None]:
htf.hybrid_system.getForces()


In [None]:
htf.hybrid_system.getForce(7).getUseDispersionCorrection()

In [None]:
htf.hybrid_system.getForce(8).getUseLongRangeCorrection()

In [None]:
platform = configure_platform(REFERENCE_PLATFORM)

# Create thermodynamic state and sampler state for non-rest system
thermostate_other = ThermodynamicState(system=other_system, temperature=temperature)

# Create context for non-rest system
integrator_other = openmm.VerletIntegrator(1.0*unit.femtosecond)
context_other = thermostate_other.create_context(integrator_other)
context_other.setPositions(positions)

# Get energy components for non-rest system
components_other = [component[1] for component in compute_potential_components(context_other, beta=beta)]
print(compute_potential_components(context_other, beta=beta))


In [None]:
# Create REST-ified hybrid system
for res in htf.hybrid_topology.residues:
    if res.resSeq == 501 and res.chain.index == 0:
        mutated_res = res
query_indices = [atom.index for atom in mutated_res.atoms]
factory = RESTTopologyFactoryV2(rhtf.hybrid_system, rhtf.hybrid_topology, rest_region=query_indices, use_dispersion_correction=True)
REST_system = factory.REST_system
# REST_system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)
# REST_system.getForce(4).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)


In [None]:
# Create thermodynamic state for rest_system
thermostate_rest = ThermodynamicState(system=REST_system, temperature=temperature)

# Create context for rest system
integrator_rest = openmm.VerletIntegrator(1.0 * unit.femtosecond)
context_rest = thermostate_rest.create_context(integrator_rest)
context_rest.setPositions(positions)

# Get energy components for rest system
components_rest = [component[1] for component in compute_potential_components(context_rest, beta=beta)]
print(compute_potential_components(context_rest, beta=beta))

In [None]:
assert np.isclose([components_other[0] + components_other[1]], [components_rest[0]])

In [None]:
assert np.isclose([components_other[2] + components_other[3]], [components_rest[1]])

In [None]:
assert np.isclose([components_other[4] + components_other[5]], [components_rest[2]])

In [None]:
assert np.isclose(components_other[6] + components_other[7] + components_other[8], components_rest[3] + components_rest[4] + components_rest[5]), f"The energies do not match for the {components_other[6] + components_other[7] + components_other[8]} (other system) vs. {components_rest[3]} (REST system)"



In [None]:
components_rest

In [None]:
components_other

In [None]:
components_other[6] + components_other[7] + components_other[8]

In [None]:
components_rest[3] + components_rest[4] + components_rest[5]

In [None]:
np.isclose(-1002461.6455134379, -1002464.5984964548)

# Test energy validation for rest vs alchemical system (John's recommended way)

In [None]:
# Load r-htf for N501Y
with open("/data/chodera/zhangi/perses_benchmark/neq/14/98/98_complex.pickle", "rb") as f:
    htf = pickle.load(f)
    other_system = htf.hybrid_system
    

In [None]:
# Load r-htf for N501Y
with open("/data/chodera/zhangi/perses_benchmark/neq/14/98/98_complex_0.pickle", "rb") as f:
    rhtf = pickle.load(f)
    

In [None]:
# Read in REST snapshot
with open("/data/chodera/zhangi/perses_benchmark/neq/14/98/98_complex_asn_1ns_snapshots.npy", "rb") as f:
    cache = np.load(f)
    positions = cache[0]

In [None]:
htf.hybrid_system.getForces()


In [None]:
htf.hybrid_system.getForce(7).getUseDispersionCorrection()

In [None]:
htf.hybrid_system.getForce(8).getUseLongRangeCorrection()

In [None]:
platform = configure_platform(REFERENCE_PLATFORM)

# Create thermodynamic state and sampler state for non-rest system
thermostate_other = ThermodynamicState(system=other_system, temperature=temperature)

# Create context for non-rest system
integrator_other = openmm.VerletIntegrator(1.0*unit.femtosecond)
context_other = thermostate_other.create_context(integrator_other)
context_other.setPositions(positions)

# Get energy components for non-rest system
components_other = [component[1] for component in compute_potential_components(context_other, beta=beta)]
print(compute_potential_components(context_other, beta=beta))


In [None]:
# Create REST-ified hybrid system
for res in htf.hybrid_topology.residues:
    if res.resSeq == 501 and res.chain.index == 0:
        mutated_res = res
query_indices = [atom.index for atom in mutated_res.atoms]
factory = RESTTopologyFactoryV2(rhtf.hybrid_system, rhtf.hybrid_topology, rest_region=query_indices, use_dispersion_correction=True)
REST_system = factory.REST_system
# REST_system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)
# REST_system.getForce(4).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)


In [None]:
# Create thermodynamic state for rest_system
thermostate_rest = ThermodynamicState(system=REST_system, temperature=temperature)

# Create context for rest system
integrator_rest = openmm.VerletIntegrator(1.0 * unit.femtosecond)
context_rest = thermostate_rest.create_context(integrator_rest)
context_rest.setPositions(positions)

# Get energy components for rest system
components_rest = [component[1] for component in compute_potential_components(context_rest, beta=beta)]
print(compute_potential_components(context_rest, beta=beta))

In [None]:
assert np.isclose([components_other[0] + components_other[1]], [components_rest[0]])

In [None]:
assert np.isclose([components_other[2] + components_other[3]], [components_rest[1]])

In [None]:
assert np.isclose([components_other[4] + components_other[5]], [components_rest[2]])

In [None]:
assert np.isclose(components_other[6] + components_other[7] + components_other[8], components_rest[3] + components_rest[4] + components_rest[5]), f"The energies do not match for the {components_other[6] + components_other[7] + components_other[8]} (other system) vs. {components_rest[3]} (REST system)"



In [None]:
components_rest

In [None]:
components_other

In [None]:
components_other[6] + components_other[7] + components_other[8]

In [None]:
components_rest[3] + components_rest[4] + components_rest[5]

In [None]:
np.isclose(-1002461.6455134379, -1002464.5984964548)