In [2]:
###########################################
# IMPORTS
###########################################
from openmmtools.states import SamplerState, ThermodynamicState, CompoundThermodynamicState
from simtk import unit, openmm
from perses.tests.utils import compute_potential_components
from openmmtools.constants import kB
from perses.dispersed.utils import configure_platform
from perses.annihilation.rest import RESTTopologyFactory
from perses.annihilation.lambda_protocol import RESTState, RESTCapableRelativeAlchemicalState, RESTCapableLambdaProtocol
import numpy as np
from perses.tests.test_topology_proposal import generate_atp, generate_dipeptide_top_pos_sys
from openmmtools.testsystems import AlanineDipeptideVacuum, AlanineDipeptideExplicit
import itertools
import copy

import pickle

#############################################
# CONSTANTS
#############################################
temperature = 300.0 * unit.kelvin
kT = kB * temperature
beta = 1.0/kT
REFERENCE_PLATFORM = openmm.Platform.getPlatformByName("CUDA")

INFO:rdkit:Enabling RDKit 2021.03.4 jupyter extensions


conducting subsequent work with the following platform: CUDA


# Run scaling test v2 (allows testing PME)

In [14]:
def test_energy_scaling():
    """
        Test whether the energy of a REST-ified system is equal to the energy of the system with terms manually scaled by
        the same factor as is used in REST.  T_min is 300 K and the thermodynamic state has temperature 600 K.
    """

    # Set temperatures
    T_min = 300.0 * unit.kelvin
    T = 600 * unit.kelvin

    ## CASE 1: alanine dipeptide in vacuum
    # Create vanilla system
    ala = AlanineDipeptideVacuum()
    system = ala.system
    system.removeForce(4)
    positions = ala.positions
    topology = ala.topology

    # Create REST system
    res1 = list(ala.topology.residues())[1]
    rest_atoms = [atom.index for atom in res1.atoms()]
    factory = RESTTopologyFactory(system, solute_region=rest_atoms)
    REST_system = factory.REST_system

    # Check energy scaling
    compare_energies(REST_system, system, positions, rest_atoms, T_min, T)

    ## CASE 2: alanine dipeptide in solvent
    # Create vanilla system
    ala = AlanineDipeptideExplicit()
    system = ala.system
    system.removeForce(4)
    positions = ala.positions
    topology = ala.topology

    # Create REST system
    
    res1 = list(ala.topology.residues())[1]
    rest_atoms = [atom.index for atom in res1.atoms()]
    factory = RESTTopologyFactory(system, solute_region=rest_atoms, use_dispersion_correction=True)
    REST_system = factory.REST_system

    # Check energy scaling
    compare_energies(REST_system, system, positions, rest_atoms, T_min, T)

    ## CASE 3: alanine dipeptide in solvent with repartitioned hybrid system
    # Create repartitioned hybrid system for lambda 0 endstate
    atp, system_generator = generate_atp(phase='solvent')
    htf = generate_dipeptide_top_pos_sys(atp.topology,
                                         new_res='THR',
                                         system=atp.system,
                                         positions=atp.positions,
                                         system_generator=system_generator,
                                         conduct_htf_prop=True,
                                         repartitioned=True,
                                         endstate=0,
                                         validate_endstate_energy=False)
    system = htf.hybrid_system
    system.removeForce(0) # Remove barostat

    # Create REST-ified hybrid system
    res1 = list(htf.hybrid_topology.residues)[1]
    rest_atoms = [atom.index for atom in list(res1.atoms)]
    factory = RESTTopologyFactory(system, solute_region=rest_atoms, use_dispersion_correction=True)
    REST_system = factory.REST_system

    # Check energy scaling
    compare_energies(REST_system, system, htf.hybrid_positions, rest_atoms, T_min, T)

In [15]:
def compare_energies(REST_system, other_system, positions, rest_atoms, T_min, T):

#     # Zero nb forces
#     nb_force = REST_system.getForce(3)
#     for i in range(nb_force.getNumParticles()):
#         charge, sigma, epsilon = nb_force.getParticleParameters(i)
#         nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
    
#     # Zero exceptions force
#     for i in range(nb_force.getNumExceptions()):
#         p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
#         nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
    
#     for i in range(nb_force.getNumParticleParameterOffsets()):
#         param, particle_idx, chargeScale, sigmaScale, epsilonScale = nb_force.getParticleParameterOffset(i)
#         nb_force.setParticleParameterOffset(i, 'electrostatic_scale', particle_idx, 0.0, 0.0, 0.0)
#         nb_force.setParticleParameterOffset(i, 'steric_scale', particle_idx, 0.0, 0.0, 0.0)

#     for i in range(nb_force.getNumExceptionParameterOffsets()):
#         param, exception_idx, chargeScale, sigmaScale, epsilonScale = nb_force.getExceptionParameterOffset(i)
#         nb_force.setExceptionParameterOffset(i, 'electrostatic_scale', exception_idx, 0.0, 0.0, 0.0)
#         nb_force.setExceptionParameterOffset(i, 'steric_scale', exception_idx, 0.0, 0.0, 0.0)

#     nb_force = other_system.getForce(3)
#     for i in range(nb_force.getNumExceptions()):
#         p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
#         nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)

#     for i in range(nb_force.getNumParticles()):
#         charge, sigma, epsilon = nb_force.getParticleParameters(i)
#         nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
    
    # Create thermodynamic state
    lambda_zero_alchemical_state = RESTState.from_system(REST_system)
    thermostate = ThermodynamicState(REST_system, temperature=T_min)
    compound_thermodynamic_state = CompoundThermodynamicState(thermostate,
                                                              composable_states=[lambda_zero_alchemical_state])

    # Set alchemical parameters
    beta_0 = 1 / (kB * T_min)
    beta_m = 1 / (kB * T)
    compound_thermodynamic_state.set_alchemical_parameters(beta_0, beta_m)

    # Get energy
    integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
    context = compound_thermodynamic_state.create_context(integrator)
    context.setPositions(positions)
#     sampler_state = SamplerState.from_context(context)
#     REST_energy = compound_thermodynamic_state.reduced_potential(sampler_state)
    REST_energy = context.getState(getEnergy=True).getPotentialEnergy().value_in_unit_system(unit.md_unit_system)

    # Compute energy for non-RESTified system
    # Determine regions and scaling factors
    nonrest_atoms = [i for i in range(other_system.getNumParticles()) if i not in rest_atoms]
    rest_scaling = beta_m / beta_0
    inter_scaling = np.sqrt(beta_m / beta_0)
    
    # Compute 3 energies: 
    ## Test case 1: rest-rest energy
    ## Test case 2: nonrest-nonrest energy
    ## Test case 3: total energy
    unmodified_energies = []
    for test_case in range(3):
        system_copy = copy.deepcopy(other_system)
        
        if test_case != 2:
            bond_force = system_copy.getForce(0)
            for bond_index in range(bond_force.getNumBonds()):
                p1, p2, length, k = bond_force.getBondParameters(bond_index)
                particles = [p1, p2]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
                    else:
                        continue
                else:
                    bond_force.setBondParameters(bond_index, p1, p2, length, k*0)

            angle_force = system_copy.getForce(1)
            for angle_index in range(angle_force.getNumAngles()):
                p1, p2, p3, angle, k = angle_force.getAngleParameters(angle_index)
                particles = [p1, p2, p3]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 
                    else:
                        continue
                else:
                    angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 

            torsion_force = system_copy.getForce(2)
            for torsion_index in range(torsion_force.getNumTorsions()):
                p1, p2, p3, p4, periodicity, phase, k = torsion_force.getTorsionParameters(torsion_index)
                particles = [p1, p2, p3, p4]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
                    else:
                        continue
                else:
                    torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)

            nb_force = system_copy.getForce(3)
            for i in range(nb_force.getNumExceptions()):
                p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
                particles = [p1, p2]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
                    else:
                        continue
                else:
                    nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)

            nb_force = system_copy.getForce(3)
            for i in range(nb_force.getNumParticles()):
                charge, sigma, epsilon = nb_force.getParticleParameters(i)
                if i in rest_atoms:
                    if test_case == 1:
                        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
                    else:
                        continue
                else:
                    if test_case == 0:
                        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
                    else:
                        continue     
            
        # Get energy
        thermostate = ThermodynamicState(system_copy, temperature=T_min)
        integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
        context = thermostate.create_context(integrator)
        context.setPositions(positions)
#         sampler_state = SamplerState.from_context(context)
#         system_copy_energy = thermostate.reduced_potential(sampler_state)
        system_copy_energy = context.getState(getEnergy=True).getPotentialEnergy().value_in_unit_system(unit.md_unit_system)
        unmodified_energies.append(system_copy_energy)
    
    print(unmodified_energies)
    unmodified_energy = unmodified_energies[0] * rest_scaling + unmodified_energies[1] + (unmodified_energies[2] - unmodified_energies[0] - unmodified_energies[1]) * inter_scaling
    print(REST_energy)
    print(unmodified_energy)
    assert np.isclose(REST_energy, unmodified_energy), f"REST energy was {REST_energy} and unmodified_energy was {unmodified_energy}"
                            

In [18]:
test_energy_scaling()


INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=2.0, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=2.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=2.0), unit=nanometer)]
INFO:REST:No unknown forces.
INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=3.2852863, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=3.2861648000000003, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=3.1855098), unit=nanometer)]
INFO:REST:No unknown forces.


[99.16267167200934, -43.86889880970683, -88.08855116600438]
-95.67415484266755
-95.67417659640483
[99.11655729667459, -24568.413671886898, -24659.27016879196]
-24653.19382010929
-24653.18662810733


DEBUG:openmmforcefields.system_generators:Trying GAFFTemplateGenerator to load gaff-2.11
INFO:proposal_generator:	Conducting polymer point mutation proposal...
INFO:proposal_generator:[Atom(name=CB, atomic number=6), Atom(name=HB1, atomic number=1), Atom(name=HB2, atomic number=1), Atom(name=HB3, atomic number=1)]
INFO:proposal_generator:[Atom(name=CB, atomic number=6), Atom(name=CG2, atomic number=6), Atom(name=OG1, atomic number=8), Atom(name=HB, atomic number=1), Atom(name=HG1, atomic number=1), Atom(name=HG21, atomic number=1), Atom(name=HG22, atomic number=1), Atom(name=HG23, atomic number=1)]


making topology proposal


INFO:geometry:propose: performing forward proposal
INFO:geometry:propose: unique new atoms detected; proceeding to _logp_propose...
INFO:geometry:Conducting forward proposal...
INFO:geometry:Computing proposal order with NetworkX...
INFO:geometry:number of atoms to be placed: 6
INFO:geometry:Atom index proposal order is [14, 18, 17, 15, 19, 16]
INFO:geometry:omitted_bonds: []
INFO:geometry:direction of proposal is forward; creating atoms_with_positions and new positions from old system/topology...


generating geometry engine
making geometry proposal from ALA to THR


INFO:geometry:creating growth system...
INFO:geometry:	creating bond force...
INFO:geometry:	there are 11 bonds in reference force.
INFO:geometry:	creating angle force...
INFO:geometry:	there are 43 angles in reference force.
INFO:geometry:	creating torsion force...
INFO:geometry:	creating extra torsions force...
INFO:geometry:	there are 72 torsions in reference force.
INFO:geometry:	creating nonbonded force...
INFO:geometry:		grabbing reference nonbonded method, cutoff, switching function, switching distance...
INFO:geometry:		creating nonbonded exception force (i.e. custom bond for 1,4s)...
INFO:geometry:		looping through exceptions calculating growth indices, and adding appropriate interactions to custom bond force.
INFO:geometry:		there are 1654 in the reference Nonbonded force
INFO:geometry:Neglected angle terms : []
INFO:geometry:omitted_growth_terms: {'bonds': [], 'angles': [], 'torsions': [], '1,4s': []}
INFO:geometry:extra torsions: {0: (19, 18, 10, 8, [1, Quantity(value=-0.07

conducting subsequent work with the following platform: CUDA


INFO:geometry:setting atoms_with_positions context new positions


conducting subsequent work with the following platform: CUDA


INFO:geometry:There are 6 new atoms
INFO:geometry:	reduced angle potential = 0.016981347129623535.
INFO:geometry:	reduced angle potential = 0.005093936715684544.
INFO:geometry:	reduced angle potential = 1.248656142948018.
INFO:geometry:	reduced angle potential = 1.7287424922210293.
INFO:geometry:	reduced angle potential = 0.3123291381457235.
INFO:geometry:	reduced angle potential = 0.15940799278644846.
INFO:geometry:	beginning construction of no_nonbonded final system...
INFO:geometry:	initial no-nonbonded final system forces ['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce', 'MonteCarloBarostat']
INFO:geometry:	final no-nonbonded final system forces dict_keys(['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce'])
INFO:geometry:	there are 11 bond forces in the no-nonbonded final system
INFO:geometry:	there are 43 angle forces in the no-nonbonded final system
INFO:geometry:	there are 72 torsion forces in the no-nonbonded

conducting subsequent work with the following platform: CUDA


INFO:geometry:total reduced potential before atom placement: 16.815513847873078


conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA


INFO:geometry:total reduced energy added from growth system: -54.33331783373363
INFO:geometry:final reduced energy -37.51780442070653
INFO:geometry:sum of energies: -37.51780398586055
INFO:geometry:magnitude of difference in the energies: 4.3484597966880756e-07
INFO:geometry:Final logp_proposal: 45.56341936342642
INFO:geometry:logp_reverse: performing reverse proposal
INFO:geometry:logp_reverse: unique new atoms detected; proceeding to _logp_propose...
INFO:geometry:Conducting forward proposal...
INFO:geometry:Computing proposal order with NetworkX...
INFO:geometry:number of atoms to be placed: 2
INFO:geometry:Atom index proposal order is [13, 12]
INFO:geometry:omitted_bonds: []
INFO:geometry:direction of proposal is reverse; creating atoms_with_positions from old system/topology


added energy components: [('CustomBondForce', 0.8144506729290285), ('CustomAngleForce', 4.590479418047062), ('CustomTorsionForce', 10.506513235499154), ('CustomBondForce', -70.2447611602089)]


INFO:geometry:creating growth system...
INFO:geometry:	creating bond force...
INFO:geometry:	there are 9 bonds in reference force.
INFO:geometry:	creating angle force...
INFO:geometry:	there are 36 angles in reference force.
INFO:geometry:	creating torsion force...
INFO:geometry:	creating extra torsions force...
INFO:geometry:	there are 42 torsions in reference force.
INFO:geometry:	creating nonbonded force...
INFO:geometry:		grabbing reference nonbonded method, cutoff, switching function, switching distance...
INFO:geometry:		creating nonbonded exception force (i.e. custom bond for 1,4s)...
INFO:geometry:		looping through exceptions calculating growth indices, and adding appropriate interactions to custom bond force.
INFO:geometry:		there are 1631 in the reference Nonbonded force
INFO:geometry:Neglected angle terms : []
INFO:geometry:omitted_growth_terms: {'bonds': [], 'angles': [], 'torsions': [], '1,4s': []}
INFO:geometry:extra torsions: {0: (6, 8, 10, 12, [1, Quantity(value=-2.0823

conducting subsequent work with the following platform: CUDA


INFO:geometry:setting atoms_with_positions context old positions


conducting subsequent work with the following platform: CUDA


INFO:geometry:There are 2 new atoms
INFO:geometry:	reduced angle potential = 7.39096069988752e-11.
INFO:geometry:	reduced angle potential = 3.205832446488702e-13.
INFO:geometry:	beginning construction of no_nonbonded final system...
INFO:geometry:	initial no-nonbonded final system forces ['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce', 'MonteCarloBarostat']
INFO:geometry:	final no-nonbonded final system forces dict_keys(['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce'])
INFO:geometry:	there are 9 bond forces in the no-nonbonded final system
INFO:geometry:	there are 36 angle forces in the no-nonbonded final system
INFO:geometry:	there are 42 torsion forces in the no-nonbonded final system
INFO:geometry:reverse final system defined with 0 neglected angles.


conducting subsequent work with the following platform: CUDA


INFO:geometry:total reduced potential before atom placement: 16.815513847873078


conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA


INFO:geometry:total reduced energy added from growth system: 5.747545230963934
INFO:geometry:final reduced energy 22.563058899034125
INFO:geometry:sum of energies: 22.56305907883701
INFO:geometry:magnitude of difference in the energies: 1.798028863575496e-07
INFO:geometry:Final logp_proposal: -17921.269846110896
INFO:relative:*** Generating RepartitionedHybridTopologyFactory ***
INFO:relative:Old system forces: dict_keys(['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce', 'MonteCarloBarostat'])
INFO:relative:New system forces: dict_keys(['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce', 'MonteCarloBarostat'])
INFO:relative:No unknown forces.
INFO:relative:Nonbonded method to be used (i.e. from old system): 4
INFO:relative:Adding and mapping old atoms to hybrid system...
INFO:relative:Adding and mapping new atoms to hybrid system...
INFO:relative:Added MonteCarloBarostat.
INFO:relative:getDefaultPeriodicBoxVectors adde

added energy components: [('CustomBondForce', 0.0), ('CustomAngleForce', 0.00017810068841776427), ('CustomTorsionForce', 0.0028912052701551938), ('CustomBondForce', 5.744475925005359)]


INFO:relative:Generating old system exceptions dict...
INFO:relative:Generating new system exceptions dict...
INFO:relative:Handling constraints...
INFO:relative:Handling virtual sites...
INFO:relative:	_handle_virtual_sites: numVirtualSites: 0
INFO:relative:	_add_nonbonded_force_terms: <openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x2ac57c020cc0> > added to hybrid system
INFO:relative:	_add_nonbonded_force_terms: nonbonded_method is PME or Ewald
INFO:relative:	_add_nonbonded_force_terms: 4 added to standard nonbonded force
INFO:relative:	_add_nonbonded_force_terms: 2 added to sterics_custom_nonbonded force
INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=2.56477354, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=2.56477354, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=2.56477354), unit=nanometer)]
INFO:REST:No unknown forces.


[-34.800948356611826, -2288.8878282656515, -2465.9308346738835]
-2406.8687197496624
-2406.8686262621886


In [3]:
def compare_energies_hybrid(htf, rest_atoms, T_min, T, is_old=True, is_solvated=False):
    hybrid_system = htf.hybrid_system
    hybrid_positions = htf.hybrid_positions
    other_system = htf._topology_proposal.old_system if is_old else htf._topology_proposal.new_system
    other_positions = htf.old_positions(htf.hybrid_positions) if is_old else htf.new_positions(htf.hybrid_positions)
    
    # Omit reciprocal space (because this cannot be scaled by rest)
    reciprocal_space_force_index = 7 if is_solvated else 6
    reciprocal_space_force = hybrid_system.getForce(reciprocal_space_force_index)
    for i in range(reciprocal_space_force.getNumParticles()):
        charge, sigma, epsilon = reciprocal_space_force.getParticleParameters(i)
        reciprocal_space_force.setParticleParameters(i, charge*0, sigma*0, epsilon*0)
    for i in range(reciprocal_space_force.getNumExceptions()):
        p1, p2, chargeProd, sigma, epsilon = reciprocal_space_force.getExceptionParameters(i)
        reciprocal_space_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma*0, epsilon*0)
    
    # Zero the unique old/new bonds in the hybrid system
    bond_force_index = 1 if is_solvated else 0
    angle_force_index = 2 if is_solvated else 1
    torsion_force_index = 3 if is_solvated else 2
    custom_bond_force = hybrid_system.getForce(bond_force_index)
    custom_angle_force = hybrid_system.getForce(angle_force_index)
    custom_torsion_force = hybrid_system.getForce(torsion_force_index)
    hybrid_to_bond_indices = htf._hybrid_to_new_bond_indices if is_old else htf._hybrid_to_old_bond_indices
    hybrid_to_angle_indices = htf._hybrid_to_new_angle_indices if is_old else htf._hybrid_to_old_angle_indices
    hybrid_to_torsion_indices = htf._hybrid_to_new_torsion_indices if is_old else htf._hybrid_to_old_torsion_indices
    for hybrid_idx, idx in hybrid_to_bond_indices.items():
        p1, p2, hybrid_params = custom_bond_force.getBondParameters(hybrid_idx)
        hybrid_params = list(hybrid_params)
        hybrid_params[-2] *= 0 # zero K_old
        hybrid_params[-1] *= 0 # zero K_new
        custom_bond_force.setBondParameters(hybrid_idx, p1, p2, hybrid_params)
    for hybrid_idx, idx in hybrid_to_angle_indices.items():
        p1, p2, p3, hybrid_params = custom_angle_force.getAngleParameters(hybrid_idx)
        hybrid_params = list(hybrid_params)
        hybrid_params[-1] *= 0
        hybrid_params[-2] *= 0
        custom_angle_force.setAngleParameters(hybrid_idx, p1, p2, p3, hybrid_params)
    for hybrid_idx, idx in hybrid_to_torsion_indices.items():
        p1, p2, p3, p4, hybrid_params = custom_torsion_force.getTorsionParameters(hybrid_idx)
        hybrid_params = list(hybrid_params)
        hybrid_params[-1] *= 0
        hybrid_params[-2] *= 0
        custom_torsion_force.setTorsionParameters(hybrid_idx, p1, p2, p3, p4, hybrid_params)
           
    # Create thermodynamic state
    lambda_protocol = RESTCapableLambdaProtocol(functions='no-alchemy')
    lambda_zero_alchemical_state = RESTCapableRelativeAlchemicalState.from_system(hybrid_system)
    thermostate = ThermodynamicState(hybrid_system, temperature=T_min)
    compound_thermodynamic_state = CompoundThermodynamicState(thermostate,
                                                              composable_states=[lambda_zero_alchemical_state])

    # Set alchemical parameters
    beta_0 = 1 / (kB * T_min)
    beta_m = 1 / (kB * T)
    global_lambda = 1
    endstate = 0 if is_old else 1
    compound_thermodynamic_state.set_alchemical_parameters(global_lambda, beta_0, beta_m, lambda_protocol=lambda_protocol, endstate=endstate)
    
    # Get energy
    integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
    context = compound_thermodynamic_state.create_context(integrator)
    context.setPositions(hybrid_positions)
    context.setParameter('lambda_alchemical_electrostatics_reciprocal', 0) # Zero offsets for reciprocal space
    REST_energy = context.getState(getEnergy=True).getPotentialEnergy().value_in_unit_system(unit.md_unit_system)

    # Compute energy for non-RESTified system
    # Determine regions and scaling factors
    nonrest_atoms = [i for i in range(other_system.getNumParticles()) if i not in rest_atoms]
    rest_scaling = beta_m / beta_0
    inter_scaling = np.sqrt(beta_m / beta_0)
    
    # Compute 3 energies: 
    ## Test case 1: rest-rest energy
    ## Test case 2: nonrest-nonrest energy
    ## Test case 3: total energy
    unmodified_energies = []
    for test_case in range(3):
        system_copy = copy.deepcopy(other_system)
        
        if test_case != 2:
            bond_force = system_copy.getForce(0)
            for bond_index in range(bond_force.getNumBonds()):
                p1, p2, length, k = bond_force.getBondParameters(bond_index)
                particles = [p1, p2]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
                    else:
                        continue
                else:
                    bond_force.setBondParameters(bond_index, p1, p2, length, k*0)

            angle_force = system_copy.getForce(1)
            for angle_index in range(angle_force.getNumAngles()):
                p1, p2, p3, angle, k = angle_force.getAngleParameters(angle_index)
                particles = [p1, p2, p3]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 
                    else:
                        continue
                else:
                    angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 

            torsion_force = system_copy.getForce(2)
            for torsion_index in range(torsion_force.getNumTorsions()):
                p1, p2, p3, p4, periodicity, phase, k = torsion_force.getTorsionParameters(torsion_index)
                particles = [p1, p2, p3, p4]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
                    else:
                        continue
                else:
                    torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)

            nb_force = system_copy.getForce(3)
            for i in range(nb_force.getNumExceptions()):
                p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
                particles = [p1, p2]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
                    else:
                        continue
                else:
                    nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)

            nb_force = system_copy.getForce(3)
            for i in range(nb_force.getNumParticles()):
                charge, sigma, epsilon = nb_force.getParticleParameters(i)
                if i in rest_atoms:
                    if test_case == 1:
                        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
                    else:
                        continue
                else:
                    if test_case == 0:
                        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
                    else:
                        continue     
            
        # Omit reciprocal space
        system_copy.getForce(3).setReciprocalSpaceForceGroup(31)
#         print("other system nb force direct space group: ", system_copy.getForce(3).getForceGroup())
#         print("other system torsion force group: ", system_copy.getForce(2).getForceGroup())
        
        # Get energy
        thermostate = ThermodynamicState(system_copy, temperature=T_min)
        integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
        context = thermostate.create_context(integrator)
        context.setPositions(other_positions)
#         sampler_state = SamplerState.from_context(context)
#         system_copy_energy = thermostate.reduced_potential(sampler_state)
        system_copy_energy = context.getState(getEnergy=True, groups={0}).getPotentialEnergy().value_in_unit_system(unit.md_unit_system)
        unmodified_energies.append(system_copy_energy)
    
    print(unmodified_energies)
    unmodified_energy = unmodified_energies[0] * rest_scaling + unmodified_energies[1] + (unmodified_energies[2] - unmodified_energies[0] - unmodified_energies[1]) * inter_scaling
    print(REST_energy)
    print(unmodified_energy)
    assert np.isclose(REST_energy, unmodified_energy), f"REST energy was {REST_energy} and unmodified_energy was {unmodified_energy}"
                            

In [5]:
with open("atp_solvent_scale_region.pickle", "rb") as f:
    htf = pickle.load(f)


In [11]:
htf.hybrid_system.getForces()

[<openmm.openmm.CustomBondForce; proxy of <Swig Object of type 'OpenMM::CustomBondForce *' at 0x2b2a57647090> >,
 <openmm.openmm.CustomAngleForce; proxy of <Swig Object of type 'OpenMM::CustomAngleForce *' at 0x2b2a53fd1ae0> >,
 <openmm.openmm.CustomTorsionForce; proxy of <Swig Object of type 'OpenMM::CustomTorsionForce *' at 0x2b2a53fd1db0> >,
 <openmm.openmm.CustomNonbondedForce; proxy of <Swig Object of type 'OpenMM::CustomNonbondedForce *' at 0x2b2a53fd1d50> >,
 <openmm.openmm.CustomNonbondedForce; proxy of <Swig Object of type 'OpenMM::CustomNonbondedForce *' at 0x2b2a546586f0> >,
 <openmm.openmm.CustomBondForce; proxy of <Swig Object of type 'OpenMM::CustomBondForce *' at 0x2b2a54658660> >,
 <openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x2b2a54658600> >]

In [6]:
htf._atom_classes['core_atoms']

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16, 17, 18, 19, 20, 21}

In [7]:
htf._atom_classes['unique_old_atoms']

{12, 13}

In [8]:
htf._atom_classes['unique_new_atoms']

{1557, 1558, 1559, 1560, 1561, 1562}

In [9]:
#### Set temperatures
T_min = 300.0 * unit.kelvin
T = 600 * unit.kelvin

# ## CASE 1: alanine dipeptide in vacuum
# print("ala dipeptide in vacuum at lambda = 0")
# with open("atp_vacuum_scale_region.pickle", "rb") as f:
#     htf = pickle.load(f)

# # Check energy scaling
# rest_atoms = [htf._hybrid_to_old_map[index] for index in [10, 12, 13]]
# compare_energies_hybrid(htf, rest_atoms, T_min, T)

# print("ala dipeptide in vacuum at lambda = 1")
# with open("atp_vacuum_scale_region.pickle", "rb") as f:
#     htf = pickle.load(f)

# # Check energy scaling
# rest_atoms = [htf._hybrid_to_new_map[index] for index in [10, 12, 22, 23, 24, 25, 26, 27]]
# compare_energies_hybrid(htf, rest_atoms, T_min, T, is_old=False)


## CASE 2: alanine dipeptide in solvent
print("ala dipeptide in solvent at lambda = 0")
with open("atp_solvent_scale_region.pickle", "rb") as f:
    htf = pickle.load(f)

# Check energy scaling
rest_atoms = [htf._hybrid_to_old_map[index] for index in [10, 11, 12, 13]]
compare_energies_hybrid(htf, rest_atoms, T_min, T, is_solvated=True)

print("ala dipeptide in solvent at lambda = 1")
with open("atp_solvent_scale_region.pickle", "rb") as f:
    htf = pickle.load(f)

# Check energy scaling
rest_atoms = [htf._hybrid_to_new_map[index] for index in [10, 11, 1557, 1558, 1559, 1560, 1561, 1562]]
compare_energies_hybrid(htf, rest_atoms, T_min, T, is_old=False, is_solvated=True)



ala dipeptide in solvent at lambda = 0
[10.033547687424486, 119656.2595678248, 119720.8221766204]
119699.82968678446
119699.83421054983
ala dipeptide in solvent at lambda = 1
[3329.783119775322, 119658.25804912952, 123043.05544347492]
121362.0462282735
121362.05057562774


In [None]:
# Note: in order to get ala dipeptide in solvent tests passing, i need to use an htf that
# was generated with rest scaling applied to all exceptions in custom exceptions expression
# (including the subtracted out reciprocal space)