# Imports

In [None]:
## Logging and Shell
import logging
logging.basicConfig(
    level=logging.ERROR,
    force=True
)

## Generic imports
from collections import defaultdict

## Numeric imports
import numpy as np
import pandas as pd

## File I/O
from pathlib import Path
import json, pickle

# Cheminformatics
from rdkit import Chem

from openmm.unit import nanometer, Quantity

from openff.toolkit import Molecule, Topology, ForceField
from openff.toolkit.utils.exceptions import (
    UnassignedChemistryInPDBError,
    IncorrectNumConformersWarning,
)

# Custom Imports
from polymerist.genutils.containers import RecursiveDict
from polymerist.genutils.fileutils import filetree

from polymerist.maths.greek import GREEK_PREFIXES
from polymerist.maths.lattices import generate_int_lattice

from polymerist.rdutils import rdkdraw
from polymerist.rdutils.rdcoords import tiling
from polymerist.rdutils.rdprops import copy_rd_props
from polymerist.rdutils.reactions import reactions, reactors

from polymerist.monomers import specification, MonomerGroup
from polymerist.residues.partition import partition
from polymerist.polymers import building
from polymerist.openfftools import topology, boxvectors, pcharge, FFDIR

# Mol drawing settings
from rdkit.Chem.Draw import IPythonConsole
DIM    = 300
ASPECT = 3/2
rdkdraw.set_rdkdraw_size(DIM, ASPECT)

# catch annoying warnings
import warnings 
warnings.catch_warnings(record=True)
warnings.filterwarnings('ignore', category=IncorrectNumConformersWarning)

In [None]:
# Static Paths
RAW_DATA_DIR  = Path('monomer_data_raw')
FMT_DATA_DIR  = Path('monomer_data_formatted')
PROC_DATA_DIR = Path('monomer_data_processed')
RXN_FILES_DIR = Path('poly_rxns')

# Load monomer and rxn data 

In [None]:
# input_data_path = PROC_DATA_DIR / '20231114_polyid_data_density_DP2-6 - 1,2 monomers_FILTERED.csv'
# input_data_path = PROC_DATA_DIR / 'nipu_urethanes_FILTERED.csv'
input_data_path = PROC_DATA_DIR / 'monomer_data_MASTER.csv'
df = pd.read_csv(input_data_path, index_col=0)

## Load pre-defined reactions with functional group and name backmap

In [None]:
keys = ['rxn_name']

blacklisted_rxns = ['imide']#, 'vinyl']
df = df[df.mechanism.map(lambda s : s not in blacklisted_rxns)]

df_grouper = df.groupby(keys)
frames_by_mech = { # separate into individual dataframes grouped by reaction mechanism
    mech : df_grouper.get_group(mech)
        for mech in df_grouper.groups
}

In [None]:
with (RXN_FILES_DIR / 'rxn_groups.json').open('r') as file: # load table of functional group for each reaction
    rxn_groups = json.load(file)

rxns = {
    rxnname : reactions.AnnotatedReaction.from_rxnfile(RXN_FILES_DIR / f'{rxnname}.rxn')
        for rxnname in rxn_groups.keys()
}

# Auto-generating monomer fragments and Topologies

## Set up and format progress bars to track build status

In [None]:
from time import sleep
from rich.progress import Progress
from rich.progress import (
    BarColumn,
    Progress,
    SpinnerColumn,
    TaskProgressColumn,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)
from rich.console import Group
from rich.live import Live

# status of individual task
status_readout = Progress(
    'STATUS:',
    TextColumn(
        '[purple]{task.fields[action]}'
    ),
    '...'
)
status_id = status_readout.add_task('[green]Current compound:', action='')

# textual display of the name of the curent polymer
compound_readout = Progress(
    'Current compound:',
    TextColumn(
        '[blue]{task.fields[polymer_name]}',
        justify='right'
    )
)
curr_compound_id  = compound_readout.add_task('[green]Current compound:', polymer_name='')

# progress over individual compounds (irrespective of mechanism)
compound_progress = Progress(
    SpinnerColumn(),
    "[progress.description]{task.description}",
    BarColumn(),
    TaskProgressColumn(),
    TextColumn(
        '({task.completed} / {task.total})'
    ),
)
comp_progress_id = compound_progress.add_task('[blue]Unique compound(s)   ', polymer_name='')

# progress over distinct classes of mechanism
inter_mech_progress = Progress(
    SpinnerColumn(),
    "[progress.description]{task.description}",
    BarColumn(),
    TaskProgressColumn(),
    TextColumn(
        '({task.completed} / {task.total})'
    ),
)
curr_mechanism_id = inter_mech_progress.add_task('[blue]Reaction mechanism(s)', start=True, total=len(frames_by_mech))

# individual progress bars for compounds within each mechanism
intra_mech_progress = Progress(
    "[progress.description]{task.description}",
    BarColumn(),
    TaskProgressColumn(),
    TextColumn(
        '({task.completed} / {task.total})'
    ),
    'At:',
    TimeElapsedColumn(),
)
total_compounds = 0
mech_task_ids = {} # preprocess dataframes by mechanism to determine progress bar layout and task lengths
for rxn_name, rxn_df in frames_by_mech.items():
    num_compounds = len(rxn_df)
    mech_task_ids[rxn_name] = intra_mech_progress.add_task(f'[cyan]{rxn_name}', start=True, total=len(rxn_df))
    total_compounds += num_compounds
compound_progress.update(curr_compound_id, total=total_compounds)

# combine progess readouts into unified live console
group = Group(
    status_readout,
    compound_readout,
    compound_progress,
    inter_mech_progress,
    intra_mech_progress,
)

## Define utility functions

In [None]:
from polymerist.rdutils.rdtypes import RDMol


def generate_smarts_fragments(reactants_dict : dict[str, RDMol], reactor : reactors.PolymerizationReactor) -> MonomerGroup:
    '''Takes a labelled dict of reactant Mols and a PolymerizationReactor object with predefined rxn mechanism
    Returns a MonomerGroup containing all fragments enumerated by the provided rxn'''
    monogrp = MonomerGroup()
    initial_reactants = [reactants for reactants in reactants_dict.values()] # must convert to list to pass to ChemicalReaction
    
    for dimer, frags in reactor.propagate(initial_reactants):
        for assoc_group_name, rdfragment in zip(reactants_dict.keys(), frags):
            # generate spec-compliant SMARTS
            raw_smiles = Chem.MolToSmiles(rdfragment)
            exp_smiles = specification.expanded_SMILES(raw_smiles)
            spec_smarts = specification.compliant_mol_SMARTS(exp_smiles)

            # record to monomer group
            affix = 'TERM' if MonomerGroup.is_terminal(rdfragment) else 'MID'
            monogrp.monomers[f'{assoc_group_name}_{affix}'] = [spec_smarts]

    return monogrp

def topology_from_molecule_onto_lattice(cmol : Molecule, lattice : np.ndarray):
    '''Convert a charged OpenFF Molecule into a Topology made up of copies of that Molecule tiled according to a lattice'''
    tiled_rdmol = tiling.tile_lattice_with_rdmol(cmol.to_rdkit(), lattice)

    tiled_offmols = [] 
    for tiled_mol_copy in Chem.GetMolFrags(tiled_rdmol, asMols=True, sanitizeFrags=False):
        copy_rd_props(tiled_rdmol, tiled_mol_copy) # ensure each individual fragment preserves the information of the parent molecule
        tiled_offmols.append(
            Molecule.from_rdkit(
                rdmol=tiled_mol_copy,
                allow_undefined_stereo=True,
                hydrogens_are_explicit=True
            )
        )
    return Topology.from_molecules(tiled_offmols)

## Set parameters for build process

In [None]:
MASTER_OUT_DIR = Path('polymer_structures')

DOPs : list[int] = [3]
charge_method : str = 'Espaloma-AM1-BCC'
force_field_name : str = 'openff_unconstrained-2.0.0.offxml' # 'openff-2.0.0.offxml'

lattice_sizes : list[np.ndarray] = [
    np.array([1, 1, 1]), # just a single molecule in a box
    np.array([2, 2, 2]),
    np.array([3, 3, 3]),
    np.array([5, 5, 5]),
]
exclusion : Quantity = 0.9 * nanometer # should match nonbonded cutoff for MD file generation

clear_existing           : bool = True
refragment               : bool = False  
repolymerize_pdbs        : bool = False
reparameterize           : bool = False
reassign_partial_charges : bool = False

# preprocess parameters
charger = pcharge.MolCharger.subclass_registry[charge_method]()
lattices = {
    'x'.join(str(i) for i in lattice_size) : generate_int_lattice(*lattice_size)
        for lattice_size in lattice_sizes
}
forcefield = ForceField(FFDIR / force_field_name)

## Execute build loop

In [10]:
# create directories
MASTER_OUT_DIR.mkdir(exist_ok=True)
if clear_existing:
    filetree.clear_dir(MASTER_OUT_DIR)

# set up data structures for global output
frag_registry  = RecursiveDict()
failure_record = RecursiveDict()

# execute build loop
num_successful : int = 0
with Live(group, refresh_per_second=10) as live:
    # ensure bars start at 0
    for pbar in group.renderables: 
        for task_id in pbar.task_ids:
            pbar.reset(task_id)

    # iterate over all distinct chemistries by reaction mechanism
    for rxn_name, rxn_df in frames_by_mech.items():
        # look up reactive groups and pathway by rxn_name
        mech_task_id = mech_task_ids[rxn_name]
        rxn_pathway  = rxns[rxn_name]
        reactor = reactors.PolymerizationReactor(rxn_pathway)
        
        # initialize output directories
        mech_dir : Path = MASTER_OUT_DIR / rxn_name
        mech_dir.mkdir(exist_ok=True)

        for (i, row) in rxn_df.iterrows():
            # 0) load reactants with IUPAC names from chemical table
            status_readout.update(status_id, action='Gathering reactants')
            named_reactants = {}
            for j in range(2):
                reactant = Chem.MolFromSmiles(row[f'smiles_monomer_{j}'], sanitize=False)
                Chem.SanitizeMol(reactant, sanitizeOps=specification.SANITIZE_AS_KEKULE)
                named_reactants[ row[f'IUPAC_name_monomer_{j}'] ] = reactant

            # 0a) auto-generate name of the current polymer (this is needed to finish setting up several prereqs viz dirs, progress, etc.) 
            polymer_name = f'poly({"-co-".join(named_reactants.keys())})' # TODO : make sure this conforms to IUPAC standards for naming
            compound_readout.update(curr_compound_id, polymer_name=polymer_name)
            frag_registry[rxn_name][i] = polymer_name

            chem_dir : Path = mech_dir / polymer_name
            chem_dir.mkdir(exist_ok=True)

            try:
                # 1) use rxn template to polymerize monomers into all possible fragments
                frag_path = chem_dir / f'{polymer_name}.json'
                if frag_path.exists() and not refragment: # if fragments have already been 
                    status_readout.update(status_id, action='Loading pre-existing monomer fragments')
                    monogrp = MonomerGroup.from_file(frag_path)
                else:
                    status_readout.update(status_id, action='Generating monomer fragments via reaction mechanism')
                    monogrp = generate_smarts_fragments(named_reactants, reactor=reactor)

                    status_readout.update(status_id, action='Saving monomer fragments...')
                    monogrp.to_file(frag_path)

                for dop in DOPs:
                    nmer_name = f'{GREEK_PREFIXES[dop]}mer'
                    dop_dir : Path = chem_dir / nmer_name
                    dop_dir.mkdir(exist_ok=True)

                    # 2) Generate PDB file for linear chain from fragments
                    pdb_path : Path = dop_dir / f'{polymer_name}.pdb'
                    if not pdb_path.exists() or repolymerize_pdbs:
                        status_readout.update(status_id, action='Generating PDB file')
                        polymer = building.build_linear_polymer(monomers=monogrp, DOP=dop, sequence='AB')  # TOSELF : may need to double DOP here, since each AB sequence of fragments technically counts as 1 monomer
                        building.mbmol_to_openmm_pdb(pdb_path, polymer)

                    # 3a) Assign chemical info to PDB system
                    param_top_path = dop_dir / f'{polymer_name}.sdf'
                    if param_top_path.exists() and not reparameterize:
                        status_readout.update(status_id, action='Loading parameterized single-mol Topology')
                        offtop = topology.topology_from_sdf(param_top_path)
                    else:
                        try:
                            status_readout.update(status_id, action='Partitioning topology by fragments')
                            offtop = Topology.from_pdb(pdb_path, _custom_substructures=monogrp.monomers)
                            assert(partition(offtop)) # verify that a partition was possible
                            topology.topology_to_sdf(param_top_path, offtop)
                        except UnassignedChemistryInPDBError:
                            failure_record['No substruct cover'][rxn_name][polymer_name][dop] = monogrp
                            continue # skip to next compounds, don't proceed with parameterization   
                        except AssertionError:
                            failure_record['No substruct partition'][rxn_name][polymer_name][dop] = monogrp
                            continue # skip to next compounds, don't proceed with parameterization   

                    offmol = topology.get_largest_offmol(offtop)
                    offmol.name = polymer_name

                    # 3b) Assign partial charges, if not already present
                    if not pcharge.has_partial_charges(offmol):
                        status_readout.update(status_id, action=f'Assigning partial charges via {charger.CHARGING_METHOD}')
                        cmol = charger.charge_molecule(offmol)
                    
                    # generate tiled lattices as specified
                    for lattice_str, lattice in lattices.items(): # NOTE : key that this is done AFTER parameterization to avoid reassigning parameters to a (potentially) much larger Topology
                        latt_dir : Path = dop_dir / lattice_str
                        latt_dir.mkdir(exist_ok=True)

                        status_readout.update(status_id, action=f'Generating tiled {lattice_str} topology')
                        tiled_offtop = topology_from_molecule_onto_lattice(cmol, lattice=lattice)
                        latt_top_path = latt_dir / f'{lattice_str}_{polymer_name}.sdf'
                        topology.topology_to_sdf(latt_top_path, tiled_offtop)

                        # create and save Interchange for MD export
                        status_readout.update(status_id, action=f'Creating {lattice_str} OpenFF Interchange')
                        interchange = forcefield.create_interchange(tiled_offtop, charge_from_molecules=[cmol])
                        top_box_vectors = boxvectors.get_topology_bbox(tiled_offtop) # determine tight box size
                        interchange.box = boxvectors.pad_box_vectors_uniform(top_box_vectors, exclusion) # apply periodic box (with padding) to Interchange

                        latt_inc_path = latt_dir / f'{lattice_str}_{polymer_name}.pkl'
                        with latt_inc_path.open('wb') as pklfile: # NOTE: pickled files must be read/written in binary mode
                            pickle.dump(interchange, pklfile)
                num_successful += 1

            except Exception as other_error:
                failure_record[other_error.__class__.__name__][rxn_name][polymer_name] = str(other_error)
            finally:
                intra_mech_progress.advance(mech_task_id)
                compound_progress.advance(comp_progress_id)
                
        inter_mech_progress.advance(curr_mechanism_id, advance=1)
    
    # Ensure readout are current at end of process
    compound_readout.update(curr_compound_id, polymer_name=f'Completed! ({num_successful}/{total_compounds} successful)')
    sleep(0.1) # needed to give final bar enough time to catch up

print(failure_record)