In [1]:
###########################################
# IMPORTS
###########################################
from openmmtools.states import SamplerState, ThermodynamicState, CompoundThermodynamicState
from simtk import unit, openmm
from perses.tests.utils import compute_potential_components
from openmmtools.constants import kB
from perses.dispersed.utils import configure_platform
from perses.annihilation.rest import RESTTopologyFactory, RESTTopologyFactoryV2
from perses.annihilation.lambda_protocol import RESTState, RESTStateV2
import numpy as np
from perses.tests.test_topology_proposal import generate_atp, generate_dipeptide_top_pos_sys
from openmmtools.testsystems import AlanineDipeptideVacuum, AlanineDipeptideExplicit
import itertools

#############################################
# CONSTANTS
#############################################
temperature = 300.0 * unit.kelvin
kT = kB * temperature
beta = 1.0/kT
REFERENCE_PLATFORM = openmm.Platform.getPlatformByName("CUDA")

conducting subsequent work with the following platform: CUDA


INFO:rdkit:Enabling RDKit 2021.03.3 jupyter extensions


In [2]:
def compare_energies(topology, REST_system, other_system, positions, rest_atoms, T_min, T):

    # Get solvent atoms
    positive_ion_name = "NA"
    negative_ion_name = "CL"
    water_name = "HOH"
    solvent_atoms = []
    if 'openmm' in topology.__module__:
        atoms = topology.atoms()
    elif 'mdtraj' in topology.__module__:
        atoms = topology.atoms
    for atom in atoms:
        if atom.residue.name == positive_ion_name or atom.residue.name == negative_ion_name or atom.residue.name == water_name:
            solvent_atoms.append(atom.index)
    
    # Create thermodynamic state
    lambda_zero_alchemical_state = RESTStateV2.from_system(REST_system)
    thermostate = ThermodynamicState(REST_system, temperature=T_min)
    compound_thermodynamic_state = CompoundThermodynamicState(thermostate,
                                                              composable_states=[lambda_zero_alchemical_state])
    
    # Set alchemical parameters
    beta_0 = 1 / (kB * T_min)
    beta_m = 1 / (kB * T)
    compound_thermodynamic_state.set_alchemical_parameters(beta_0, beta_m)
    print("beta_0: ", beta_0)
    print("beta_m: ", beta_m)
    
    # Minimize and save energy
    integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
    context = compound_thermodynamic_state.create_context(integrator)
    context.setPositions(positions)
    sampler_state = SamplerState.from_context(context)
    REST_energy = compound_thermodynamic_state.reduced_potential(sampler_state)
    
    # Get energy components for rest system
    components_rest = [component for component in compute_potential_components(context, beta=beta)]
    print(components_rest)
    
    # Compute energy for non-RESTified system
    # Determine regions and scaling factors
    nonrest_atoms = [i for i in range(other_system.getNumParticles()) if i not in rest_atoms and i not in solvent_atoms]
    rest_scaling = beta_m / beta_0
    inter_scaling = np.sqrt(beta_m / beta_0)

    # Scale the terms in the bond force appropriately
    bond_force = other_system.getForce(0)
    for bond in range(bond_force.getNumBonds()):
        p1, p2, length, k = bond_force.getBondParameters(bond)
        if p1 in rest_atoms and p2 in rest_atoms:
            bond_force.setBondParameters(bond, p1, p2, length, k * rest_scaling)
        elif (p1 in rest_atoms and p2 in nonrest_atoms) or (p1 in nonrest_atoms and p2 in rest_atoms):
            bond_force.setBondParameters(bond, p1, p2, length, k * inter_scaling)

    # Scale the terms in the angle force appropriately
    angle_force = other_system.getForce(1)
    for angle_index in range(angle_force.getNumAngles()):
        p1, p2, p3, angle, k = angle_force.getAngleParameters(angle_index)
        if p1 in rest_atoms and p2 in rest_atoms and p3 in rest_atoms:
            angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k * rest_scaling)
        elif set([p1, p2, p3]).intersection(set(rest_atoms)) != set() and set([p1, p2, p3]).intersection(set(nonrest_atoms)) != set():
            angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k * inter_scaling)

    # Scale the terms in the torsion force appropriately
    torsion_force = other_system.getForce(2)
    for torsion_index in range(torsion_force.getNumTorsions()):
        p1, p2, p3, p4, periodicity, phase, k = torsion_force.getTorsionParameters(torsion_index)
        if p1 in rest_atoms and p2 in rest_atoms and p3 in rest_atoms and p4 in rest_atoms:
            torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k * rest_scaling)
        elif set([p1, p2, p3, p4]).intersection(set(rest_atoms)) != set() and set([p1, p2, p3, p4]).intersection(set(nonrest_atoms)) != set():
            torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k * inter_scaling)

    # Scale the exceptions in the nonbonded force appropriately
    nb_force = other_system.getForce(3)
    for nb_index in range(nb_force.getNumExceptions()):
        p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(nb_index)
        if (p1 in rest_atoms and p2 in rest_atoms):
            nb_force.setExceptionParameters(nb_index, p1, p2, rest_scaling * chargeProd, sigma, rest_scaling * epsilon)
        elif (p1 in rest_atoms and p2 in nonrest_atoms) or (p1 in nonrest_atoms and p2 in rest_atoms):
            nb_force.setExceptionParameters(nb_index, p1, p2, inter_scaling * chargeProd, sigma, inter_scaling * epsilon)

    # Scale nonbonded interactions for rest-rest region by adding exceptions for all pairs of atoms
    exception_pairs = [tuple(sorted([nb_force.getExceptionParameters(nb_index)[0], nb_force.getExceptionParameters(nb_index)[1]])) for nb_index in range(nb_force.getNumExceptions())]
    rest_pairs = set([tuple(sorted(pair)) for pair in list(itertools.product(rest_atoms, rest_atoms))])
    for pair in list(rest_pairs):
        p1 = pair[0]
        p2 = pair[1]
        p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
        p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
        if p1 != p2:
            if pair not in exception_pairs:
                nb_force.addException(p1, p2, p1_charge * p2_charge * rest_scaling, 0.5 * (p1_sigma + p2_sigma),
                                      np.sqrt(p1_epsilon * p2_epsilon) * rest_scaling)

    # Scale nonbonded interactions for inter region by adding exceptions for all pairs of atoms
    for pair in list(itertools.product(rest_atoms, nonrest_atoms)):
        p1 = pair[0]
        p2 = int(pair[1])  # otherwise, will be a numpy int
        p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
        p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
        if tuple(sorted(pair)) not in exception_pairs:
            nb_force.addException(p1, p2, p1_charge * p2_charge * inter_scaling, 0.5 * (p1_sigma + p2_sigma), np.sqrt(p1_epsilon * p2_epsilon) * inter_scaling)

    # Scale nonbonded interactions for rest-water region by adding exceptions for all pairs of atoms
    for pair in list(itertools.product(rest_atoms, solvent_atoms)):
        p1 = pair[0]
        p2 = pair[1]
        p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
        p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
        nb_force.addException(p1, p2, p1_charge * p2_charge * rest_scaling, 0.5 * (p1_sigma + p2_sigma), np.sqrt(p1_epsilon * p2_epsilon) * rest_scaling)
            
    # Get energy
    thermostate = ThermodynamicState(other_system, temperature=T_min)
    integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
    context = thermostate.create_context(integrator)
    context.setPositions(positions)
    sampler_state = SamplerState.from_context(context)
    nonREST_energy = thermostate.reduced_potential(sampler_state)
    
    print(f"Energies: {REST_energy} (rest), {nonREST_energy} (nonrest)")
    print(f"Discrepancy: {abs(REST_energy - nonREST_energy)}")
    assert abs(REST_energy - nonREST_energy) < 1, f"The energy of the REST system ({REST_energy}) does not match " \
                                                        f"that of the non-REST system with terms manually scaled according to REST2({nonREST_energy})."
    
    print("Success!")


In [11]:
def test_energy_scaling():
    """
        Test whether the energy of a REST-ified system is equal to the energy of the system with terms manually scaled by
        the same factor as is used in REST.  T_min is 298 K and the thermodynamic state has temperature 600 K.
    """

    # Set temperatures
    T_min = 300.0 * unit.kelvin
    T = 1200 * unit.kelvin

#     ## CASE 1: alanine dipeptide in vacuum
#     # Create vanilla system
#     ala = AlanineDipeptideVacuum()
#     system = ala.system
#     system.removeForce(4)
#     positions = ala.positions
#     topology = ala.topology

#     # Create REST system
#     res1 = list(ala.topology.residues())[1]
#     rest_atoms = [atom.index for atom in res1.atoms()]
#     factory = RESTTopologyFactoryV2(system, topology, rest_region=rest_atoms)
#     REST_system = factory.REST_system

#     # Check energy scaling
#     compare_energies(topology, REST_system, system, positions, rest_atoms, T_min, T)

    ## CASE 2: alanine dipeptide in solvent
    # Create vanilla system
    ala = AlanineDipeptideExplicit()
    system = ala.system
    system.removeForce(4)
#     system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)
    positions = ala.positions
    topology = ala.topology

    # Create REST system
    
    res1 = list(ala.topology.residues())[1]
    rest_atoms = [atom.index for atom in res1.atoms()]
    factory = RESTTopologyFactoryV2(system, topology, rest_region=rest_atoms, use_dispersion_correction=True)
    REST_system = factory.REST_system
#     REST_system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)
#     REST_system.getForce(4).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)


    # Check energy scaling
    compare_energies(topology, REST_system, system, positions, rest_atoms, T_min, T)

    ## CASE 3: alanine dipeptide in solvent with repartitioned hybrid system
    # Create repartitioned hybrid system for lambda 0 endstate
    atp, system_generator = generate_atp(phase='solvent')
    htf = generate_dipeptide_top_pos_sys(atp.topology,
                                         new_res='THR',
                                         system=atp.system,
                                         positions=atp.positions,
                                         system_generator=system_generator,
                                         conduct_htf_prop=True,
                                         repartitioned=True,
                                         endstate=0,
                                         validate_endstate_energy=False)
    system = htf.hybrid_system
    system.removeForce(0) # Remove barostat
#     system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)

    # Create REST-ified hybrid system
    res1 = list(htf.hybrid_topology.residues)[1]
    rest_atoms = [atom.index for atom in list(res1.atoms)]
    factory = RESTTopologyFactoryV2(system, htf.hybrid_topology, rest_region=rest_atoms, use_dispersion_correction=True)
    REST_system = factory.REST_system
#     REST_system.getForce(3).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)
#     REST_system.getForce(4).setNonbondedMethod(openmm.NonbondedForce.NoCutoff)

    # Check energy scaling
    compare_energies(htf.hybrid_topology, REST_system, system, htf.hybrid_positions, rest_atoms, T_min, T)

In [4]:
test_energy_scaling()

INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=2.0, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=2.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=2.0), unit=nanometer)]
INFO:REST:No unknown forces.
INFO:REST:Handling constraints
INFO:REST:Handling bonds
INFO:REST:Handling angles
INFO:REST:Handling torsions
INFO:REST:Handling nonbondeds
INFO:REST:Handling nonbonded scaling (custom nb)
INFO:REST:Handling nonbonded exception scaling (custom bond)


beta_0:  0.00040090785014242015 mol/J
beta_m:  0.00010022696253560504 mol/J
conducting subsequent work with the following platform: CUDA
[('CustomBondForce', 0.03416160942630912), ('CustomAngleForce', 0.2299864673818366), ('CustomTorsionForce', 1.6138417471377915), ('NonbondedForce', -39.186929599617706), ('CustomNonbondedForce', 75.67604299428817), ('CustomBondForce', -74.75726575321863), ('AndersenThermostat', 0.0)]
Energies: -36.390157145634475 (rest), -36.39016971661439 (nonrest)
Discrepancy: 1.2570979912140956e-05
Success!


INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=3.2852863, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=3.2861648000000003, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=3.1855098), unit=nanometer)]
INFO:REST:No unknown forces.
INFO:REST:Handling constraints
INFO:REST:Handling bonds
INFO:REST:Handling angles
INFO:REST:Handling torsions
INFO:REST:Handling nonbondeds
INFO:REST:Handling nonbonded scaling (custom nb)
INFO:REST:Handling nonbonded exception scaling (custom bond)


beta_0:  0.00040090785014242015 mol/J
beta_m:  0.00010022696253560504 mol/J
conducting subsequent work with the following platform: CUDA
[('CustomBondForce', 0.034161591732089575), ('CustomAngleForce', 0.22998655020010014), ('CustomTorsionForce', 1.6138411377552013), ('NonbondedForce', -9889.978580573506), ('CustomNonbondedForce', 51.95404623924919), ('CustomBondForce', -74.7572664138298), ('AndersenThermostat', 0.0)]
Energies: -9910.903811468392 (rest), -9873.711205981175 (nonrest)
Discrepancy: 37.1926054872165


AssertionError: The energy of the REST system (-9910.903811468392) does not match that of the non-REST system with terms manually scaled according to REST2(-9873.711205981175).

# Run scaling test v2 (allows testing PME)

In [6]:
import copy


In [7]:
def compare_energies(topology, REST_system, other_system, positions, rest_atoms, T_min, T):

    # Create thermodynamic state
    lambda_zero_alchemical_state = RESTStateV2.from_system(REST_system)
    thermostate = ThermodynamicState(REST_system, temperature=T_min)
    compound_thermodynamic_state = CompoundThermodynamicState(thermostate,
                                                              composable_states=[lambda_zero_alchemical_state])

    # Set alchemical parameters
    beta_0 = 1 / (kB * T_min)
    beta_m = 1 / (kB * T)
    compound_thermodynamic_state.set_alchemical_parameters(beta_0, beta_m)

    # Get energy
    integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
    context = compound_thermodynamic_state.create_context(integrator)
    context.setPositions(positions)
    sampler_state = SamplerState.from_context(context)
    REST_energy = compound_thermodynamic_state.reduced_potential(sampler_state)

    # Get solvent atoms
    positive_ion_name = "NA"
    negative_ion_name = "CL"
    water_name = "HOH"
    solvent_atoms = []
    if 'openmm' in topology.__module__:
        atoms = topology.atoms()
    elif 'mdtraj' in topology.__module__:
        atoms = topology.atoms
    for atom in atoms:
        if atom.residue.name == positive_ion_name or atom.residue.name == negative_ion_name or atom.residue.name == water_name:
            solvent_atoms.append(atom.index)
    
    # Compute energy for non-RESTified system
    # Determine regions and scaling factors
    nonrest_atoms = [i for i in range(other_system.getNumParticles()) if i not in rest_atoms]
    rest_scaling = beta_m / beta_0
    inter_scaling = np.sqrt(beta_m / beta_0)
    
    # Compute 3 energies: 
    ## Test case 0: rest-rest energy
    ## Test case 1: nonrest-nonrest energy
    ## Test case 2: total energy
    ## Test case 3: nonrest_solvent - rest (this will only be a nonbonded energy)
    ## Test case 4: nonrest_solvent - nonrest-solvent (this will only be a nonbonded energy)
    unmodified_energies = []
    for test_case in range(5):
        system_copy = copy.deepcopy(other_system)
        
        if test_case != 2:
            bond_force = system_copy.getForce(0)
            for bond_index in range(bond_force.getNumBonds()):
                p1, p2, length, k = bond_force.getBondParameters(bond_index)
                particles = [p1, p2]
                if all(x in rest_atoms for x in particles):
                    if test_case in [1, 3, 4]:
                        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case in [0, 3, 4]:
                        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
                    else:
                        continue
                else:
                    bond_force.setBondParameters(bond_index, p1, p2, length, k*0)

            angle_force = system_copy.getForce(1)
            for angle_index in range(angle_force.getNumAngles()):
                p1, p2, p3, angle, k = angle_force.getAngleParameters(angle_index)
                particles = [p1, p2, p3]
                if all(x in rest_atoms for x in particles):
                    if test_case in [1, 3, 4]:
                        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case in [0, 3, 4]:
                        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 
                    else:
                        continue
                else:
                    angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 

            torsion_force = system_copy.getForce(2)
            for torsion_index in range(torsion_force.getNumTorsions()):
                p1, p2, p3, p4, periodicity, phase, k = torsion_force.getTorsionParameters(torsion_index)
                particles = [p1, p2, p3, p4]
                if all(x in rest_atoms for x in particles):
                    if test_case in [1, 3, 4]:
                        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case in [0, 3, 4]:
                        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
                    else:
                        continue
                else:
                    torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)

            nb_force = system_copy.getForce(3)
            for i in range(nb_force.getNumExceptions()):
                p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
                particles = [p1, p2]
                if all(x in rest_atoms for x in particles):
                    if test_case in [1, 4]:
                        nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case in [0]:
                        nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
                    elif test_case in [3, 4]:
                        if all(x in solvent_atoms for x in particles):
                            continue
                        else:
                            nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
                    else:
                        continue
                else:
                    if test_case == 3:
                        if any(x in solvent_atoms for x in particles):
                            continue
                        else:
                            nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
                    else:
                        nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)

            nb_force = system_copy.getForce(3)
            for i in range(nb_force.getNumParticles()):
                charge, sigma, epsilon = nb_force.getParticleParameters(i)
                if i in rest_atoms:
                    if test_case in [1, 4]:
                        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
                    else:
                        continue
                else:
                    if test_case == 0:
                        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
                        
                    elif test_case in [3, 4]:
                        if i in solvent_atoms:
                            continue
                        else:
                            nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
                            
                    else:
                        continue     
            
        # Get energy
        thermostate = ThermodynamicState(system_copy, temperature=T_min)
        integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
        context = thermostate.create_context(integrator)
        context.setPositions(positions)
        sampler_state = SamplerState.from_context(context)
        system_copy_energy = thermostate.reduced_potential(sampler_state)
        unmodified_energies.append(system_copy_energy)
    
    nonrest_solvent_rest = (unmodified_energies[3] - unmodified_energies[0] - unmodified_energies[4]) * rest_scaling
    nonrest_protein_rest = (unmodified_energies[2] - unmodified_energies[0] - unmodified_energies[1] - nonrest_solvent_rest) * inter_scaling
    unmodified_energy = unmodified_energies[0] * rest_scaling + unmodified_energies[1] +  nonrest_solvent_rest + nonrest_protein_rest
    print(REST_energy)
    print(unmodified_energy)
    assert np.isclose(REST_energy, unmodified_energy), f"REST energy was {REST_energy} and unmodified_energy was {unmodified_energy}"
                            

In [12]:
test_energy_scaling()

INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=3.2852863, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=3.2861648000000003, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=3.1855098), unit=nanometer)]
INFO:REST:No unknown forces.
INFO:REST:Handling constraints
INFO:REST:Handling bonds
INFO:REST:Handling angles
INFO:REST:Handling torsions
INFO:REST:Handling nonbondeds
INFO:REST:Handling nonbonded scaling (custom nb)
INFO:REST:Handling nonbonded exception scaling (custom bond)


-9910.903811468392
-9880.228257241612


AssertionError: REST energy was -9910.903811468392 and unmodified_energy was -9880.228257241612