In [1]:
###########################################
# IMPORTS
###########################################
from openmmtools.states import SamplerState, ThermodynamicState, CompoundThermodynamicState
from simtk import unit, openmm
from perses.tests.utils import compute_potential_components
from openmmtools.constants import kB
from perses.dispersed.utils import configure_platform
from perses.annihilation.rest import RESTTopologyFactory
from perses.annihilation.lambda_protocol import RESTState
import numpy as np
from perses.tests.test_topology_proposal import generate_atp, generate_dipeptide_top_pos_sys
from openmmtools.testsystems import AlanineDipeptideVacuum, AlanineDipeptideExplicit
import itertools

#############################################
# CONSTANTS
#############################################
temperature = 300.0 * unit.kelvin
kT = kB * temperature
beta = 1.0/kT
REFERENCE_PLATFORM = openmm.Platform.getPlatformByName("Reference")



INFO:rdkit:Enabling RDKit 2021.03.5 jupyter extensions


conducting subsequent work with the following platform: CUDA


# Run the existing scaling test in test_rest.py

In [10]:
def compare_energies(REST_system, other_system, positions, rest_atoms, T_min, T):

    # Create thermodynamic state
    lambda_zero_alchemical_state = RESTState.from_system(REST_system)
    thermostate = ThermodynamicState(REST_system, temperature=T_min)
    compound_thermodynamic_state = CompoundThermodynamicState(thermostate,
                                                              composable_states=[lambda_zero_alchemical_state])

    # Set alchemical parameters
    beta_0 = 1 / (kB * T_min)
    beta_m = 1 / (kB * T)
    compound_thermodynamic_state.set_alchemical_parameters(beta_0, beta_m)

    # Minimize and save energy
    integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
    context = compound_thermodynamic_state.create_context(integrator)
    context.setPositions(positions)
    sampler_state = SamplerState.from_context(context)
    REST_energy = compound_thermodynamic_state.reduced_potential(sampler_state)

    # Compute energy for non-RESTified system
    # Determine regions and scaling factors
    solute = rest_atoms
    solvent = [i for i in range(other_system.getNumParticles()) if i not in solute]
    solute_scaling = beta_m / beta_0
    inter_scaling = np.sqrt(beta_m / beta_0)

    # Scale the terms in the bond force appropriately
    bond_force = other_system.getForce(0)
    for bond in range(bond_force.getNumBonds()):
        p1, p2, length, k = bond_force.getBondParameters(bond)
        if p1 in solute and p2 in solute:
            bond_force.setBondParameters(bond, p1, p2, length, k * solute_scaling)
        elif (p1 in solute and p2 in solvent) or (p1 in solvent and p2 in solute):
            bond_force.setBondParameters(bond, p1, p2, length, k * inter_scaling)

    # Scale the terms in the angle force appropriately
    angle_force = other_system.getForce(1)
    for angle_index in range(angle_force.getNumAngles()):
        p1, p2, p3, angle, k = angle_force.getAngleParameters(angle_index)
        if p1 in solute and p2 in solute and p3 in solute:
            angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k * solute_scaling)
        elif set([p1, p2, p3]).intersection(set(solute)) != set() and set([p1, p2, p3]).intersection(
                set(solvent)) != set():
            angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k * inter_scaling)

    # Scale the terms in the torsion force appropriately
    torsion_force = other_system.getForce(2)
    for torsion_index in range(torsion_force.getNumTorsions()):
        p1, p2, p3, p4, periodicity, phase, k = torsion_force.getTorsionParameters(torsion_index)
        if p1 in solute and p2 in solute and p3 in solute and p4 in solute:
            torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k * solute_scaling)
        elif set([p1, p2, p3, p4]).intersection(set(solute)) != set() and set([p1, p2, p3, p4]).intersection(
                set(solvent)) != set():
            torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k * inter_scaling)

    # Scale the exceptions in the nonbonded force appropriately
    nb_force = other_system.getForce(3)
    for nb_index in range(nb_force.getNumExceptions()):
        p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(nb_index)
        if p1 in solute and p2 in solute:
            nb_force.setExceptionParameters(nb_index, p1, p2, solute_scaling * chargeProd, sigma, solute_scaling * epsilon)
        elif (p1 in solute and p2 in solvent) or (p1 in solvent and p2 in solute):
            nb_force.setExceptionParameters(nb_index, p1, p2, inter_scaling * chargeProd, sigma, inter_scaling * epsilon)

    # Scale nonbonded interactions for solute-solute region by adding exceptions for all pairs of atoms
    exception_pairs = [tuple(sorted([nb_force.getExceptionParameters(nb_index)[0], nb_force.getExceptionParameters(nb_index)[1]])) for nb_index in range(nb_force.getNumExceptions())]
    solute_pairs = set([tuple(sorted(pair)) for pair in list(itertools.product(solute, solute))])
    for pair in list(solute_pairs):
        p1 = pair[0]
        p2 = pair[1]
        p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
        p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
        if p1 != p2:
            if pair not in exception_pairs:
                nb_force.addException(p1, p2, p1_charge * p2_charge * solute_scaling, 0.5 * (p1_sigma + p2_sigma),
                                      np.sqrt(p1_epsilon * p2_epsilon) * solute_scaling)

    # Scale nonbonded interactions for inter region by adding exceptions for all pairs of atoms
    for pair in list(itertools.product(solute, solvent)):
        p1 = pair[0]
        p2 = int(pair[1])  # otherwise, will be a numpy int
        p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
        p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
        if tuple(sorted(pair)) not in exception_pairs:
            nb_force.addException(p1, p2, p1_charge * p2_charge * inter_scaling, 0.5 * (p1_sigma + p2_sigma), np.sqrt(p1_epsilon * p2_epsilon) * inter_scaling)

    # Get energy
    thermostate = ThermodynamicState(other_system, temperature=T_min)
    integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
    context = thermostate.create_context(integrator)
    context.setPositions(positions)
    sampler_state = SamplerState.from_context(context)
    nonREST_energy = thermostate.reduced_potential(sampler_state)
    
    print(f"REST energy: {REST_energy}")
    print(f"nonREST energy: {nonREST_energy}")
    assert REST_energy - nonREST_energy < 1, f"The energy of the REST system ({REST_energy}) does not match " \
                                                        f"that of the non-REST system with terms manually scaled according to REST2({nonREST_energy})."

def test_energy_scaling():
    """
        Test whether the energy of a REST-ified system is equal to the energy of the system with terms manually scaled by
        the same factor as is used in REST.  T_min is 298 K and the thermodynamic state has temperature 600 K.
    """

    # Set temperatures
    T_min = 300.0 * unit.kelvin
    T = 600 * unit.kelvin

    ## CASE 1: alanine dipeptide in vacuum
    # Create vanilla system
    print("alanine dipeptide in vacuum")
    ala = AlanineDipeptideVacuum()
    system = ala.system
    positions = ala.positions

    # Create REST system
    system.removeForce(4)
    res1 = list(ala.topology.residues())[1]
    rest_atoms = [atom.index for atom in res1.atoms()]
    factory = RESTTopologyFactory(system, solute_region=rest_atoms)
    REST_system = factory.REST_system
    
    # Check energy scaling
    compare_energies(REST_system, system, positions, rest_atoms, T_min, T)

    ## CASE 2: alanine dipeptide in solvent (with Nonbonded Method as NoCutoff)
    # Create vanilla system
    print("alanine dipeptide in solvent")
    ala = AlanineDipeptideExplicit()
    system = ala.system
#     system.getForce(3).setNonbondedMethod(0)
    positions = ala.positions

    # Create REST system
    system.removeForce(4)
    res1 = list(ala.topology.residues())[1]
    rest_atoms = [atom.index for atom in res1.atoms()]
    factory = RESTTopologyFactory(system, solute_region=rest_atoms, use_dispersion_correction=True)
    REST_system = factory.REST_system
    
#     system.getForce(3).setUseDispersionCorrection(False)
#     REST_system.getForce(3).setUseDispersionCorrection(False)

    # Check energy scaling
    compare_energies(REST_system, system, positions, rest_atoms, T_min, T)

    ## CASE 3: alanine dipeptide in solvent with repartitioned hybrid system (with Nonbonded Method as NoCutoff)
    # Create repartitioned hybrid system for lambda 0 endstate
    print("alanine dipepetide with repartitioned hybrid system")
    atp, system_generator = generate_atp(phase='solvent')
    htf = generate_dipeptide_top_pos_sys(atp.topology,
                                         new_res='THR',
                                         system=atp.system,
                                         positions=atp.positions,
                                         system_generator=system_generator,
                                         conduct_htf_prop=True,
                                         repartitioned=True,
                                         endstate=0,
                                         validate_endstate_energy=False)
    system = htf.hybrid_system
    system.removeForce(0) # Remove barostat
#     system.getForce(3).setNonbondedMethod(0)

    # Create REST-ified hybrid system
    res1 = list(htf.hybrid_topology.residues)[1]
    rest_atoms = [atom.index for atom in list(res1.atoms)]
    factory = RESTTopologyFactory(system, solute_region=rest_atoms, use_dispersion_correction=True)
    REST_system = factory.REST_system

    # Check energy scaling
    compare_energies(REST_system, system, htf.hybrid_positions, rest_atoms, T_min, T)

In [23]:
test_energy_scaling()


INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=2.0, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=2.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=2.0), unit=nanometer)]
INFO:REST:No unknown forces.


alanine dipeptide in vacuum


INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=3.2852863, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=3.2861648000000003, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=3.1855098), unit=nanometer)]


REST energy: -38.356522004048074
nonREST energy: -38.35653377001734
alanine dipeptide in solvent


INFO:REST:No unknown forces.


REST energy: -9883.66816482165
nonREST energy: -9884.410368966905


# Adapt this test to avoid adding exceptions (which don't use PME)

In [2]:
import copy
from tqdm import tqdm_notebook


In [7]:
# Set temperatures
T_min = 300.0 * unit.kelvin
T = 600 * unit.kelvin

# ## CASE 1: alanine dipeptide in vacuum
# # Create vanilla system
# print("alanine dipeptide in vacuum")
# ala = AlanineDipeptideVacuum()
# system = ala.system
# positions = ala.positions

# # Create REST system
# system.removeForce(4)
# res1 = list(ala.topology.residues())[1]
# rest_atoms = [atom.index for atom in res1.atoms()]
# factory = RESTTopologyFactory(system, solute_region=rest_atoms)
# REST_system = factory.REST_system

## CASE 2: alanine dipeptide in solvent
# Create vanilla system
print("alanine dipeptide in solvent")
ala = AlanineDipeptideExplicit()
system = ala.system
positions = ala.positions

# Create REST system
system.removeForce(4)
# system.getForce(3).setNonbondedMethod(0)
res1 = list(ala.topology.residues())[1]
rest_atoms = [atom.index for atom in res1.atoms()]
factory = RESTTopologyFactory(system, solute_region=rest_atoms, use_dispersion_correction=True)
REST_system = factory.REST_system
# REST_system.getForce(3).setNonbondedMethod(0)



alanine dipeptide in solvent


INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=3.2852863, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=3.2861648000000003, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=3.1855098), unit=nanometer)]
INFO:REST:No unknown forces.


In [200]:
# system.getForce(3).setUseDispersionCorrection(False)
# REST_system.getForce(3).setUseDispersionCorrection(False)

In [17]:
# solute = rest_atoms
# solvent = [i for i in range(REST_system.getNumParticles()) if i not in rest_atoms]

# nb_force = REST_system.getForce(3)
# # for i in range(nb_force.getNumParticles()):
# #     charge, sigma, epsilon = nb_force.getParticleParameters(i)
# #     if i in rest_atoms:
# #         continue
# #     else:
# #         nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
# # #     nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
        

# # for i in range(nb_force.getNumParticleParameterOffsets()):
# #     param, particle_idx, chargeScale, sigmaScale, epsilonScale = nb_force.getParticleParameterOffset(i)
# #     if particle_idx in rest_atoms:
# #         continue
# #     else:
# #         nb_force.setParticleParameterOffset(i, 'electrostatic_scale', particle_idx, 0.0, 0.0, 0.0)
# #         nb_force.setParticleParameterOffset(i, 'steric_scale', particle_idx, 0.0, 0.0, 0.0)

# zeroed_exceptions = []
# for i in range(nb_force.getNumExceptions()):
#     p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
#     particles = [p1, p2]
#     if all(x in rest_atoms for x in particles):
# #         nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
# #         zeroed_exceptions.append(i)
#         continue
#     elif all(x in solvent for x in particles):
#         nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
#         zeroed_exceptions.append(i)
#     else:
#         nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
#         zeroed_exceptions.append(i)
# #     nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
        
# for i in range(nb_force.getNumExceptionParameterOffsets()):
#     param, exception_idx, chargeScale, sigmaScale, epsilonScale = nb_force.getExceptionParameterOffset(i)
#     if exception_idx in zeroed_exceptions:
#         nb_force.setExceptionParameterOffset(i, 'electrostatic_scale', exception_idx, 0.0, 0.0, 0.0)
#         nb_force.setExceptionParameterOffset(i, 'steric_scale', exception_idx, 0.0, 0.0, 0.0)


# print("adding exceptions for solute-solvent pairs")
# exception_pairs = [tuple(sorted([nb_force.getExceptionParameters(nb_index)[0], nb_force.getExceptionParameters(nb_index)[1]])) for nb_index in range(nb_force.getNumExceptions())]
# solute_solvent_pairs = set([tuple(sorted(pair)) for pair in list(itertools.product(solute, solvent))])
# solvent_pairs = set([tuple(sorted(pair)) for pair in list(itertools.product(solvent, solvent))])
# for pair in list(solute_solvent_pairs):
#     p1 = pair[0]
#     p2 = pair[1]
# #     p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
# #     p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
#     if p1 != p2:
#         if pair not in exception_pairs:
#             nb_force.addException(p1, p2, 0, 0.5, 0)
# print("adding exceptions for solvent pairs")
# for pair in tqdm_notebook(list(solvent_pairs)):
#     p1 = pair[0]
#     p2 = pair[1]
# #     p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
# #     p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
#     if p1 != p2:
#         if pair not in exception_pairs:
#             nb_force.addException(p1, p2, 0, 0.5, 0)


        



adding exceptions for solute-solvent pairs
adding exceptions for solvent pairs


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for pair in tqdm_notebook(list(solvent_pairs)):


  0%|          | 0/2552670 [00:00<?, ?it/s]

In [33]:
# bond_force = REST_system.getForce(0)
# for i in range(bond_force.getNumBonds()):
#     p1, p2, params = bond_force.getBondParameters(i)
#     length, k, identifier = params
#     if p1 in rest_atoms and p2 in rest_atoms:
#         continue
# #         bond_force.setBondParameters(i, p1, p2, [length, k*0, identifier])
#     elif (p1 in rest_atoms and p2 not in rest_atoms) or (p1 not in rest_atoms and p2 in rest_atoms):
#         bond_force.setBondParameters(i, p1, p2, [length, k*0, identifier])
#     else:
#         bond_force.setBondParameters(i, p1, p2, [length, k*0, identifier])

In [5]:
# Create thermodynamic state
lambda_zero_alchemical_state = RESTState.from_system(REST_system)
thermostate = ThermodynamicState(REST_system, temperature=T_min)
compound_thermodynamic_state = CompoundThermodynamicState(thermostate,
                                                          composable_states=[lambda_zero_alchemical_state])

# Set alchemical parameters
beta_0 = 1 / (kB * T_min)
beta_m = 1 / (kB * T)
compound_thermodynamic_state.set_alchemical_parameters(beta_0, beta_m)

# Get energy
integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
context = compound_thermodynamic_state.create_context(integrator)
context.setPositions(positions)
sampler_state = SamplerState.from_context(context)
REST_energy = compound_thermodynamic_state.reduced_potential(sampler_state)
components_rest = [component[1] for component in compute_potential_components(context, beta=beta)]


conducting subsequent work with the following platform: CUDA


In [6]:
# Zero inter and solvent energies and get REST energy
system_A = copy.deepcopy(system)

# Determine regions and scaling factors
solute = rest_atoms
solvent = [i for i in range(system_A.getNumParticles()) if i not in solute]
solute_scaling = beta_m / beta_0
inter_scaling = np.sqrt(beta_m / beta_0)

bond_force = system_A.getForce(0)
for bond_index in range(bond_force.getNumBonds()):
    p1, p2, length, k = bond_force.getBondParameters(bond_index)
    particles = [p1, p2]
    if all(x in solute for x in particles):
        continue
    elif all(x in solvent for x in particles):
        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
    else:
        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
#     bond_force.setBondParameters(bond_index, p1, p2, length, k*0)

angle_force = system_A.getForce(1)
for angle_index in range(angle_force.getNumAngles()):
    p1, p2, p3, angle, k = angle_force.getAngleParameters(angle_index)
    particles = [p1, p2, p3]
    if all(x in solute for x in particles):
        continue
    elif all(x in solvent for x in particles):
        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)
    else:
        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)
#     angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)

torsion_force = system_A.getForce(2)
for torsion_index in range(torsion_force.getNumTorsions()):
    p1, p2, p3, p4, periodicity, phase, k = torsion_force.getTorsionParameters(torsion_index)
    particles = [p1, p2, p3, p4]
    if all(x in solute for x in particles):
        continue
    elif all(x in solvent for x in particles):
        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
    else:
        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
#     torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)


nb_force = system_A.getForce(3)
for i in range(nb_force.getNumParticles()):
    charge, sigma, epsilon = nb_force.getParticleParameters(i)
    if i in solute:
        continue
    else:
        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
    nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
        
for i in range(nb_force.getNumExceptions()):
    p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
    particles = [p1, p2]
    if all(x in solute for x in particles):
        continue
    elif all(x in solvent for x in particles):
        nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
    else:
        nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
#     if p1 in solute and p2 in solute:
#         continue
#     elif (p1 in solute and p2 in solvent) or (p1 in solvent and p2 in solute):
#         nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
#     else:
#         nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
#     nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
        
        
# Get energy
thermostate = ThermodynamicState(system_A, temperature=T_min)
integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
context = thermostate.create_context(integrator)
context.setPositions(positions)
sampler_state = SamplerState.from_context(context)
system_A_energy = thermostate.reduced_potential(sampler_state)
system_A_energy_scaled = system_A_energy * solute_scaling


In [7]:
# Zero inter and REST energies and get solvent energy
system_B = copy.deepcopy(system)

# Determine regions and scaling factors
solute = rest_atoms
solvent = [i for i in range(system_B.getNumParticles()) if i not in solute]
solute_scaling = beta_m / beta_0
inter_scaling = np.sqrt(beta_m / beta_0)

bond_force = system_B.getForce(0)
for bond_index in range(bond_force.getNumBonds()):
    p1, p2, length, k = bond_force.getBondParameters(bond_index)
    particles = [p1, p2]
    if all(x in solute for x in particles):
        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
    elif all(x in solvent for x in particles):
        continue
    else:
        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
#     if p1 in solute and p2 in solute:
#         bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
#     elif (p1 in solute and p2 in solvent) or (p1 in solvent and p2 in solute):
#         bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
#     else:
#         continue
#     bond_force.setBondParameters(bond_index, p1, p2, length, k*0)

angle_force = system_B.getForce(1)
for angle_index in range(angle_force.getNumAngles()):
    p1, p2, p3, angle, k = angle_force.getAngleParameters(angle_index)
    particles = [p1, p2, p3]
    if all(x in solute for x in particles):
        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)
    elif all(x in solvent for x in particles):
        continue
    else:
        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)
    
#     if p1 in solute and p2 in solute and p3 in solute:
#         angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)
#     elif set([p1, p2, p3]).intersection(set(solute)) != set() and set([p1, p2, p3]).intersection(set(solvent)) != set():
#         angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)
#     else:
#         continue
#     angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)

torsion_force = system_B.getForce(2)
for torsion_index in range(torsion_force.getNumTorsions()):
    p1, p2, p3, p4, periodicity, phase, k = torsion_force.getTorsionParameters(torsion_index)
    particles = [p1, p2, p3, p4]
    if all(x in solute for x in particles):
        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
    elif all(x in solvent for x in particles):
        continue
    else:
        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
#     if p1 in solute and p2 in solute and p3 in solute and p4 in solute:
#         torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
#     elif set([p1, p2, p3, p4]).intersection(set(solute)) != set() and set([p1, p2, p3, p4]).intersection(
#             set(solvent)) != set():
#         torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
#     else:
#         continue
#     torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)

nb_force = system_B.getForce(3)
for i in range(nb_force.getNumParticles()):
    charge, sigma, epsilon = nb_force.getParticleParameters(i)
    if i in solute:
        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
    else:
        continue
#     nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
        
nb_force = system_B.getForce(3)
for i in range(nb_force.getNumExceptions()):
    p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
    particles = [p1, p2]
    if all(x in solute for x in particles):
        nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
    elif all(x in solvent for x in particles):
        continue
    else:
        nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
#     if p1 in solute and p2 in solute:
#         nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
#     elif (p1 in solute and p2 in solvent) or (p1 in solvent and p2 in solute):
#         nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
#     else:
#         continue
#     nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
        
# Get energy
thermostate = ThermodynamicState(system_B, temperature=T_min)
integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
context = thermostate.create_context(integrator)
context.setPositions(positions)
sampler_state = SamplerState.from_context(context)
system_B_energy = thermostate.reduced_potential(sampler_state)
system_B_energy_scaled = system_B_energy

In [10]:
# # Zero REST and solvent energies and get inter energy
# system_C = copy.deepcopy(system)

# # Determine regions and scaling factors
# solute = rest_atoms
# solvent = [i for i in range(system_C.getNumParticles()) if i not in solute]
# solute_scaling = beta_m / beta_0
# inter_scaling = np.sqrt(beta_m / beta_0)

# bond_force = system_C.getForce(0)
# for bond_index in range(bond_force.getNumBonds()):
#     p1, p2, length, k = bond_force.getBondParameters(bond_index)
#     particles = [p1, p2]
# #     if all(x in solute for x in particles):
# #         bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
# #     elif all(x in solvent for x in particles):
# #         bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
# #     else:
# #         continue
# #     if p1 in solute and p2 in solute:
# #         bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
# #     elif (p1 in solute and p2 in solvent) or (p1 in solvent and p2 in solute):
# #         continue
# #     else:
# #         bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
#     bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
        
# angle_force = system_C.getForce(1)
# for angle_index in range(angle_force.getNumAngles()):
#     p1, p2, p3, angle, k = angle_force.getAngleParameters(angle_index)
#     particles = [p1, p2, p3]
# #     if all(x in solute for x in particles):
# #         angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)
# #     elif all(x in solvent for x in particles):
# #         angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)
# #     else:
# #         continue
# #     if p1 in solute and p2 in solute and p3 in solute:
# #         angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)
# #     elif set([p1, p2, p3]).intersection(set(solute)) != set() and set([p1, p2, p3]).intersection(set(solvent)) != set():
# #         continue
# #     else:
# #         angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)
#     angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0)
    
# torsion_force = system_C.getForce(2)
# for torsion_index in range(torsion_force.getNumTorsions()):
#     p1, p2, p3, p4, periodicity, phase, k = torsion_force.getTorsionParameters(torsion_index)
#     particles = [p1, p2, p3, p4]
# #     if all(x in solute for x in particles):
# #         torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
# #     elif all(x in solvent for x in particles):
# #         torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
# #     else:
# #         continue
# #     if p1 in solute and p2 in solute and p3 in solute and p4 in solute:
# #         torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
# #     elif set([p1, p2, p3, p4]).intersection(set(solute)) != set() and set([p1, p2, p3, p4]).intersection(
# #             set(solvent)) != set():
# #         continue
# #     else:
# #         torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
#     torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
    
# nb_force = system_C.getForce(3)
# # for i in range(nb_force.getNumParticles()):
# #     charge, sigma, epsilon = nb_force.getParticleParameters(i)
# #     nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)

# print("zeroing exceptions")
# for i in range(nb_force.getNumExceptions()):
#     p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
#     particles = [p1, p2]
#     if all(x in solute for x in particles):
#         print("solute exception: ", particles, chargeProd, sigma, epsilon)
#         nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
#     elif all(x in solvent for x in particles):
#         print("solvent exception: ", particles, chargeProd, sigma, epsilon)
#         nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
#     else:
#         continue
# #     if p1 in solute and p2 in solute:
# #         nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
# #     elif (p1 in solute and p2 in solvent) or (p1 in solvent and p2 in solute):
# #         nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)
# #     else:
# #         continue
# #     nb_force.setExceptionParameters(i, p1, p2, charge*0, sigma, epsilon*0)

# print("adding exceptions for solute pairs")
# exception_pairs = [tuple(sorted([nb_force.getExceptionParameters(nb_index)[0], nb_force.getExceptionParameters(nb_index)[1]])) for nb_index in range(nb_force.getNumExceptions())]
# solute_pairs = set([tuple(sorted(pair)) for pair in list(itertools.product(solute, solute))])
# solvent_pairs = set([tuple(sorted(pair)) for pair in list(itertools.product(solvent, solvent))])
# for pair in list(solute_pairs):
#     p1 = pair[0]
#     p2 = pair[1]
# #     p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
# #     p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
#     if p1 != p2:
#         if pair not in exception_pairs:
#             nb_force.addException(p1, p2, 0, 0.5, 0)
# print("adding exceptions for solvent pairs")
# for pair in tqdm_notebook(list(solvent_pairs)):
#     p1 = pair[0]
#     p2 = pair[1]
# #     p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
# #     p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
#     if p1 != p2:
#         if pair not in exception_pairs:
#             nb_force.addException(p1, p2, 0, 0.5, 0)


            
# # Get energy
# thermostate = ThermodynamicState(system_C, temperature=T_min)
# integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
# context = thermostate.create_context(integrator)
# context.setPositions(positions)
# sampler_state = SamplerState.from_context(context)
# system_C_energy = thermostate.reduced_potential(sampler_state)
# system_C_energy_scaled = system_C_energy * inter_scaling

zeroing exceptions
solvent exception:  [3, 5] -0.053145975154069464 e**2 0.28047273444524545 nm 0.12012161263171671 kJ/mol
solvent exception:  [2, 5] -0.053145975154069464 e**2 0.28047273444524545 nm 0.12012161263171671 kJ/mol
solvent exception:  [0, 5] -0.053145975154069464 e**2 0.28047273444524545 nm 0.12012161263171671 kJ/mol
solute exception:  [13, 14] 0.030014325027576096 e**2 0.3024601147602527 nm 0.07687068174209402 kJ/mol
solute exception:  [12, 14] 0.030014325027576096 e**2 0.3024601147602527 nm 0.07687068174209402 kJ/mol
solute exception:  [11, 14] 0.030014325027576096 e**2 0.3024601147602527 nm 0.07687068174209402 kJ/mol
solute exception:  [9, 11] 0.004135575 e**2 0.2560442914951194 nm 0.03284440013043381 kJ/mol
solute exception:  [9, 12] 0.004135575 e**2 0.2560442914951194 nm 0.03284440013043381 kJ/mol
solute exception:  [9, 13] 0.004135575 e**2 0.2560442914951194 nm 0.03284440013043381 kJ/mol
solute exception:  [9, 15] -0.0389484751129111 e**2 0.2715637472143426 nm 0.12012

solvent exception:  [338, 339] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [340, 341] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [340, 342] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [341, 342] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [343, 344] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [343, 345] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [344, 345] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [346, 347] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [346, 348] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [347, 348] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [349, 350] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [349, 351] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [350, 351] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [352, 353] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [352, 354] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [353, 354] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [355, 356] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent except

solvent exception:  [739, 740] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [739, 741] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [740, 741] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [742, 743] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [742, 744] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [743, 744] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [745, 746] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [745, 747] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [746, 747] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [748, 749] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [748, 750] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [749, 750] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [751, 752] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [751, 753] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [752, 753] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [754, 755] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [754, 756] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent except

solvent exception:  [1138, 1140] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1139, 1140] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1141, 1142] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1141, 1143] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1142, 1143] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1144, 1145] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1144, 1146] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1145, 1146] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1147, 1148] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1147, 1149] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1148, 1149] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1150, 1151] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1150, 1152] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1151, 1152] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1153, 1154] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1153, 1155] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1154, 1155] 0.0 e**

solvent exception:  [1535, 1536] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1537, 1538] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1537, 1539] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1538, 1539] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1540, 1541] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1540, 1542] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1541, 1542] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1543, 1544] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1543, 1545] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1544, 1545] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1546, 1547] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1546, 1548] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1547, 1548] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1549, 1550] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1549, 1551] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1550, 1551] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1552, 1553] 0.0 e**

solvent exception:  [1838, 1839] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1840, 1841] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1840, 1842] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1841, 1842] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1843, 1844] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1843, 1845] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1844, 1845] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1846, 1847] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1846, 1848] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1847, 1848] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1849, 1850] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1849, 1851] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1850, 1851] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1852, 1853] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1852, 1854] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1853, 1854] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [1855, 1856] 0.0 e**

solvent exception:  [2239, 2240] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2239, 2241] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2240, 2241] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2242, 2243] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2242, 2244] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2243, 2244] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2245, 2246] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2245, 2247] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2246, 2247] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2248, 2249] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2248, 2250] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2249, 2250] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2251, 2252] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2251, 2253] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2252, 2253] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2254, 2255] 0.0 e**2 0.1 nm 0.0 kJ/mol
solvent exception:  [2254, 2256] 0.0 e**

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for pair in tqdm_notebook(list(solvent_pairs)):


  0%|          | 0/2552670 [00:00<?, ?it/s]

In [8]:
# Get total energy
system_C = copy.deepcopy(system)
         
# Get energy
thermostate = ThermodynamicState(system_C, temperature=T_min)
integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
context = thermostate.create_context(integrator)
context.setPositions(positions)
sampler_state = SamplerState.from_context(context)
system_C_energy = thermostate.reduced_potential(sampler_state)

In [96]:
len(solvent_pairs)

2552670

In [9]:
system_D_energy_scaled = (system_C_energy - system_A_energy - system_B_energy)*inter_scaling

In [10]:
REST_energy


-9883.658890418434

In [11]:
components_rest


[0.03429166476612171,
 0.3557026496311803,
 2.282951099724897,
 -9886.341110588197,
 0.0]

In [6]:
0.03429166476612171 + 0.3557026496311803 + 2.282951099724897

2.672945414122199

In [14]:
np.sum(components_rest)

-9883.668165174075

In [13]:
system_A_energy_scaled + system_B_energy_scaled + system_D_energy_scaled


-9883.668095370325

In [15]:
np.isclose(np.sum(components_rest), system_A_energy_scaled + system_B_energy_scaled + system_D_energy_scaled
)

True

# Write new test scaling function and run on vanilla rest

In [3]:
import copy
from tqdm import tqdm_notebook


In [23]:
def compare_energies(REST_system, other_system, positions, rest_atoms, T_min, T):

    # Zero nb forces
    nb_force = REST_system.getForce(3)
    for i in range(nb_force.getNumParticles()):
        charge, sigma, epsilon = nb_force.getParticleParameters(i)
        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
    
    # Zero exceptions force
    for i in range(nb_force.getNumExceptions()):
        p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
        nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
    
    for i in range(nb_force.getNumParticleParameterOffsets()):
        param, particle_idx, chargeScale, sigmaScale, epsilonScale = nb_force.getParticleParameterOffset(i)
        nb_force.setParticleParameterOffset(i, 'electrostatic_scale', particle_idx, 0.0, 0.0, 0.0)
        nb_force.setParticleParameterOffset(i, 'steric_scale', particle_idx, 0.0, 0.0, 0.0)

    for i in range(nb_force.getNumExceptionParameterOffsets()):
        param, exception_idx, chargeScale, sigmaScale, epsilonScale = nb_force.getExceptionParameterOffset(i)
        nb_force.setExceptionParameterOffset(i, 'electrostatic_scale', exception_idx, 0.0, 0.0, 0.0)
        nb_force.setExceptionParameterOffset(i, 'steric_scale', exception_idx, 0.0, 0.0, 0.0)

    nb_force = other_system.getForce(3)
    for i in range(nb_force.getNumExceptions()):
        p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
        nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)

    for i in range(nb_force.getNumParticles()):
        charge, sigma, epsilon = nb_force.getParticleParameters(i)
        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
    
    # Create thermodynamic state
    lambda_zero_alchemical_state = RESTState.from_system(REST_system)
    thermostate = ThermodynamicState(REST_system, temperature=T_min)
    compound_thermodynamic_state = CompoundThermodynamicState(thermostate,
                                                              composable_states=[lambda_zero_alchemical_state])

    # Set alchemical parameters
    beta_0 = 1 / (kB * T_min)
    beta_m = 1 / (kB * T)
    compound_thermodynamic_state.set_alchemical_parameters(beta_0, beta_m)

    # Get energy
    integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
    context = compound_thermodynamic_state.create_context(integrator)
    context.setPositions(positions)
#     sampler_state = SamplerState.from_context(context)
#     REST_energy = compound_thermodynamic_state.reduced_potential(sampler_state)
    REST_energy = context.getState(getEnergy=True).getPotentialEnergy().value_in_unit_system(unit.md_unit_system)

    # Compute energy for non-RESTified system
    # Determine regions and scaling factors
    nonrest_atoms = [i for i in range(other_system.getNumParticles()) if i not in rest_atoms]
    rest_scaling = beta_m / beta_0
    inter_scaling = np.sqrt(beta_m / beta_0)
    
    # Compute 3 energies: 
    ## Test case 1: rest-rest energy
    ## Test case 2: nonrest-nonrest energy
    ## Test case 3: total energy
    unmodified_energies = []
    for test_case in range(3):
        system_copy = copy.deepcopy(other_system)
        
        if test_case != 2:
            bond_force = system_copy.getForce(0)
            for bond_index in range(bond_force.getNumBonds()):
                p1, p2, length, k = bond_force.getBondParameters(bond_index)
                particles = [p1, p2]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        bond_force.setBondParameters(bond_index, p1, p2, length, k*0)
                    else:
                        continue
                else:
                    bond_force.setBondParameters(bond_index, p1, p2, length, k*0)

            angle_force = system_copy.getForce(1)
            for angle_index in range(angle_force.getNumAngles()):
                p1, p2, p3, angle, k = angle_force.getAngleParameters(angle_index)
                particles = [p1, p2, p3]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 
                    else:
                        continue
                else:
                    angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k*0) 

            torsion_force = system_copy.getForce(2)
            for torsion_index in range(torsion_force.getNumTorsions()):
                p1, p2, p3, p4, periodicity, phase, k = torsion_force.getTorsionParameters(torsion_index)
                particles = [p1, p2, p3, p4]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)
                    else:
                        continue
                else:
                    torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k*0)

            nb_force = system_copy.getForce(3)
            for i in range(nb_force.getNumExceptions()):
                p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(i)
                particles = [p1, p2]
                if all(x in rest_atoms for x in particles):
                    if test_case == 1:
                        nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
                    else:
                        continue
                elif all(x in nonrest_atoms for x in particles):
                    if test_case == 0:
                        nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)
                    else:
                        continue
                else:
                    nb_force.setExceptionParameters(i, p1, p2, chargeProd*0, sigma, epsilon*0)

            nb_force = system_copy.getForce(3)
            for i in range(nb_force.getNumParticles()):
                charge, sigma, epsilon = nb_force.getParticleParameters(i)
                if i in rest_atoms:
                    if test_case == 1:
                        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
                    else:
                        continue
                else:
                    if test_case == 0:
                        nb_force.setParticleParameters(i, charge*0, sigma, epsilon*0)
                    else:
                        continue     
            
        # Get energy
        thermostate = ThermodynamicState(system_copy, temperature=T_min)
        integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond)
        context = thermostate.create_context(integrator)
        context.setPositions(positions)
#         sampler_state = SamplerState.from_context(context)
#         system_copy_energy = thermostate.reduced_potential(sampler_state)
        system_copy_energy = context.getState(getEnergy=True).getPotentialEnergy().value_in_unit_system(unit.md_unit_system)
        unmodified_energies.append(system_copy_energy)
    
    print(unmodified_energies)
    unmodified_energy = unmodified_energies[0] * rest_scaling + unmodified_energies[1] + (unmodified_energies[2] - unmodified_energies[0] - unmodified_energies[1]) * inter_scaling
    print(REST_energy)
    print(unmodified_energy)
    assert np.isclose(REST_energy, unmodified_energy), f"REST energy was {REST_energy} and unmodified_energy was {unmodified_energy}"
                            

In [24]:
test_energy_scaling()


INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=2.0, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=2.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=2.0), unit=nanometer)]
INFO:REST:No unknown forces.


alanine dipeptide in vacuum


INFO:REST:No MonteCarloBarostat added.
INFO:REST:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=3.2852863, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=3.2861648000000003, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=3.1855098), unit=nanometer)]


[1.2664324958308815, 0.34497100125756114, 9.656921484626121]
6.667235068767628
6.66722757631924
alanine dipeptide in solvent


INFO:REST:No unknown forces.
DEBUG:openmmforcefields.system_generators:Trying GAFFTemplateGenerator to load gaff-2.11


[1.2664292771783718, 0.3449719651280984, 9.656918764249463]
6.667232296012481
6.6672266016383634
alanine dipepetide with repartitioned hybrid system


INFO:proposal_generator:	Conducting polymer point mutation proposal...
INFO:proposal_generator:[Atom(name=CB, atomic number=6), Atom(name=HB1, atomic number=1), Atom(name=HB2, atomic number=1), Atom(name=HB3, atomic number=1)]
INFO:proposal_generator:[Atom(name=CB, atomic number=6), Atom(name=CG2, atomic number=6), Atom(name=OG1, atomic number=8), Atom(name=HB, atomic number=1), Atom(name=HG1, atomic number=1), Atom(name=HG21, atomic number=1), Atom(name=HG22, atomic number=1), Atom(name=HG23, atomic number=1)]
INFO:geometry:propose: performing forward proposal
INFO:geometry:propose: unique new atoms detected; proceeding to _logp_propose...
INFO:geometry:Conducting forward proposal...
INFO:geometry:Computing proposal order with NetworkX...
INFO:geometry:number of atoms to be placed: 8
INFO:geometry:Atom index proposal order is [10, 18, 14, 13, 16, 19, 17, 15]
INFO:geometry:omitted_bonds: []
INFO:geometry:direction of proposal is forward; creating atoms_with_positions and new positions 

making topology proposal
generating geometry engine
making geometry proposal from ALA to THR


INFO:geometry:creating growth system...
INFO:geometry:	creating bond force...
INFO:geometry:	there are 11 bonds in reference force.
INFO:geometry:	creating angle force...
INFO:geometry:	there are 43 angles in reference force.
INFO:geometry:	creating torsion force...
INFO:geometry:	creating extra torsions force...
INFO:geometry:	there are 72 torsions in reference force.
INFO:geometry:	creating nonbonded force...
INFO:geometry:		grabbing reference nonbonded method, cutoff, switching function, switching distance...
INFO:geometry:		creating nonbonded exception force (i.e. custom bond for 1,4s)...
INFO:geometry:		looping through exceptions calculating growth indices, and adding appropriate interactions to custom bond force.
INFO:geometry:		there are 1654 in the reference Nonbonded force
INFO:geometry:Neglected angle terms : []
INFO:geometry:omitted_growth_terms: {'bonds': [], 'angles': [], 'torsions': [], '1,4s': []}
INFO:geometry:extra torsions: {0: (7, 6, 8, 10, [1, Quantity(value=0.99000

conducting subsequent work with the following platform: CUDA


INFO:geometry:setting atoms_with_positions context new positions


conducting subsequent work with the following platform: CUDA


INFO:geometry:There are 8 new atoms
INFO:geometry:	reduced angle potential = 1.8764759956267916.
INFO:geometry:	reduced angle potential = 0.7480530411444264.
INFO:geometry:	reduced angle potential = 0.00025952110242746036.
INFO:geometry:	reduced angle potential = 0.028641196311232166.
INFO:geometry:	reduced angle potential = 0.0024587666803323263.
INFO:geometry:	reduced angle potential = 0.08513584378342594.
INFO:geometry:	reduced angle potential = 1.2242897963470216.
INFO:geometry:	reduced angle potential = 0.2900992163235368.
INFO:geometry:	beginning construction of no_nonbonded final system...
INFO:geometry:	initial no-nonbonded final system forces ['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce', 'MonteCarloBarostat']
INFO:geometry:	final no-nonbonded final system forces dict_keys(['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce'])
INFO:geometry:	there are 11 bond forces in the no-nonbonded final system
INFO:geo

conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA


INFO:geometry:total reduced potential before atom placement: 9.116481141712038


conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA


INFO:geometry:total reduced energy added from growth system: -39.98682378132465
INFO:geometry:final reduced energy -30.870345104074598
INFO:geometry:sum of energies: -30.870342639612616
INFO:geometry:magnitude of difference in the energies: 2.4644619855962446e-06
INFO:geometry:Final logp_proposal: 59.84296653065326


added energy components: [('CustomBondForce', 0.3440959708274345), ('CustomAngleForce', 10.994248505225615), ('CustomTorsionForce', 16.956110095782638), ('CustomBondForce', -68.28127835316033)]


INFO:geometry:logp_reverse: performing reverse proposal
INFO:geometry:logp_reverse: unique new atoms detected; proceeding to _logp_propose...
INFO:geometry:Conducting forward proposal...
INFO:geometry:Computing proposal order with NetworkX...
INFO:geometry:number of atoms to be placed: 4
INFO:geometry:Atom index proposal order is [10, 11, 13, 12]
INFO:geometry:omitted_bonds: []
INFO:geometry:direction of proposal is reverse; creating atoms_with_positions from old system/topology
INFO:geometry:creating growth system...
INFO:geometry:	creating bond force...
INFO:geometry:	there are 9 bonds in reference force.
INFO:geometry:	creating angle force...
INFO:geometry:	there are 36 angles in reference force.
INFO:geometry:	creating torsion force...
INFO:geometry:	creating extra torsions force...
INFO:geometry:	there are 42 torsions in reference force.
INFO:geometry:	creating nonbonded force...
INFO:geometry:		grabbing reference nonbonded method, cutoff, switching function, switching distance...

conducting subsequent work with the following platform: CUDA


INFO:geometry:setting atoms_with_positions context old positions


conducting subsequent work with the following platform: CUDA


INFO:geometry:There are 4 new atoms
INFO:geometry:	reduced angle potential = 0.08012165173241892.
INFO:geometry:	reduced angle potential = 1.2915588460963948e-10.
INFO:geometry:	reduced angle potential = 7.39096069988752e-11.
INFO:geometry:	reduced angle potential = 3.205832446488702e-13.
INFO:geometry:	beginning construction of no_nonbonded final system...
INFO:geometry:	initial no-nonbonded final system forces ['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce', 'MonteCarloBarostat']
INFO:geometry:	final no-nonbonded final system forces dict_keys(['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce'])
INFO:geometry:	there are 9 bond forces in the no-nonbonded final system
INFO:geometry:	there are 36 angle forces in the no-nonbonded final system
INFO:geometry:	there are 42 torsion forces in the no-nonbonded final system
INFO:geometry:reverse final system defined with 0 neglected angles.


conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA


INFO:geometry:total reduced potential before atom placement: 9.116481141712038


conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA
conducting subsequent work with the following platform: CUDA


INFO:geometry:total reduced energy added from growth system: 20.67298667212561
INFO:geometry:final reduced energy 29.78946596595155
INFO:geometry:sum of energies: 29.789467813837646
INFO:geometry:magnitude of difference in the energies: 1.8478860965842614e-06
INFO:geometry:Final logp_proposal: -27102.43303179193
INFO:relative:*** Generating RepartitionedHybridTopologyFactory ***
INFO:relative:Old system forces: dict_keys(['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce', 'MonteCarloBarostat'])


added energy components: [('CustomBondForce', 0.0005202039273265048), ('CustomAngleForce', 0.45111977562481365), ('CustomTorsionForce', 7.25046332853325), ('CustomBondForce', 12.970883364040219)]


INFO:relative:New system forces: dict_keys(['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce', 'MonteCarloBarostat'])
INFO:relative:No unknown forces.
INFO:relative:Nonbonded method to be used (i.e. from old system): 4
INFO:relative:Adding and mapping old atoms to hybrid system...
INFO:relative:Adding and mapping new atoms to hybrid system...
INFO:relative:Added MonteCarloBarostat.
INFO:relative:getDefaultPeriodicBoxVectors added to hybrid: [Quantity(value=Vec3(x=2.56477354, y=0.0, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=2.56477354, z=0.0), unit=nanometer), Quantity(value=Vec3(x=0.0, y=0.0, z=2.56477354), unit=nanometer)]
INFO:relative:Determined atom classes.
INFO:relative:Generating old system exceptions dict...
INFO:relative:Generating new system exceptions dict...
INFO:relative:Handling constraints...
INFO:relative:Handling virtual sites...
INFO:relative:	_handle_virtual_sites: numVirtualSites: 0
INFO:relative:	_add_nonbonded_force_te

[53.57283004118337, 10.335210739359, 112.52719556658857]
71.50057081468864
71.50055980472234
