# Core Imports

In [36]:
# Generic Imports
import re
from functools import partial, cached_property
from collections import defaultdict
from itertools import combinations, chain
from ast import literal_eval

# Numeric imports
import pandas as pd
import numpy as np

# File I/O
from pathlib import Path
import csv, json, openpyxl, pickle

# Logging
from tqdm import tqdm as tqdm_text
from tqdm.notebook import tqdm as tqdm_notebook

# Typing and Subclassing
from typing import Any, Callable, ClassVar, Generator, Iterable, Optional, Union
from dataclasses import dataclass, field
from abc import ABC, abstractmethod, abstractproperty

# Cheminformatics
from rdkit import Chem
from rdkit.Chem import rdChemReactions

from openff.toolkit import ForceField
from openff.toolkit.topology import Topology, Molecule

from openforcefields.openforcefields import get_forcefield_dirs_paths
OPENFF_DIR = Path(get_forcefield_dirs_paths()[0])

# File and chemistry type definitions

In [37]:
nmer_name = 'trimers'
# lattice_size = '1x1x1'
lattice_size = '5x5x5'

topo_dir = Path('Topologies') / nmer_name / lattice_size
ics_dir = Path('Interchanges') / nmer_name / lattice_size
lammps_dir = Path('LAMMPS') / nmer_name / lattice_size
omm_dir = Path('OpenMM') / nmer_name / lattice_size

# Creating OpenMM and LAMMPS systems

## Harvest and tabulate paths + info for all Interchange and Topology files

In [41]:
import pandas as pd
from polymerist.genutils.fileutils.pathutils import assemble_path


MOL_MASTER_DIR = Path('polymer_structures')
TAGS = (
    'mechanism',
    'mol_name',
    'oligomer_size',
    'lattice_size'
)

records = []
for mol_dir in MOL_MASTER_DIR.glob('**/[0-9]x[0-9]x[0-9]'):
    mol_info = {
        tag : value
            for tag, value in zip(TAGS, mol_dir.relative_to(MOL_MASTER_DIR).parts)
    }
    mol_info['directory'] = mol_dir
    
    inc_path = assemble_path(mol_dir, prefix=mol_info['lattice_size'], postfix=mol_info['mol_name'], extension='pkl')
    if inc_path.exists():
        mol_info['interchange_path'] = inc_path
    
    top_path = assemble_path(mol_dir, prefix=mol_info['lattice_size'], postfix=mol_info['mol_name'], extension='sdf')
    if top_path.exists():
        mol_info['topology_path'] = top_path

    records.append(mol_info)

mol_file_frame = pd.DataFrame.from_records(records)
groups = mol_file_frame.groupby(['lattice_size', 'mechanism'])

## Setting Simulation parameters

In [42]:
from polymerist.openmmtools.parameters import SimulationParameters, IntegratorParameters, ThermoParameters, ReporterParameters
from openmm.unit import Unit, Quantity
from openmm.unit import kelvin, atmosphere # T & P
from openmm.unit import femtosecond, picosecond, nanosecond

state_data_props : dict[str, bool] = {
    'step'            : True,
    'time'            : True,
    'potentialEnergy' : True,
    'kineticEnergy'   : True,
    'totalEnergy'     : True,
    'temperature'     : True,
    'volume'          : True,
    'density'         : True,
    'speed'           : True,
    'progress'        : False,
    'remainingTime'   : False,
    'elapsedTime'     : False
}

sim_params = SimulationParameters(
    integ_params = IntegratorParameters(
        time_step=2*femtosecond,
        total_time=1*nanosecond,
        num_samples=10
    ),
    thermo_params = ThermoParameters(
        ensemble='NVT', #'NPT,
        temperature=300*kelvin,
        pressure=1*atmosphere,
        friction_coeff=1*picosecond**-1,
        barostat_freq=25
    ),
    reporter_params = ReporterParameters(
        report_checkpoint=True,
        report_state     =True,
        report_trajectory=True,
        report_state_data=True,
        traj_ext='dcd',
        state_data=state_data_props,
    )
)
sim_params.to_file('sim_param.json')

In [43]:
force_name_remap = { # TODO : move this to energy eval as a simple remap (suitable names are already set by Interchange)
    'vdW force'                : 'vdW',
    'Electrostatics force'     : 'Electrostatic',
    'vdW 1-4 force'            : 'vdW 1-4',
    'Electrostatics 1-4 force' : 'Electrostatic 1-4',
    'PeriodicTorsionForce'     : 'Dihedral',
    'HarmonicAngleForce'       : 'Angle',
    'HarmonicBondForce'        : 'Bond'
}

## Manually create OpenMM sims from Interchange

## Playing with Interchange settings

In [49]:
import pickle
from polymerist.openfftools import topology

frame = groups.get_group(('1x1x1', 'polyamide'))

row = frame.iloc[0]

offtop = topology.topology_from_sdf(row.topology_path)
ommtop = offtop.to_openmm()

with row.interchange_path.open('rb') as inc_file:
    interchange = pickle.load(inc_file)

In [51]:
from openff.interchange.components.mdconfig import MDConfig, _infer_constraints, get_smirnoff_defaults
from openff.interchange.constants import _PME
from openmm.unit import angstrom

interchange['vdW'].switch_width = 0.0 * angstrom # TOSELF : setting to NoneType manually raises Exception, see if there is a more canonical way to do this

smirnoff_mdc = get_smirnoff_defaults(periodic=True)
mdc = MDConfig.from_interchange(interchange)
mdc.dict()

{'periodic': True,
 'constraints': 'none',
 'vdw_method': 'cutoff',
 'vdw_cutoff': 9.0 <Unit('angstrom')>,
 'mixing_rule': 'lorentz-berthelot',
 'switching_function': False,
 'switching_distance': 0.0 <Unit('angstrom')>,
 'coul_method': 'Ewald3D-ConductingBoundary',
 'coul_cutoff': 9.0 <Unit('angstrom')>}

In [52]:
from openmm import Context, Platform
from openmm.app import Simulation

from openff.interchange.interop.openmm._positions import to_openmm_positions

from polymerist.openmmtools import serialization, preparation
from polymerist.openmmtools.thermo import EnsembleFactory


ommsys = interchange.to_openmm(combine_nonbonded_forces=False)
desc_dict, force_map = describe_ommsys_forces(ommsys)
desc = dict_to_indented_str(desc_dict)
# print(desc)

ensfac = EnsembleFactory.from_thermo_params(sim_params.thermo_params)
integrator = ensfac.integrator(time_step=sim_params.integ_params.time_step)

context = Context(ommsys, integrator)
context.setPositions(to_openmm_positions(interchange, include_virtual_sites=True))

In [12]:
{
    force_name : context.getState(getEnergy=True, groups={group_idx}).getPotentialEnergy()
        for force_name, group_idx in force_map.items()
}

{'vdW force': Quantity(value=42275038.18413356, unit=kilojoule/mole),
 'Electrostatics force': Quantity(value=-781.1802186879249, unit=kilojoule/mole),
 'vdW 1-4 force': Quantity(value=451.7369384765625, unit=kilojoule/mole),
 'Electrostatics 1-4 force': Quantity(value=-804.5572509765625, unit=kilojoule/mole),
 'PeriodicTorsionForce': Quantity(value=153.1345672607422, unit=kilojoule/mole),
 'HarmonicAngleForce': Quantity(value=1324.5399169921875, unit=kilojoule/mole),
 'HarmonicBondForce': Quantity(value=1322.8663330078125, unit=kilojoule/mole)}

In [13]:
from polymerist.genutils.fileutils.pathutils import assemble_path

lammps_lmp = Path('test.lammps')
lammps_in  = Path('test.in')

lammps_files_from_interchange(interchange, lmp_data_path=lammps_lmp, lmp_input_path=lammps_in)

# Defining utilities

## LAMMPS functions

In [118]:
from openmm.unit import kilocalorie_per_mole, joule, ergs, hartree
from openmm.unit import picogram, micrometer, microsecond
from openmm.unit import attogram, nanometer, nanosecond

# registering electron volts (not in the default OpenMM unit space)
from openmm.unit import joule, ScaledUnit # energy
from scipy.constants import electron_volt as electron_volt_joules
electronvolt_base = ScaledUnit(electron_volt_joules, joule, 'electronvolt', 'eV')
electronvolt = eV = Unit({electronvolt_base : 1.0})

from openff.interchange import Interchange
from polymerist.genutils.decorators.functional import allow_string_paths


LAMMPS_ENERGY_UNITS = {
    'real'     : kilocalorie_per_mole,
    'metal'    : electronvolt,
    'si'       : joule,
    'cgs'      : ergs,
    'electron' : hartree,
    'micro'    : picogram * micrometer**2 * microsecond**-2,
    'nano'     : attogram * nanometer**2  * nanosecond**-2,
}

E_MAP = {
    'ebond'  : 'Bond',
    'eangle' : 'Angle',
    'edihed' : 'Proper Torsion',
    'eimp'   : 'Improper Torsion',
    'ecoul'  : 'Coulomb Short',
    'elong'  : 'Coulomb Long',
    'evdwl'  : 'vdW',
    'etail'  : 'Dispersion',
    'epair'  : 'Nonbonded',
    'pe'     : 'Potential',
    'ke'     : 'Kinetic',
    'etotal' : 'Total'
}

CELL_KW = ( # keywords for probing unit cell sizes and angles
    'cella',
    'cellb',
    'cellc',
    'cellalpha',
    'cellbeta',
    'cellgamma',
)

def lammps_files_from_interchange(interchange : Interchange, lmp_input_path : Path, lmp_data_path : Path) -> None:
    '''Convert and OpenFF interchange to LAMMPS structure and input files'''
    # validating Paths
    assert(lmp_input_path.suffix == '.in')
    assert(lmp_data_path.suffix in ('.lmp', '.lammps'))
    
    # writing out files
    interchange.to_lammps(lmp_data_path) # MD data file
    mdc = MDConfig.from_interchange(interchange)
    mdc.write_lammps_input(lmp_input_path) # input directive file

    # replacing generic lmp file with data file from above
    with lmp_input_path.open('r') as in_file:
        in_file_block = in_file.read()

    with lmp_input_path.open('w') as in_file:
        in_file.write(
            in_file_block.replace('out.lmp', f'"{lmp_data_path}"') # need surrounding double quotes to allow LAMMPS to read special symbols in filename (if present)
        )


LMP_THERMO_STYLE_REGEX = re.compile(r'^thermo_style\s(?P<thermo_style>\b\w*?\b)\s(?P<energy_evals>.*)$')
LMP_UNIT_REGEX = re.compile(r'^units\s(?P<unit_style>\w*)$')

@allow_string_paths
def parse_lammps_input(lmp_in_path : Path) -> tuple[str, list[str]]:
    '''Read which thermodynamic energy contributions will be calculated from a LAMMPS input file block'''
    info_dict = {}
    with lmp_in_path.open('r') as lmp_in_file:
        for line in lmp_in_file.read().split('\n'):
            if (thermo_match := re.search(LMP_THERMO_STYLE_REGEX, line)):
                info_dict.update(thermo_match.groupdict())
                info_dict['energy_evals'] = info_dict['energy_evals'].split(' ') # separate on spaces (TODO : maybe find more elegant way to do this in the future?)
            
            if (units_match := re.search(LMP_UNIT_REGEX, line)):
                info_dict.update(units_match.groupdict())

    return info_dict

In [None]:
import lammps
from IPython.display import clear_output

def eval_lammps_energies(lmp_in_path : Path) -> dict[str, Quantity]:
    '''Perform an energy evaluation using a LAMMPS input file'''
    lammps_info = parse_lammps_input(lmp_in_path)
    energy_unit = LAMMPS_ENERGY_UNITS.get(lammps_info['unit_style'])
    energy_contribs = lammps_info['energy_evals']

    with lammps.lammps() as lmp: # need to create new lammps() object instance for each run
        # lmp.commands_string( ENERGY_EVAL_STR.replace('$INP_FILE', str(lammps_file)) )
        ## Getting energies
        lmp.file(str(lmp_in_path)) # read input file and calculate energies
        return {
            f'{E_MAP[contrib]}' : lmp.get_thermo(contrib) * energy_unit
                for contrib in energy_contribs
        }

def eval_lammps_unit_cell(lmp_in_path : Path) -> dict[str, Quantity]:
    '''Extract the 6 unit cell parameters specified by a LAMMPS input file'''
    with lammps.lammps() as lmp: # need to create new lammps() object instance for each run
        # lmp.commands_string( ENERGY_EVAL_STR.replace('$INP_FILE', str(lammps_file)) )
        ## Getting energies
        lmp.file(str(lmp_in_path)) # read input file and calculate energies
        return {
            cp : lmp.get_thermo(cp) # TODO : include proper units based on unit style
                for cp in CELL_KW
        }

## OpenMM functions

In [None]:
from textwrap import indent

def dict_to_indented_str(dict_to_stringify : dict[Any, Any], level_delimiter : str='\t') -> str:
    '''Generate a pretty-printable string from a (possibly nested) dictionary,
    with each level of nesting indicated by "level_delimiter"'''
    text = []
    for key, value in dict_to_stringify.items():
        if isinstance(value, dict):
            text.append(key)
            text.append(indent(dict_to_indented_str(value), level_delimiter)) # recursive call for nested dicts
        else:
            text.append(f'{key!r} : {value!r}') # call repr methods

    return '\n'.join(text)

In [121]:
from openmm import System, NonbondedForce
from openmm.unit import kilojoule_per_mole
from polymerist.genutils.containers import RecursiveDict


NONBOND_CUTOFF_METHOD_NAMES = (
    'NoCutoff',
    'CutoffNonPeriodic',
    'CutoffPeriodic',
    'Ewald',
    'PME',
    'LJPME',
)
NONBOND_CUTOFF_METHODS = {
    idx : method_name
        for idx, method_name in sorted( # sort in ascending order by integer code
            (getattr(NonbondedForce, method_name), method_name)
                for method_name in NONBOND_CUTOFF_METHOD_NAMES
        )
}

def describe_ommsys_forces(ommsys : System) -> tuple[str, dict[str, int]]:
    '''Describes accessible parameters associated with each Force in an OpenMM system
    Also maps each Force's force_group to a unique id

    Returns the decriptive text as a string, and a dict mapping each Force's name to it's id'''
    force_map : dict[str, int] = {}
    descript_dict = RecursiveDict()

    for i, force in enumerate(ommsys.getForces()):
        force.setForceGroup(i)
        force_name = force.getName()
        force_map[force_name] = i
        descript_dict[force_name]['type'] = type(force).__name__
        
        for attr in dir(force):
            if attr.startswith('get'):
                try:
                    attr_val = getattr(force, attr)()
                    if attr == 'getNonbondedMethod': # convert integer index into readable name of nonbonded cutoff method
                        attr_val = NONBOND_CUTOFF_METHODS[attr_val]
                    descript_dict[force_name][attr.removeprefix('get')] = attr_val
                except TypeError: # called when the getter expects more than 0 arguments
                    pass
                
    return descript_dict, force_map

def eval_openmm_energies(context : Context) -> dict[str, Quantity]:
    '''Perform an energy evaluation on an OpenMM Context'''
    openmm_energies = {}

    # get global energies
    overall_state = context.getState(getEnergy=True) # get total potential energy
    openmm_energies['Potential'] = overall_state.getPotentialEnergy()
    openmm_energies['Kinetic'  ] = overall_state.getKineticEnergy()

    # get individual energies from each force type
    for i, force in enumerate(ommsys.getForces()):
        state = context.getState(getEnergy=True, groups={i})
        force_label = force_name_remap.get(force.getName(), force.getName()) # check if a remapped name is registered, otherwise use the Force's set name
        openmm_energies[force_label] = state.getPotentialEnergy()

    return openmm_energies

In [124]:
NULL_ENERGY = 0.0*kilojoule_per_mole
PRECISION : int = 4

openmm_energies = eval_openmm_energies(context)
assert(openmm_energies['Kinetic'] == NULL_ENERGY) # check total KE to verify no integration is being done
print(openmm_energies)

{'Potential': Quantity(value=42276701.59864143, unit=kilojoule/mole), 'Kinetic': Quantity(value=0.0, unit=kilojoule/mole), 'vdW': Quantity(value=42275038.18413356, unit=kilojoule/mole), 'Electrostatic': Quantity(value=-781.1802186879249, unit=kilojoule/mole), 'vdW 1-4': Quantity(value=451.7369384765625, unit=kilojoule/mole), 'Electrostatic 1-4': Quantity(value=-804.5572509765625, unit=kilojoule/mole), 'Dihedral': Quantity(value=153.1345672607422, unit=kilojoule/mole), 'Angle': Quantity(value=1324.5399169921875, unit=kilojoule/mole), 'Bond': Quantity(value=1322.8663330078125, unit=kilojoule/mole)}


## Collapse

In [None]:
# specifying simulation and ensemble parameters
from openff.interchange.components.mdconfig import MDConfig

from openmm.app import Simulation
from openmm import NonbondedForce, CustomNonbondedForce
from openmm import MonteCarloBarostat, LangevinMiddleIntegrator

from openmm.unit import atmosphere, kelvin, nanometer
from openmm.unit import femtosecond, picosecond

from openff.interchange import Interchange
from openff.units import unit as offunit

from polymerist.genutils.fileutils.pathutils import assemble_path
from polymerist.unitutils import openff_to_openmm

from polymerist.openmmtools.parameters import IntegratorParameters, ThermoParameters, SimulationParameters
from polymerist.openmmtools.thermo import EnsembleFactory, ThermoParameters
from polymerist.openmmtools.serialization import serialize_system, serialize_state_from_context, DEFAULT_STATE_PROPS

# Long-range parameters
long_range_params = {
    'setCutoffDistance'          : 0.9 * nanometer,               # nonbonded cutoff distance
    'setNonbondedMethod'         : NonbondedForce.CutoffPeriodic, # .NoCutoff, .CutoffNonPeriodic
    'setUseSwitchingFunction'    : False,                         # whether to use a switching function (hard to cross-validate)
    'setUseDispersionCorrection' : True,                          # use dispersion correction
    'setUseLongRangeCorrection'  : True,                          # use dispersion correction (alias for Custom Nonbondeds)
}

# Thermodynamic/integrator parameters
timestep = 2*femtosecond

thermo_params = ThermoParameters(
    ensemble='NVT', #'NPT,
    temperature=300*kelvin,
    pressure=1*atmosphere,
    friction_coeff=1*picosecond**-1,
    barostat_freq=25
)
ens_fac = EnsembleFactory.from_thermo_params(thermo_params)

# ======================================

force_name_remap = { # TODO : move this to energy eval as a simple remap (suitable names are already set by Interchange)
    'vdW force'                : 'vdW',
    'Electrostatics force'     : 'Electrostatic',
    'vdW 1-4 force'            : 'vdW 1-4',
    'Electrostatics 1-4 force' : 'Electrostatic 1-4',
    'PeriodicTorsionForce'     : 'Dihedral',
    'HarmonicAngleForce'       : 'Angle',
    'HarmonicBondForce'        : 'Bond'
}

for chem_dir in ics_dir.iterdir():
    chemistry = chem_dir.stem
    lmp_chem_dir = lammps_dir / chemistry
    lmp_chem_dir.mkdir(exist_ok=True)
    
    omm_chem_dir = omm_dir/ chemistry
    omm_chem_dir.mkdir(exist_ok=True)

    progress = tqdm_notebook([file for file in chem_dir.iterdir() if file.suffix == '.pkl'])
    for ics_file in progress:
        mol_name = ics_file.stem
        with ics_file.open('rb') as pklfile:
            interchange = pickle.load(pklfile)

        progress.set_postfix_str(f'{chemistry} : {mol_name}')
        omm_mol_dir = omm_chem_dir / mol_name
        omm_mol_dir.mkdir(exist_ok=True)

        if omm_mol_dir.exists() and any(omm_mol_dir.iterdir()): # skip over if dir exists and is non-empty
            continue
        
    # saving LAMMPS files
        progress.set_description('Generating LAMMPS files')
        lmp_mol_dir = lmp_chem_dir / mol_name
        lmp_mol_dir.mkdir(exist_ok=True)
        lammps_files_from_interchange(interchange, mol_name=mol_name, interchange=interchange)

    # creating OpenMM Simulation
        progress.set_description('Building OpenMM Simulation')
        # specifying thermo/baro to determine ensemble
        integrator   = ens_fac.integrator(time_step=timestep)
        extra_forces = ens_fac.forces()

        # loading OpenMM sim components from Interchange
        omm_top = interchange.topology.to_openmm()
        omm_sys = interchange.to_openmm(combine_nonbonded_forces=False)
        omm_pos = interchange.positions.m_as(offunit.nanometer)

        # configuring bound Force objects
        if extra_forces:
            for force in extra_forces:
                omm_sys.addForce(force)

        ## separate forces by number, remap names, and set long-range parameters
        for i, force in enumerate(omm_sys.getForces()):
            force.setForceGroup(i)
            force.setName(force_name_remap[force.getName()])

            for long_range_attr, chosen_value in long_range_params.items():
                if hasattr(force, long_range_attr):
                    getattr(force, long_range_attr)(chosen_value)

        # create OpenMM Simulation
        sim = Simulation(omm_top, omm_sys, integrator)
        sim.context.setPositions(omm_pos)

    # saving OpenMM files
        progress.set_description('Generating OpenMM files')

        sdf_out_path = omm_mol_dir / f'{mol_name}_topology.sdf' # TODO : change this to copy the pre-existing OpenFF Topology SDF
        sdf_out_path.touch()

        for mol in interchange.topology.molecules: # use OpenFF format for saving Molecules (much more convenient to work with)
            mol.to_file(str(sdf_out_path), file_format=sdf_out_path.suffix[1:])

        sys_out_path   = assemble_path(omm_mol_dir, mol_name, extension='xml', postfix='system')
        serialize_system(sys_out_path, sim.system)

        state_out_path = assemble_path(omm_mol_dir, mol_name, extension='xml', postfix='state')
        serialize_state_from_context(state_out_path, sim.context, state_params=DEFAULT_STATE_PROPS)

# Evaluating LAMMPS energies

In [None]:
import lammps
from IPython.display import clear_output


failed = defaultdict(list)
records = {}
cell_sizes = {}
for subdir in lammps_dir.iterdir():
    if subdir.is_dir():
        chemistry = subdir.name
        for mol_dir in subdir.iterdir():
            mol_name = mol_dir.stem
            lammps_file = mol_dir / f'{mol_name}.lammps'
            lammps_in   = mol_dir / f'{mol_name}.in'
            
            # craete LAMMPS wrapper and execute input calc
            with lammps.lammps() as lmp: # need to create new lammps() object instance for each run
                # lmp.commands_string( ENERGY_EVAL_STR.replace('$INP_FILE', str(lammps_file)) )
                try:
                    ## Getting energies
                    lmp.file(str(lammps_in)) # read input file and calculate energies
                    with lammps_in.open('r') as in_file:
                        thermo_style, calc_energies = get_calc_lmp_energies(in_file.read())

                    energies = {
                        E_MAP[contrib] : lmp.get_thermo(contrib)
                            for contrib in calc_energies
                    }
                except:
                    failed[chemistry].append(mol_name)
                    continue

                ## Getting unit cell dimensions
                cell_params = {
                    cp : lmp.get_thermo(cp)
                        for cp in CELL_KW
                }

            # reformatting energies
            energies = {
                f'{contrib} (kcal/mol)' : energy # add units to labels
                    for contrib, energy in energies.items()
            }
            
            # save records for Pandas DataFrames
            records[(chemistry, mol_name)] = energies
            cell_sizes[(chemistry, mol_name)] = cell_params
            clear_output() # wipe lengthy LAMMPS printouts

In [None]:
failed

In [None]:
lmp_table = pd.DataFrame.from_dict(records, 'index')
lmp_table.index.names  = ['Chemistry', 'Molecule'] # ensure index labels are labelled consistently
lmp_table.sort_values('Molecule', inplace=True)
lmp_table.to_csv(lammps_dir/f'{lammps_dir.name}_PEs.csv')

# Evaluating OpenMM energies

In [None]:
from openmm import XmlSerializer, Force, System
from openmm.unit import kilojoule_per_mole, kilocalorie_per_mole

# parameters
sep_force_grps : bool = True
remove_constrs : bool = False

# utility function
def load_openmm_system(sys_path : Path, extra_forces : Optional[Union[Force, Iterable[Force]]]=None, sep_force_grps : bool=True, remove_constrs : bool=False) -> System:
    '''Load and configure a serialized OpenMM system, with optional additional parameters'''
    assert(sys_path.suffix == '.xml')
    with sys_path.open('r') as file:
        ommsys = XmlSerializer.deserialize(file.read())

    if extra_forces: # deliberately sparse to handle both Nonetype and empty list
        for force in extra_forces: 
            ommsys.addForce(force)

    if sep_force_grps:
        for i, force in enumerate(ommsys.getForces()):
            force.setForceGroup(i)

    if remove_constrs:
        for i in range(ommsys.getNumConstraints())[::-1]: # need to remove in reverse order to avoid having prior constraints "fall back down"
            ommsys.removeConstraint(i)

    return ommsys


NULL_ENERGY = 0.0*kilojoule_per_mole
PRECISION : int = 4

# iterate over serialized directory tree and load
data_dicts = []
omm_sims = defaultdict(defaultdict)
for subdir in omm_dir.iterdir():
    if subdir.is_dir():
        chemistry = subdir.name
        progress = tqdm_notebook([f for f in subdir.iterdir()])
        for mol_dir in progress:
            mol_name = mol_dir.name

            state_file = mol_dir / f'{mol_name}_state.xml'
            sys_file   = mol_dir / f'{mol_name}_system.xml'
            top_file   = mol_dir / f'{mol_name}_topology.sdf'

            offmol = Molecule.from_file(top_file)
            offtop = Topology.from_molecules(offmol)
            
            integrator = LangevinMiddleIntegrator(T, friction, timestep)
            # extra_forces = [MonteCarloBarostat(P, T, baro_freq)]
            extra_forces = None

            # load and configure System
            omm_top = offtop.to_openmm()
            omm_sys = load_openmm_system(
                sys_file,
                extra_forces=extra_forces,
                sep_force_grps=sep_force_grps,
                remove_constrs=remove_constrs
            )

            # putting it all together into a Simulation
            sim = Simulation(
                topology=omm_top,
                system=omm_sys,
                integrator=integrator,
                state=state_file
            )
            omm_sims[chemistry][mol_name] = sim

            # extract total and component energies from OpenMM force groups
            data_dict = {
                'Chemistry' : chemistry,
                'Molecule'  : mol_name
            }
            omm_energies = {}

            ## Total Potential
            overall_state = sim.context.getState(getEnergy=True) # get total potential energy
            PE = overall_state.getPotentialEnergy()
            omm_energies['Potential'] = PE

            ## Total Kinetic (to verify no integration is being done)
            KE = overall_state.getKineticEnergy()
            omm_energies['Kinetic'] = KE
            assert(KE == NULL_ENERGY)

            ## Individual force contributions
            for i, force in enumerate(sim.system.getForces()):
                state = sim.context.getState(getEnergy=True, groups={i})
                omm_energies[force.getName()] = state.getPotentialEnergy()

            # reformat to desired units and precision
            omm_energies_kcal = {}
            for contrib_name, energy_kj in omm_energies.items():
                energy_kcal = energy_kj.in_units_of(kilocalorie_per_mole)
                omm_energies_kcal[f'{contrib_name} ({energy_kcal.unit.get_symbol()})'] = round(energy_kcal._value, PRECISION)

            # compile data
            data_dict = {**data_dict, **omm_energies_kcal}
            data_dicts.append(data_dict)

omm_table = pd.DataFrame.from_records(data_dicts)
omm_table.sort_values('Molecule', inplace=True)
omm_table.set_index(['Chemistry', 'Molecule'], inplace=True)
omm_table.to_csv(omm_dir / f'{omm_dir.name}_PEs.csv')

# Comparing energies

## Loading energy tables and comparing contributions

In [None]:
pd.options.display.float_format = '{:.4f}'.format # disable scientific notation

@dataclass
class TableFormats:
    table_key : str
    sum_terms : dict[str, list[str]]
    del_terms : list[str]

omm_formats = TableFormats(
    table_key = omm_dir.stem,
    sum_terms = {
        'vdW (kcal/mol)' : ['vdW (kcal/mol)', 'vdW 1-4 (kcal/mol)'],
        'Coulomb (kcal/mol)' : ['Electrostatic (kcal/mol)', 'Electrostatic 1-4 (kcal/mol)']
    },
    del_terms = ['Kinetic (kcal/mol)']
)

lmp_formats = TableFormats(
    table_key = lammps_dir.stem,
    sum_terms = {
        'vdW (kcal/mol)' : ['vdW (kcal/mol)', 'Dispersion (kcal/mol)'],
        'Dihedral (kcal/mol)' : ['Proper Torsion (kcal/mol)', 'Improper Torsion (kcal/mol)'],
        'Coulomb (kcal/mol)' : ['Coulomb Short (kcal/mol)', 'Coulomb Long (kcal/mol)']
    },
    del_terms = ['Nonbonded (kcal/mol)']
)

# apply reformatting to respective tables
for fmt in (omm_formats, lmp_formats):
    table_in_path  = Path(fmt.table_key) / f'{fmt.table_key}_PEs.csv'
    table_out_path = Path(fmt.table_key) / f'{fmt.table_key}_PEs_processed.csv'
    table = pd.read_csv(table_in_path, index_col=(0, 1)).sort_index(axis=1)

    # combine selected terms
    for combined_contrib, contribs in fmt.sum_terms.items():
        new_term = sum(
            table[contrib]
                for contrib in contribs
        ) # merge contributions into a single new named term
        table.drop(columns=contribs, inplace=True) # clear contributions
        table[combined_contrib] = new_term # done after drop to ensure name clashes don;t result in extra deletion
    
    # delete redundant terms
    for del_contrib in fmt.del_terms:
        table.drop(columns=[del_contrib], inplace=True) # clear contributions

    globals()[f'{fmt.table_key.lower()}_table'] = table
    table.to_csv(table_out_path)

In [None]:
openmm_table

In [None]:
lammps_table

In [None]:
diff = openmm_table - lammps_table
diff

In [None]:
common_cols = ['Angle (kcal/mol)', 'Bond (kcal/mol)']# 'Torsion (kcal/mol)']

omm_redux = omm_table.drop(columns=common_cols)
lmp_redux = lmp_table.drop(columns=common_cols)

In [None]:
omm_table[common_cols] - lmp_table[common_cols]

In [None]:
omm_redux

## Evaluating energies with drivers

In [None]:
from openff.interchange.drivers.openmm import get_openmm_energies, _get_openmm_energies
from openff.interchange.drivers.lammps import get_lammps_energies, _get_lammps_energies,  _find_lammps_executable
from openff.units.openmm import to_openmm as openff_units_to_openmm

In [None]:
{
    contrib : openff_units_to_openmm(value).in_units_of(kilocalorie_per_mole)
        for contrib, value in get_openmm_energies(interchange, detailed=True, combine_nonbonded_forces=False).energies.items()
}

In [None]:
get_lammps_energies(interchange).energies

## Comparing ParmEd energy decomposition to native OpenMM force-group-based decomposition

In [None]:
import parmed
from openmm.openmm import Force

NULL_ENERGY = 0.0*kilojoule_per_mole

sim = omm_sims['urethane']['urethane_41']
# assign and initialize unique force groups for simulation
for i, force in enumerate(sim.system.getForces()):
    force.setForceGroup(i)
    # print(force.getName(), force.getForceGroup())
sim.context.reinitialize(preserveState=True) # need to reinitialize to get force labelling changes to "stick"

# energies from OpenMM force groups
print('\nOpenMM:')
print('='*30)
omm_energies = {}

## extract total energies for state
overall_state = sim.context.getState(getEnergy=True) # get total potential energy
PE = overall_state.getPotentialEnergy()
omm_energies['Total Potential Energy'] = PE

KE = overall_state.getKineticEnergy()
assert(KE == NULL_ENERGY)

for i, force in enumerate(sim.system.getForces()):
    state = sim.context.getState(getEnergy=True, groups={i})
    force_name = force.getName().removesuffix('Force')
    pe = state.getPotentialEnergy()

    omm_energies[force_name] = pe
    print(f'{force_name} : {pe}')

## converting name to match with ParmEd for comparison
namemap = {
    'Nonbonded' : 'bond',
    'PeriodicTorsion' : 'angle',
    'HarmonicAngle' : 'dihedral',
    'HarmonicBond' : 'urey_bradley',
    'Total Potential Energy' : 'total'
}
compat_omm_energies = {
    namemap[contrib] : energy
        for contrib, energy in omm_energies.items()
}

total = sum(omm_energies.values(), start=NULL_ENERGY) # need "seed" to have Quantity datatype to sum
print(f'{general.GREEK_UPPER["delta"]}E_contrib: ', PE - total)

# ParmEd energy decomposition
print('\nParmEd:')
print('='*30)
parm_energies = {}
parm_struct = parmed.openmm.load_topology(sim.topology, sim.system)
for contrib, energy_val in parmed.openmm.energy_decomposition(parm_struct, sim.context).items():
    parm_energies[contrib] = energy = energy_val*kilocalorie_per_mole # assign proper units
    print(contrib, energy.in_units_of(kilojoule_per_mole))