# Developing more robust IBIS which can handle rings

In [None]:
import json
import pandas as pd
from pathlib import Path

from polymerist.rdutils.reactions import reactions

import networkx as nx
from rdkit import Chem

from polymerist.rdutils.rdtypes import RDMol
from polymerist.rdutils.rdgraphs import rdmol_to_networkx
from polymerist.rdutils.rdkdraw import set_rdkdraw_size, disable_substruct_highlights
from polymerist.rdutils.reactions import reactors, reactions, fragment

### Load rxn reference data

In [None]:
# Static Paths
RAW_DATA_DIR  = Path('monomer_data_raw')
FMT_DATA_DIR  = Path('monomer_data_formatted')
PROC_DATA_DIR = Path('monomer_data_processed')
RXN_FILES_DIR = Path('poly_rxns')

with (RXN_FILES_DIR / 'rxn_groups.json').open('r') as file: # load table of functional group for each reaction
    rxn_groups = json.load(file)

rxns = {
    rxnname : reactions.AnnotatedReaction.from_rxnfile(RXN_FILES_DIR / f'{rxnname}.rxn')
        for rxnname in rxn_groups.keys()
}

In [None]:
df = pd.read_csv(PROC_DATA_DIR / 'monomer_data_MASTER.csv', index_col=[0,1])
frames_by_mech = {
    mechanism : df.xs(mechanism)
        for mechanism in df.index.unique(level=0)
}

### Individual reaction

In [None]:
from polymerist.monomers import specification
from polymerist.rdutils.rdkdraw import set_rdkdraw_size
set_rdkdraw_size(300, 3/2)


# mech = 'polyurethane_isocyanate'
mech = 'polyamide'
mech = 'polyurethane_nonisocyanate'
mech = 'polyester'
mech = 'polycarbonate_phosgene'
mech = 'polyimide'
# mech = 'vinyl'

frame = frames_by_mech[mech]
row = frame.iloc[3]
rxn = rxns[mech]
reactor = reactors.PolymerizationReactor(rxn)
display(rxn)

reactants = []
for index in range(2):
    smi = row[f'smiles_monomer_{index}']
    exp_smi = specification.expanded_SMILES(smi, assign_map_nums=False)
    mol = Chem.MolFromSmiles(exp_smi, sanitize=False)
    display(mol)

    reactants.append(mol)

In [None]:
set_rdkdraw_size(500)

products = reactor.react(reactants)
for product in products:
    Chem.SanitizeMol(product, sanitizeOps=specification.SANITIZE_AS_KEKULE)
    display(product)

In [None]:
G = rdmol_to_networkx(product)

elem_colors = {
    'C' : 'gray',
    'O' : 'red',
    'N' : 'blue',
    'H' : 'green',
    'Cl' : 'purple'
}

nx.draw(
    G,
    pos=nx.spring_layout(G),
    node_color=[elem_colors[symbol] for symbol in nx.get_node_attributes(G, 'symbol').values()],
    with_labels=True
)

In [None]:
S = G.subgraph(nx.get_node_attributes(G, 'molRxnRole').keys())
# S = G.subgraph(i for i, symbol in nx.get_node_attributes(G, 'atomic_num').items() if symbol > 1)

pos = nx.spring_layout(S)
nx.draw(
    S,
    pos=pos,
    node_color=[elem_colors[symbol] for symbol in nx.get_node_attributes(S, 'symbol').values()],
    with_labels=True
)

edge_labels = { # label by RDKit bond number
    (u, v) : data['idx']
        for u, v, data in S.edges(data=True)
}
# edge_labels = { # label by distint chains in chain decomposition
#     (u, v) : i
#         for i, chain in enumerate(nx.chain_decomposition(S)) 
#             for u, v in chain
# }
list(nx.chain_decomposition(S))
_ = nx.draw_networkx_edge_labels(
    S,
    pos=pos,
    edge_labels=edge_labels
)

In [None]:
nx.get_node_attributes(S, 'was_dummy').keys(), nx.get_node_attributes(S, 'molRxnRole').keys()

In [None]:
from itertools import combinations
from polymerist.genutils.iteration import sliding_window

node_pairs_to_edges = lambda pairs : [S.edges[*edge]['idx'] for edge in pairs]


bridges = set(node_pairs_to_edges(nx.bridges(S)))
former_bh_ids = nx.get_node_attributes(S, 'was_dummy').keys() # not checking for heavies here
new_bond_ids = {i for i in rxn.product_info_maps[0].new_bond_ids_to_map_nums}

for bh_id_pair in combinations(former_bh_ids, 2):
    path_edges = node_pairs_to_edges(nx.utils.pairwise(nx.shortest_path(S, *bh_id_pair)))
    print(bh_id_pair, path_edges, set.intersection(set(path_edges), bridges))

In [None]:
from openff.interchange import Interchange

In [None]:
new_bond_ids

In [None]:
former_bh_ids = nx.get_node_attributes(S, 'was_dummy').keys() # not checking for heavies here

In [None]:

bridge_edge_ids = [G.edges[*edge]['idx'] for edge in bridges]
new_edge_ids = [i for i in rxn.product_info_maps[0].new_bond_ids_to_map_nums]
former_bh_ids = nx.get_node_attributes(S, 'was_dummy').keys() # not checking for heavies here

bh_paths = {}
bh_bridges = {}
for bh_id_pair in combinations(former_bh_ids, 2):
    paths = [
        [S.edges[*path_edges]['idx'] for path_edges in path]
            for path in nx.all_simple_edge_paths(S, *bh_id_pair) 
    ]
    bh_paths[bh_id_pair] = paths

    


In [None]:
list(sliding_window(nx.shortest_path(S, 1, 7), n=2))

In [None]:
bridges

In [None]:
Chem.GetShortestPath(product, 0,12)

In [None]:
nx.shortest_path(G, 0, 12)

In [None]:
bridges

In [None]:

class CutBridgesNearReactingAtoms(reactors.IBIS):
    '''IBIS which looks for bridges occurring within the activate portion of a reaction product to minimize the number of cuts made'''
    def locate_intermonomer_bonds(self, product: RDMol, product_info: reactors.RxnProductInfo) -> reactors.Generator[int, None, None]:
        return super().locate_intermonomer_bonds(product, product_info)

In [None]:
def test():
    former_bh_ids = [atom_id for match in prod.GetSubstructMatches(fragment.HEAVY_FORMER_LINKER_QUERY) for atom_id in match]
    bridge_bond_ids = {prod.GetBondBetweenAtoms(*pair).GetIdx() for pair in bridges}
    print(bridge_bond_ids)
    new_bond_ids = set(rxn.product_info_maps[0].new_bond_ids_to_map_nums)

    new_bridge_bond_ids = set.intersection(bridge_bond_ids, new_bond_ids)
    if new_bridge_bond_ids:
        yield from new_bridge_bond_ids
    else:
        for bh_id_pair in combinations(former_bh_ids, 2):
            for path in nx.all_simple_edge_paths(S, *bh_id_pair):
                path_edges = set(S.edges[*pair]['idx'] for pair in path)
                if set.intersection(path_edges, new_bond_ids):
                    yield from set.intersection(path_edges, bridge_bond_ids)

In [None]:
BUFFER = '='*10
for i, (dimers, frags) in enumerate(reactor.propagate(reactants)):
    print(f'{BUFFER}STEP {i}{BUFFER}')
    for j, dim in enumerate(dimers):
        print(f'Dimer {i}-{j}')
        display(dim)
    bins = []
    for j, frag in enumerate(frags):
        print(f'Fragment {i}-{j}')
        display(frag)