# Imports

In [1]:
## Logging and Shell
import logging

from polymerist.openfftools.partialcharge import molchargers
logging.basicConfig(
    level=logging.ERROR,
    force=True
)

## Generic imports
from typing import Any, Optional
from collections import defaultdict

## Numeric imports
import numpy as np
import pandas as pd

## File I/O
from pathlib import Path
import json, pickle

# Cheminformatics
from rdkit import Chem

from openmm.unit import nanometer, angstrom, Quantity

from openff.toolkit import Molecule, Topology, ForceField
from openff.toolkit.utils.exceptions import (
    UnassignedChemistryInPDBError,
    IncorrectNumConformersWarning,
)

# Custom Imports
from polymerist.genutils.containers import RecursiveDict
from polymerist.genutils.fileutils import filetree
from polymerist.genutils.fileutils.pathutils import assemble_path
from polymerist.duration import Duration, Timer
from polymerist.unitutils.interop import openmm_to_openff, openff_to_openmm

from polymerist.maths.greek import GREEK_PREFIXES
from polymerist.rdutils.rdprops import copy_rd_props
from polymerist.rdutils.rdcoords.tiling import rdmol_effective_radius
from polymerist.rdutils.reactions import reactions, reactors

from polymerist.openfftools.partition import partition
from polymerist.openfftools import topology, boxvectors, FFDIR
from polymerist.monomers import specification, MonomerGroup
from polymerist.polymers import building

# catch annoying warnings
import warnings 
warnings.catch_warnings(record=True)
warnings.filterwarnings('ignore', category=IncorrectNumConformersWarning)



# Load monomer and rxn data 

In [2]:
# Static Paths
RAW_DATA_DIR  = Path('monomer_data_raw')
FMT_DATA_DIR  = Path('monomer_data_formatted')
PROC_DATA_DIR = Path('monomer_data_processed')
RXN_FILES_DIR = Path('poly_rxns')

In [3]:
# input_data_path = PROC_DATA_DIR / '20231114_polyid_data_density_DP2-6 - 1,2 monomers_FILTERED.csv'
# input_data_path = PROC_DATA_DIR / 'nipu_urethanes_FILTERED.csv'
input_data_path = PROC_DATA_DIR / 'monomer_data_MASTER.csv'
df = pd.read_csv(input_data_path, index_col=[0,1])
df = df.replace(np.nan, None) # swap NaN values for explicit NoneTypes to simplify value write

In [4]:
# benchsamp = pd.read_csv('oligomers_for_benchmark.csv', index_col=[0,1])
# df = df.loc[benchsamp.index]

## Load pre-defined reactions with functional group and name backmap

In [5]:
take_first_n : Optional[int] = None # debug option to only take a handful of compounds from each family
# take_first_n : Optional[int] = 1
blacklisted_rxns = ['polyimide']#, 'polyvinyl_head_tail']

frames_by_mech : dict[str, pd.DataFrame] = {}
for rxn_name in df.index.unique(level='mechanism'):
    if rxn_name in blacklisted_rxns:
        continue

    rxn_df = df.xs(rxn_name)
    if take_first_n is not None:
        rxn_df = rxn_df.head(take_first_n)
    frames_by_mech[rxn_name] = rxn_df

In [6]:
with (RXN_FILES_DIR / 'rxn_groups.json').open('r') as file: # load table of functional group for each reaction
    rxn_groups = json.load(file)

rxns = {
    rxnname : reactions.AnnotatedReaction.from_rxnfile(RXN_FILES_DIR / f'{rxnname}.rxn')
        for rxnname in rxn_groups.keys()
}

# Auto-generating monomer fragments and Topologies

## Set up and format progress bars to track build status

In [7]:
from rich.progress import Progress
from rich.progress import (
    BarColumn,
    Progress,
    SpinnerColumn,
    TaskProgressColumn,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)
from rich.console import Group
from rich.live import Live

# status of individual task
status_readout = Progress(
    'STATUS:',
    TextColumn(
        '[purple]{task.fields[action]}'
    ),
    '...'
)
status_id = status_readout.add_task('[green]Current compound:', action='')

# textual display of the name of the curent polymer
compound_readout = Progress(
    'Current compound:',
    TextColumn(
        '[blue]{task.fields[polymer_name]}',
        justify='right'
    )
)
curr_compound_id  = compound_readout.add_task('[green]Compound:', polymer_name='')

# progress over individual compounds (irrespective of mechanism)
compound_progress = Progress(
    SpinnerColumn(),
    "[progress.description]{task.description}",
    BarColumn(),
    TaskProgressColumn(),
    TextColumn(
        '({task.completed} / {task.total})'
    ),
)
comp_progress_id = compound_progress.add_task('[blue]Unique compound(s)   ', polymer_name='')

# progress over distinct classes of mechanism
inter_mech_progress = Progress(
    SpinnerColumn(),
    "[progress.description]{task.description}",
    BarColumn(),
    TaskProgressColumn(),
    TextColumn(
        '({task.completed} / {task.total})'
    ),
)
curr_mechanism_id = inter_mech_progress.add_task('[blue]Reaction mechanism(s)', start=True, total=len(frames_by_mech))

# individual progress bars for compounds within each mechanism
intra_mech_progress = Progress(
    "[progress.description]{task.description}",
    BarColumn(),
    TaskProgressColumn(),
    TextColumn(
        '({task.completed} / {task.total})'
    ),
    'At:',
    TimeElapsedColumn(),
)
total_compounds = 0
mech_task_ids = {} # preprocess dataframes by mechanism to determine progress bar layout and task lengths
for rxn_name, rxn_df in frames_by_mech.items():
    num_compounds = len(rxn_df)
    mech_task_ids[rxn_name] = intra_mech_progress.add_task(f'[cyan]{rxn_name}', start=True, total=len(rxn_df))
    total_compounds += num_compounds
compound_progress.update(curr_compound_id, total=total_compounds)

# combine progess readouts into unified live console
group = Group(
    status_readout,
    compound_readout,
    compound_progress,
    inter_mech_progress,
    intra_mech_progress,
)

## Define utility functions

In [8]:
import re
from time import sleep, time
from polymerist.rdutils.rdtypes import RDMol
from polymerist.maths.lattices.integral import CubicIntegerLattice


HILL_REGEX = re.compile(r'([A-Z][a-z]?)[0-9]*?') # break apart hill formula into just unique elements (one capital letter, one or no lowercase letters, any (including none) digits)

def generate_smarts_fragments(reactants_dict : dict[str, RDMol], reactor : reactors.PolymerizationReactor) -> MonomerGroup:
    '''Takes a labelled dict of reactant Mols and a PolymerizationReactor object with predefined rxn mechanism
    Returns a MonomerGroup containing all fragments enumerated by the provided rxn'''
    monogrp = MonomerGroup()
    initial_reactants = [reactants for reactants in reactants_dict.values()] # must convert to list to pass to ChemicalReaction
    
    for intermediates, frags in reactor.propagate(initial_reactants):
        for assoc_group_name, rdfragment in zip(reactants_dict.keys(), frags):
            # generate spec-compliant SMARTS
            raw_smiles = Chem.MolToSmiles(rdfragment)
            exp_smiles = specification.expanded_SMILES(raw_smiles)
            spec_smarts = specification.compliant_mol_SMARTS(exp_smiles)

            # record to monomer group
            affix = 'TERM' if MonomerGroup.is_terminal(rdfragment) else 'MID'
            monogrp.monomers[f'{assoc_group_name}_{affix}'] = [spec_smarts]

    return monogrp

def generate_uniform_subpopulated_lattice(max_num_atoms : int, num_atoms_in_mol : int, dimension : int=3) -> CubicIntegerLattice:
    '''Create an integer lattice which accomodates a number of sites while minimizing the size of consecutive voids between empty sites'''
    num_mols = max_num_atoms // num_atoms_in_mol # NOTE: key that this is floor division and not ordinary division
    sidelen = np.ceil(num_mols**(1/dimension)).astype(int) # needed to bypass float-typing for integer-valued quantity
    sidelens = np.array([sidelen]*dimension)
    full_lattice = CubicIntegerLattice(sidelens)

    # determine how many odd and even sublattice sites to sample
    num_even_sites = full_lattice.even_idxs.size
    num_even_to_take = min(num_mols, num_even_sites)     # lower bound on occupancy in d-dims is 0.5**(d-1) (=0.25 when d=3), meaning half lattice is not guaranteed to be occupied
    num_odd_to_take  = max(0, num_mols - num_even_sites) # only choose odd sites if there are any remaining once filling the even sites

    # randomly subsample appropriate amounts of each sublattice
    even_idxs_to_keep = np.random.permutation(full_lattice.even_idxs)[:num_even_to_take] # if the even lattice is unfilled, this improves spread, and if it is full this doesn't matter
    odd_idxs_to_keep  = np.random.permutation(full_lattice.odd_idxs )[:num_odd_to_take ] # populate interstices randmoly to avoid bias towards any part of the box
    idxs_to_keep = np.concatenate([even_idxs_to_keep, odd_idxs_to_keep])
    full_lattice.points = full_lattice.points[idxs_to_keep]

    return full_lattice # TOSELF: naming here no longer makes sense as lattice is not technically full anymore: worth fixing?

## Set parameters for build process

In [9]:
# MASTER_OUT_DIR = Path('polymer_improved')
MASTER_OUT_DIR = Path('polymers_atom_limited')

DOPs : list[int] = [3]
charge_method : str = 'Espaloma-AM1-BCC'
force_field_name : str = 'openff_unconstrained-2.0.0.offxml' # 'openff-2.0.0.offxml'
max_num_atoms_array : tuple[int] = (10_000, 20_000,)

switching_function : bool = False
exclusion : Quantity = 0.0 * nanometer 
nonbond_cutoff : Quantity = 0.9 * nanometer

clear_existing           : bool = True#False
refragment               : bool = False  
repolymerize_pdbs        : bool = False
reparameterize           : bool = False
reassign_partial_charges : bool = False
perform_energy_min       : bool = True

# preprocess parameters
charger = molchargers.MolCharger.subclass_registry[charge_method]()
forcefield = ForceField(FFDIR / force_field_name)

min_box_dim : Quantity = 2 * nonbond_cutoff # should be at least twice the nonbonded cutoff to avoid self-interaction
min_bbox = openmm_to_openff(min_box_dim * np.eye(3))

## Execute build loop

In [10]:
# create directories
MASTER_OUT_DIR.mkdir(exist_ok=True)
if clear_existing:
    filetree.clear_dir(MASTER_OUT_DIR)

# set up data structures for global output
failure_record = RecursiveDict()
m2p_mismatches = RecursiveDict()

# execute build loop
num_successful : int = 0
md_build_records : list[dict[str, Any]] = []
with Live(group, refresh_per_second=10) as live:
    # ensure bars start at 0
    for pbar in group.renderables: 
        for task_id in pbar.task_ids:
            pbar.reset(task_id)

    # iterate over all distinct chemistries by reaction mechanism
    for rxn_name, rxn_df in frames_by_mech.items():
        # look up reactive groups and pathway by mechanism
        mech_task_id = mech_task_ids[rxn_name]
        rxn_pathway  = rxns[rxn_name]
        reactor = reactors.PolymerizationReactor(rxn_pathway)
        
        # initialize output directories
        mech_dir : Path = MASTER_OUT_DIR / rxn_name
        mech_dir.mkdir(exist_ok=True)

        for (polymer_name, row) in rxn_df.iterrows():
            compound_readout.update(curr_compound_id, polymer_name=polymer_name)
            chem_dir : Path = mech_dir / polymer_name
            chem_dir.mkdir(exist_ok=True)

            # 0) load reactants with IUPAC names from chemical table
            status_readout.update(status_id, action='Gathering reactants')
            named_reactants = {}
            for j in range(2):
                reactant = Chem.MolFromSmiles(row[f'smiles_monomer_{j}'], sanitize=False)
                Chem.SanitizeMol(reactant, sanitizeOps=specification.SANITIZE_AS_KEKULE)
                named_reactants[ row[f'IUPAC_name_monomer_{j}'] ] = reactant

            try:
                # 1) use rxn template to polymerize monomers into all possible fragments
                frag_path = assemble_path(chem_dir, polymer_name, extension='json')
                if frag_path.exists() and not refragment: # if fragments have already been 
                    status_readout.update(status_id, action='Loading pre-existing monomer fragments')
                    monogrp = MonomerGroup.from_file(frag_path)
                else:
                    status_readout.update(status_id, action='Generating monomer fragments via reaction mechanism')
                    monogrp = generate_smarts_fragments(named_reactants, reactor=reactor)

                    status_readout.update(status_id, action='Saving monomer fragments...')
                    monogrp.to_file(frag_path)

                for dop in DOPs:
                    nmer_name = f'{GREEK_PREFIXES[dop]}mer'
                    dop_dir : Path = chem_dir / nmer_name
                    dop_dir.mkdir(exist_ok=True)

                    # 2) Generate PDB file for linear chain from fragments
                    pdb_path : Path = assemble_path(dop_dir, polymer_name, extension='pdb')
                    if not pdb_path.exists() or repolymerize_pdbs:
                        status_readout.update(status_id, action=f'Generating PDB file (with{"" if perform_energy_min else "out"} UFF energy minimization)')
                        polymer = building.build_linear_polymer(monomers=monogrp, DOP=dop+1, sequence='BA', energy_minimize=perform_energy_min)  # "BA" is needed to make term groups align properly, DOP does not account for term group pair (hence the "+1")
                        building.mbmol_to_openmm_pdb(pdb_path, polymer)
                        
                        # checking that my method produces the same results as M2P
                        m2p_smiles = row.smiles_polymer_DP6
                        if m2p_smiles is not None:
                            m2p_mol = Chem.MolFromSmiles(m2p_smiles)
                            workflow_smiles = polymer.to_smiles()
                            workflow_mol    = Chem.MolFromSmiles(workflow_smiles)

                            if not (workflow_mol.HasSubstructMatch(m2p_mol) or m2p_mol.HasSubstructMatch(workflow_mol)):
                                m2p_mismatches[rxn_name][polymer_name]['M2P_vers'] = m2p_smiles
                                m2p_mismatches[rxn_name][polymer_name]['workflow_vers'] = workflow_smiles

                    # 3a) Assign chemical info to PDB system
                    param_top_path : Path = assemble_path(dop_dir, polymer_name, extension='sdf')
                    if param_top_path.exists() and not reparameterize:
                        status_readout.update(status_id, action='Loading parameterized single-mol Topology')
                        offtop = topology.topology_from_sdf(param_top_path)
                    else:
                        try:
                            status_readout.update(status_id, action='Partitioning topology by fragments')
                            offtop = Topology.from_pdb(pdb_path, _custom_substructures=monogrp.monomers)
                            assert(partition(offtop)) # verify that a partition was possible
                            topology.topology_to_sdf(param_top_path, offtop)
                        except UnassignedChemistryInPDBError:
                            failure_record['No substruct cover'][rxn_name][polymer_name][dop] = monogrp
                            continue # skip to next compounds, don't proceed with parameterization   
                        except AssertionError:
                            failure_record['No substruct partition'][rxn_name][polymer_name][dop] = monogrp
                            continue # skip to next compounds, don't proceed with parameterization   

                    offmol = topology.get_largest_offmol(offtop)
                    offmol.name = polymer_name

                    # 3b) Assign partial charges, if not already present
                    if not molchargers.has_partial_charges(offmol):
                        status_readout.update(status_id, action=f'Assigning partial charges via {charger.CHARGING_METHOD}')
                        cmol = charger.charge_molecule(offmol)
                        unique_elems = re.findall(HILL_REGEX, cmol.hill_formula) # unique element names in same order as found in Hill formula
                    
                    # generate tiled lattices as specified
                    for max_num_atoms in max_num_atoms_array:
                        lattice_str = f'sub_{max_num_atoms}_atoms'
                        latt_dir : Path = dop_dir / lattice_str
                        latt_dir.mkdir(exist_ok=True)

                        int_lattice = generate_uniform_subpopulated_lattice(max_num_atoms, num_atoms_in_mol=cmol.n_atoms)
                        r_eff = rdmol_effective_radius(cmol.to_rdkit())
                        lattice = int_lattice.linear_transformation(2.0*r_eff*np.eye(3), as_coords=True) # scale integer lattice my effective diameter

                        # create tiled version of parameterized topology
                        with Timer() as topo_timer:
                            status_readout.update(status_id, action=f'Generating tiled {lattice_str} topology')
                            tiled_offtop = topology.topology_from_molecule_onto_lattice(cmol, lattice_points=lattice.points, rotate_randomly=True, unique_mol_ids=True)
                            latt_top_path = assemble_path(latt_dir, lattice_str, postfix=polymer_name, extension='sdf')
                            topology.topology_to_sdf(latt_top_path, tiled_offtop)

                            latt_pdb_path = assemble_path(latt_dir, lattice_str, postfix=polymer_name, extension='pdb')
                            tiled_offtop.to_file(latt_pdb_path)

                        # generate appropriately-sized periodic box size, starting with the tight bounding box for the topology
                        top_box_vectors = boxvectors.get_topology_bbox(tiled_offtop) # determine tight box size
                        top_box_vectors = boxvectors.pad_box_vectors_uniform(top_box_vectors, exclusion) # apply periodic box (with padding) to Interchange
                        top_box_vectors = np.maximum(min_bbox, top_box_vectors) # enusre the box is no smaller than the minimum determined by the cutoff distance

                        top_box_vectors_omm = openff_to_openmm(top_box_vectors)
                        box_vector_sizes = np.linalg.norm(top_box_vectors_omm, axis=1) * top_box_vectors_omm.unit # rows are each a distinct box vector
                        box_vector_dict = {
                            f'box_dim_{axis} ({size_quant.unit!s})' : size_quant._value
                                for (axis, size_quant) in zip('xyz', box_vector_sizes)
                        }

                        # create and save Interchange for MD export
                        with Timer() as inc_timer:
                            status_readout.update(status_id, action=f'Creating {lattice_str} OpenFF Interchange')
                            interchange = forcefield.create_interchange(tiled_offtop, charge_from_molecules=[cmol])
                            interchange.box = top_box_vectors # apply periodic box to Interchange

                            # configure nonbonded in Interchange to have correct cutoff and switching function width
                            interchange['vdW'].switch_width = (1.0 if switching_function else 0.0) * angstrom
                            interchange['vdW'           ].cutoff = nonbond_cutoff
                            interchange['Electrostatics'].cutoff = nonbond_cutoff

                        latt_inc_path = assemble_path(latt_dir, lattice_str, postfix=polymer_name, extension='pkl')
                        with latt_inc_path.open('wb') as pklfile: # NOTE: pickled files must be read/written in binary mode
                            pickle.dump(interchange, pklfile)

                        # record information about MD build run to simplfiy resuming, analyzing, and benchmarking structure outputs
                        md_build_entry = {
                            'mechanism'                : rxn_name,
                            'polymer_name'             : polymer_name,
                            'exper_density'            : row['Density'],
                            'n_atoms_cap'              : lattice_str,
                            'lattice_size'             : int_lattice.counts_along_dims_as_str(),
                            'num_oligomers'            : lattice.n_points,
                            'effective radius'         : r_eff,
                            'oligomer_type'            : nmer_name,
                            'n_atoms_in_topology'      : tiled_offtop.n_atoms,
                            'unique_elems_in_topology' : unique_elems, 
                            'directory'                : str(latt_dir),
                            'topology_path'            : str(latt_top_path),
                            'topology_time'            : topo_timer.time_taken,
                            'interchange_path'         : str(latt_inc_path),
                            'interchange_time'         : inc_timer.time_taken,
                        }
                        md_build_entry.update(box_vector_dict)
                        md_build_records.append(md_build_entry)

                        md_build_entry_path = assemble_path(latt_dir, f'{lattice_str}_{nmer_name}_{polymer_name}', postfix='RECORD', extension='json')
                        with md_build_entry_path.open('w') as record_file: # also save to disc individually, to allow reconstruction if loop fails haflway through
                            json.dump(md_build_entry, record_file, indent=4)
                    
                num_successful += 1

            except Exception as other_error:
                failure_record[other_error.__class__.__name__][rxn_name][polymer_name] = str(other_error)
            finally:
                intra_mech_progress.advance(mech_task_id)
                compound_progress.advance(comp_progress_id)
                
        inter_mech_progress.advance(curr_mechanism_id, advance=1)
        sleep(0.1) # needed to give final bar enough time to catch up
    compound_readout.update(curr_compound_id, polymer_name=f'Completed! ({num_successful}/{total_compounds} successful)')

all_records_path = assemble_path(MASTER_OUT_DIR, 'build_records', extension='csv')
md_build_records_table = pd.DataFrame.from_records(md_build_records)
md_build_records_table.set_index(['mechanism', 'polymer_name'], inplace=True)
md_build_records_table.to_csv(all_records_path)

m2p_mismatch_path = assemble_path(MASTER_OUT_DIR, 'm2p_mismatches', extension='json')
with m2p_mismatch_path.open('w') as m2p_mismatch_file:
    json.dump(m2p_mismatches, m2p_mismatch_file, indent=4)

print(failure_record)
print(m2p_mismatches)

Output()