# A Demo of using RDKitMol as intermediate to generate TS by TS-GCN

A demo to show how RDKitMol can connect RMG and TS-GCN to help predict TS geometry. TS-GCN requires a same atom ordering for the reactant and the product, which is seldomly accessible in practice. RDKitMol + RMG provides an opportunity to match reactant and product atom indexes according to RMG reaction family. <br>

Some codes are compiled from https://github.com/ReactionMechanismGenerator/TS-GCN


In [None]:
import os
import sys
import subprocess
from typing import Iterable
# To add this RDMC into PYTHONPATH in case you haven't do it
sys.path.append(os.path.dirname(os.path.abspath('')))

import numpy as np

from rdkit import Chem
from rdmc import RDKitMol
from rdmc.forcefield import optimize_mol
from rdmc.ts import get_formed_and_broken_bonds, is_DA_rxn_endo
from rdmc.rdtools.atommap import reverse_map
from rdmc.rdtools.view import grid_viewer, mol_viewer, ts_viewer

try:
    # import RMG dependencies
    from rdmc.external.rmg import (from_rdkit_mol,
                                   find_reaction_family,
                                   generate_reaction_complex,
                                   load_rmg_database,
                                   )
    # Load RMG database
    rmg_db = load_rmg_database(all_families=True)
except (ImportError, ModuleNotFoundError):
    print('You need to install RMG-Py first and run this IPYNB in rmg_env!')


# A helper function to generate molecules either from xyz or smiles
# It will also note which molecules have 3D information
def parse_xyz_or_smiles_list(mol_list, **kwargs):
    """
    A function to parse xyz and smiles and list if the
    conformational information is provided.
    """
    mols, is_3D = [], []
    for mol in mol_list:
        if isinstance(mol, (tuple, list)) and len(mol) == 2:
            mol, mult = mol
        else:
            mult = None
        try:
            rd_mol = RDKitMol.FromXYZ(mol, **kwargs)
        except ValueError:
            rd_mol = RDKitMol.FromSmiles(mol,)
            rd_mol.EmbedConformer()
            is_3D.append(False)
        else:
            is_3D.append(True)
        finally:
            if mult != None:
                rd_mol.SaturateMol(multiplicity=mult)
            mols.append(rd_mol)
    return mols, is_3D


# when doing subgraph match, RDKit will returns a list
# that the index corresponds to the reference molecule
# and the value corresponds to the probing molecule
# This function inverse-transform the index and value relationship.

def match_mols_to_complex(mols: list, 
                           mol_complex: 'RDKitMol',
                           max_matches: int = 10000):
    """
    Generate a list of lists of indexes that each item corresponding to one of the
    fragments in the complex.
    """
    frags_idx = list(mol_complex.GetMolFrags())
    if len(frags_idx[0]) == len(frags_idx[1]):
        frags = mol_complex.GetMolFrags(asMols=True)
        match1 = frags[0].GetSubstructMatch(mols[0])
        match2 = frags[0].GetSubstructMatch(mols[1])
        if match1:
            # Either two identical molecules
            # Or the first fragment is indeed the first one assigned
            pass
        elif match2:
            frags_idx = frags_idx[::-1]
        else:
            match1 = frags[1].GetSubstructMatch(mols[0])
            match2 = frags[1].GetSubstructMatch(mols[1])
            if match1:
                frags_idx = frags_idx[::-1]
            elif match2:
                pass
            else:
                 raise RuntimeError('Have difficulty matching molecules from the complex'
                                    'to the input molecules.')
    elif len(frags_idx[0]) != mols[0].GetNumAtoms():
        frags_idx = frags_idx[::-1]

    included_atoms = set()
    for i in range(len(mols)):
        match = mol_complex.GetSubstructMatch(mols[i])
        if set.intersection(included_atoms, match):
            # Need to rematch the substructure
            # Since the degeneracy from H atoms, there can be combinatorial explosion
            # Only try to match the heavy atoms first
            all_matches = mol_complex.GetSubstructMatches(mols[i], uniquify=False, maxMatches=max_matches)
            for match in all_matches:
                if not set.intersection(included_atoms, match):
                    break
            else:
                raise RuntimeError(f'Cannot find a proper match for fragment {i}, you may want to '
                                   f'change `mols` or increase `max_matches`.')
        included_atoms = included_atoms.union(match)
        frags_idx[i] = match
    return frags_idx


def get_bond_length_list(mol: 'RDKitMol',
                         match: list = [],):
    """
    Get a list whose first element is length-2 set containing the bonded atoms, and
    the second element is the bond length.
    
    Args:
        mol (RDKitMol): the molecule
        match (list): A list from the subgraph match result. The atom indexes will
                      be transformed to the matched pattern.
    """
    if match:
        match_dict = {prb_i: ref_i for prb_i, ref_i in enumerate(match)}
    else:
        match_dict = {i: i for i in range(mol.GetNumAtoms())}
    conf = mol.GetConformer()
    # Get the bond length in the product geometry
    bond_length = []
    for bond in mol.GetBondsAsTuples():
        bond_length.append([[match_dict[atom] for atom in bond],
                             conf.GetBondLength(bond)])
    return bond_length


# Experimental features
# Ideally, each reaction family may have a value that works better
# and the author is still trying to find those numbers
# Ones recorded below are just for reference
BOND_CONSTRAINT = {'1,3_Insertion_ROR': 2.5,
                   'Retroene': 2.5,
                   '1,2_Insertion': 2.5,
                   '2+2_cycloaddition_Cd': 2.5,
                   'Diels_alder_addition': 3.0,
                   'Intra_ene_reaction': 4.0,
                   'H_Abstraction': 3.,
                   'Disproportionation': 3.,
                   'H_Abstraction': 3.,
                   'SubstitutionS': 3.,  # A + B = C + D
                   'Substitution_O': 3.,}

# There are several conditions that reactants can be more informative.
# 1. If a family is only breaking bonds without breaking any, 
#    then the reactant geometry is more informative. One can
#    infer the geometry of products and the product alignment
#    from the reactant solely. (Since all reactions are elementary,
#    We expect stereospecificity is maintained.)
# 2. For ketoenol, the reactant (enol) may have cis-trans, while
#    this info may lose in the product
# the value is if product is definied as the forward direction
REACTANT_MORE_INFORMATIVE_FAMILIES = {
    '1+2_Cycloaddition': False,  # A + B = C  Ring in product
#     '1,2_shiftS': False,  # A = B  Potentially 2 chiral in product vs 1 chiral center in reactant
    '1,4_Cyclic_birad_scission': True,  # A = C  Ring in product
    '2+2_cycloaddition': False,  # A + B = C  Ring in product
    'Birad_recombination': False,  # A = C  Ring in product
    'Concerted_Intra_Diels_alder_monocyclic_1,2_shiftH': False,  # A = C  C=CC=C in A can be less constraint
#     'Cyclic_Ether_Formation': False,  # A = C + D  Ring structure is more constraint than single bond
#     'Cyclic_Thioether_Formation': False,  # A = C + D  Ring structure is more constraint
    'Cyclopentadiene_scission': True,  # A = C  Ring in product
    'Diels_alder_addition': False,  # A + B = C  Ring in product
    'Intra_2+2_cycloaddition_Cd': False,  # A = C  Ring in product
    'Intra_5_membered_conjugated_C=C_C=C_addition': False,  # A = C
    'Intra_Diels_alder_monocyclic': False,  # A = C  Ring in product
#     'Intra_RH_Add_Endocyclic': False,  # A = C  Ring in product
#     'Intra_RH_Add_Exocyclic': False,  # A = C  Ring in product
    'Intra_R_Add_Endocyclic': False,  # A = C  Ring in product
    'Intra_R_Add_Exocyclic': False,  # A = C  Ring in product
#     'Intra_R_Add_ExoTetcyclic': False,  # A = C  Ring in product
    'Intra_Retro_Diels_alder_bicyclic': True,  # A = C  Ring in reactant
    'R_Addition_COm': False,  # A + B = C
    'R_Addition_CSm': False,  # A + B = C
    'R_Addition_MultipleBond': False, # A + B = C
    'R_Recombination': False,  # An extra bond formed
    'ketoenol': True, # A = B   
}

OWN_REVERSE = [
    '1,2_shiftC',  # A = C
    '6_membered_central_C-C_shift',  # A = C
    'H_Abstraction',  # A + B = C + D
    'Intra_R_Add_Exo_scission',  # A = C
    'Intra_ene_reaction',  # A = C
    'intra_H_migration',  # A = C
    'SubstitutionS',  # A + B = C + D
    'Substitution_O',  # A + B = C + D
]

BIMOLECULAR = [
    'CO_Disproportionation',
    'Disproportionation',
    'H_Abstraction',
    'SubstitutionS',  # A + B = C + D
    'Substitution_O',  # A + B = C + D
]

%load_ext autoreload
%autoreload 2

## INTPUT FIELDS

#### Forcefield arguments
- `forcefield`: The type of the forcefield to use. Available: `MMFF94s`, `MMFF94`, `UFF`
- `tol`: The convergence tolerance of the optimization
- `max_step`: The max number of step for the optimization to conduct.

#### XYZ perception arguments
- `backends`: choose the backends for XYZ perception. It has no influence if you are using SMILES. Previously, `openbabel` xyz perception is prefered over `jensen`
- `header`: The xyz files contains a line indicates the number of atoms and a line of title/comments. If your string does not contain those two lines, set `header` to `False`.

#### TS-GCN arguments
- `TS_GCN_PYTHON`: The path to the python executable to run TS-GCN. If an conda environment is installed
for TS-GCN, then it should be something like `CONDA_HOME_PATH/envs/ENV_NAME/bin/python
- `TS_GCN_DIR`: The path to the directory where TS-GCN is installed.

In [None]:
############### Force Field ###################
# Force Field
force_field_type = "MMFF94s"
# Convergence criteria, Step size, Max step
tol, max_step = 1e-8, 10000
###############################################

############### XYZ Perception ################
# Backend perception algorithm
backends = ['openbabel', 'jensen']
# If the input XYZ has the first two lines (atom number + title/comments)
header = False
################################################

############ TS-GCN setup ###########################
TS_MODEL_PYTHON = '~/Apps/anaconda3/envs/ts_gcn/bin/python3.7'
TS_MODEL_DIR = '~/Apps/TS-GCN'
#######################################################

############ TS-EGNN setup ###########################
# TS_MODEL_PYTHON = '~/Apps/anaconda3/envs/ts_egnn/bin/python3.7'
# TS_MODEL_DIR = '~/Apps/ts_egnn'
#######################################################

############# For DA reaction only. ##################
# Whether specific a certain type of stereoisomer ('endo' or 'exo')
da_stereo_specific = 'endo'
#######################################################

# NOTE: The following is a testing feature not fully functioning!!!!
# This is an option to inform if the user want to use their input 3D geometries
# without modifications (excluding necesary alignment). Note, there are multiple cases
# that such setting can results in lower TS generation success rate. E.g., the reactant
# and the product have great differences in geometries of non-reacting atoms; the
# some reactant conformer may not on the IRC path
# force_user_input = False

### 1. Input molecule information
You can input SMILEs, XYZs or mix them together. Molecule instances are then generated from the input identifiers.<br>
**RECOMMENDATIONs:**
- **Better define the single species end of the reaction as the reactant.**
- **Better put the heavier product in the first place of the list.**
- **If you need to specify the multiplicity, make the molecule instance a tuple. E.g., reactants = [('XYZ_STRING', 1), ('SMILES', 2)] where 1 and 2 are multiplicities.**

Here, some examples are provided

Example: intra_H_migration

In [None]:
reactants = [
"""C -1.528265  0.117903  -0.48245
C -0.214051  0.632333  0.11045
C 0.185971  2.010727  -0.392941
O 0.428964  2.005838  -1.836634
O 1.53499  1.354342  -2.136876
H -1.470265  0.057863  -1.571456
H -1.761158  -0.879955  -0.103809
H -2.364396  0.775879  -0.226557
H -0.285989  0.690961  1.202293
H 0.605557  -0.056315  -0.113934
H -0.613001  2.746243  -0.275209
H 1.100271  2.372681  0.080302""",
]

products = [
    """C 1.765475  -0.57351  -0.068971
H 1.474015  -1.391926  -0.715328
H 2.791718  -0.529486  0.272883
C 0.741534  0.368416  0.460793
C -0.510358  0.471107  -0.412585
O -1.168692  -0.776861  -0.612765
O -1.768685  -1.15259  0.660846
H 1.164505  1.37408  0.583524
H 0.417329  0.069625  1.470788
H -1.221189  1.194071  0.001131
H -0.254525  0.771835  -1.433299
H -1.297409  -1.977953  0.837367""",
]

In [None]:
reactants = [
"""CCCO[O]""",
]

products = [
    """C 1.765475  -0.57351  -0.068971
H 1.474015  -1.391926  -0.715328
H 2.791718  -0.529486  0.272883
C 0.741534  0.368416  0.460793
C -0.510358  0.471107  -0.412585
O -1.168692  -0.776861  -0.612765
O -1.768685  -1.15259  0.660846
H 1.164505  1.37408  0.583524
H 0.417329  0.069625  1.470788
H -1.221189  1.194071  0.001131
H -0.254525  0.771835  -1.433299
H -1.297409  -1.977953  0.837367""",
]

In [None]:
reactants = [
"""CCCO[O]""",
]

products = [
"""[CH2]CCOO""",
]

Example: intra_OH_migration

In [None]:
reactants = [
"""OCCC[O]""",
]

products = [
    """C 1.765475  -0.57351  -0.068971
H 1.474015  -1.391926  -0.715328
H 2.791718  -0.529486  0.272883
C 0.741534  0.368416  0.460793
C -0.510358  0.471107  -0.412585
O -1.168692  -0.776861  -0.612765
O -1.768685  -1.15259  0.660846
H 1.164505  1.37408  0.583524
H 0.417329  0.069625  1.470788
H -1.221189  1.194071  0.001131
H -0.254525  0.771835  -1.433299
H -1.297409  -1.977953  0.837367""",
]

Example: intra_ene_reaction

In [None]:
reactants = [
"""C=CC=CCC""",
]

products = [
"""CC=CC=CC""",
]

In [None]:
reactants = [
"""C      2.365139   -0.823066    0.195886
C      1.132133   -0.448278   -0.615530
C      0.601799    0.908821   -0.244634
C     -0.469930    1.514001   -0.781599
C     -1.310197    0.964460   -1.818982
C     -2.368267    1.614877   -2.317239
H      2.724279   -1.814346   -0.098775
H      2.140638   -0.849340    1.267595
H      3.177943   -0.106912    0.034253
H      0.358910   -1.205983   -0.444261
H      1.397321   -0.462750   -1.678915
H      1.139704    1.449320    0.533030
H     -0.736243    2.502985   -0.410259
H     -1.070361   -0.020969   -2.210454
H     -2.973523    1.164170   -3.097508
H     -2.658935    2.600563   -1.968362"""
]

products = [
"""CC=CC=CC""",
]

Example: keto-enol

In [None]:
reactants = [
"""O 0.898799  1.722422  0.70012
C 0.293754  -0.475947  -0.083092
C -1.182804  -0.101736  -0.000207
C 1.238805  0.627529  0.330521
H 0.527921  -1.348663  0.542462
H 0.58037  -0.777872  -1.100185
H -1.45745  0.17725  1.018899
H -1.813437  -0.937615  -0.310796
H -1.404454  0.753989  -0.640868
H 2.318497  0.360641  0.272256""",
]

products = [
    """O 2.136128  0.058786  -0.999372
C -1.347448  0.039725  0.510465
C 0.116046  -0.220125  0.294405
C 0.810093  0.253091  -0.73937
H -1.530204  0.552623  1.461378
H -1.761309  0.662825  -0.286624
H -1.923334  -0.892154  0.536088
H 0.627132  -0.833978  1.035748
H 0.359144  0.869454  -1.510183
H 2.513751  -0.490247  -0.302535""",
]

Example: 2+2_cycloaddition

In [None]:
reactants = [
"""O -0.854577  1.055663  -0.58206
O 0.549424  1.357531  -0.196886
C -0.727718  -0.273028  -0.011573
C 0.76774  -0.043476  0.113736
H -1.066903  -1.044054  -0.706048
H -1.263435  -0.349651  0.939354
H 1.374762  -0.530738  -0.655177
H 1.220707  -0.172248  1.098653"""
           ]

products = [
"""O 0.0  0.0  0.682161
C 0.0  0.0  -0.517771
H 0.0  0.938619  -1.110195
H 0.0  -0.938619  -1.110195""",

"""O 0.0  0.0  0.682161
C 0.0  0.0  -0.517771
H 0.0  0.938619  -1.110195
H 0.0  -0.938619  -1.110195""",
]

Example: Diels_Alder

In [None]:
reactants = [
"""C      2.788553    0.698686    0.674316
C      2.218817   -1.464988    0.029675
C      2.516823   -0.656661    1.258397
C      2.662208    0.650837   -0.659411
C      2.310059   -0.686509   -1.057857
H      3.046237    1.573804    1.251124
H      1.969545   -2.515127    0.032875
H      1.657462   -0.631845    1.934608
H      3.393616   -1.044418    1.784949
H      2.798561    1.473331   -1.344992
H      2.148949   -0.993686   -2.080010""",

"""C     -0.567538   -0.593271   -0.685125
C      0.550187   -0.609810    0.305262
C     -0.935561    0.922337   -0.812382
C      0.866807    0.620438    0.725116
C     -2.448912    0.929590   -0.465129
C     -1.921658   -1.240688   -0.288092
C      0.013503    1.688438    0.113822
C     -2.886280   -0.387595   -1.126997
C     -2.356832   -0.769335    1.082806
C     -2.672212    0.530266    0.975717
H     -0.801330    1.252767   -1.851145
H     -0.225560   -0.979503   -1.653092
H      1.647573    0.834801    1.442595
H      1.040594   -1.515236    0.635305
H     -2.000313   -2.316795   -0.448365
H     -3.004507    1.811308   -0.786923
H     -3.943990   -0.631087   -0.961967
H     -2.677424   -0.424904   -2.202421
H      0.641104    2.375974   -0.463365
H     -0.510855    2.269595    0.877800
H     -2.347672   -1.359962    1.985934
H     -2.969955    1.189249    1.776970""",
    
]

In [None]:
reactants = [
"""C      2.788553    0.698686    0.674316
C      2.218817   -1.464988    0.029675
C      2.516823   -0.656661    1.258397
C      2.662208    0.650837   -0.659411
C      2.310059   -0.686509   -1.057857
H      3.046237    1.573804    1.251124
H      1.969545   -2.515127    0.032875
H      1.657462   -0.631845    1.934608
H      3.393616   -1.044418    1.784949
H      2.798561    1.473331   -1.344992
H      2.148949   -0.993686   -2.080010""",

"""C1=CC2C3C=CC(C3)C2C1""",
]

products = ['C1=CC2CC1C1CC3C4C=CC(C4)C3C21']

In [None]:
reactants = [
"""N     -0.235850    2.415630    0.592257\nC     -0.189406    1.303660    0.262333\nC     -0.117300   -0.107473   -0.154306\nC     -0.356309   -0.874025    1.119755\nC      0.668915   -0.886001    2.197513\nO      0.555397   -2.328930    2.249814\nC     -0.371544   -2.367112    1.153521\nC     -1.660797   -1.622402    1.301232\nC      1.207895   -0.411311   -0.842905\nH     -0.925686   -0.307130   -0.868447\nH      1.690050   -0.577603    1.962755\nH      0.360327   -0.425955    3.141272\nH     -0.151022   -2.934947    0.260576\nH     -2.128078   -1.611387    2.280157\nH     -2.368344   -1.672623    0.482321\nH      1.285265   -1.472416   -1.100426\nH      1.299761    0.165623   -1.770381\nH      2.071118   -0.154538   -0.219905""",
]

products = ['C     -0.688961   -0.876527    1.062550\nC      0.010910    0.011864    0.039818\nC     -0.535153    1.464633    0.128853\nC     -1.598752    2.140829   -0.361211\nO     -1.307383    3.269506    0.185148\nC     -0.108540    2.734726    0.794049\nC      1.467348    0.009173    0.264285\nN      2.613593    0.014300    0.449258\nH     -1.771062   -0.894655    0.891148\nH     -0.321266   -1.907257    1.004893\nH     -0.522346   -0.521521    2.086759\nH     -0.194738   -0.377266   -0.964573\nH     -2.425266    1.825331   -1.000264\nH     -0.147626    2.772910    1.886885\nH      0.796609    3.224983    0.422878',
('C     -0.000821    0.412249   -0.000000\nH     -0.909612   -0.206423   -0.000000\nH      0.910432   -0.205826    0.000000', 1)]

Example: Intra_R_Add_Endocyclic (A = B)

In [None]:
reactants = ["""C=CCCO[O]""",
]

products = ["""[CH2]C1CCOO1""",
]

Example: Retroene

In [None]:
reactants = [
"""CCC1C=CC=C1""",
]

products = [
"""C1C=CC=C1""",

"""C=C""",
]

Example: HO2_elimination

In [None]:
reactants = [
"""C -1.890664  -0.709255  -0.271996
C -0.601182  0.078056  -0.018811
C 0.586457  -0.545096  -0.777924
C -0.292203  0.188974  1.451901
H -0.683164  -0.56844  2.124827
C 0.477032  1.332664  2.012529
O -0.367239  2.493656  2.288335
O -0.679966  1.393013  -0.618968
O -1.811606  2.119506  -0.074789
H -1.819659  -1.711353  0.159844
H -2.063907  -0.801665  -1.346104
H -2.739557  -0.190076  0.171835
H 0.374452  -0.548385  -1.849706
H 1.501209  0.026135  -0.608139
H 0.747239  -1.572318  -0.444379
H 1.209047  1.707778  1.296557
H 0.998836  1.047896  2.931789
H -0.994076  2.235514  2.974109
H -1.392774  2.537261  0.704151"""
]

products = [
"""C -1.395681  1.528483  -0.00216
C -0.402668  0.411601  -0.210813
C -0.997629  -0.972081  -0.127641
C 0.890607  0.678979  -0.433435
C 2.015631  -0.28316  -0.676721
O 2.741986  0.043989  -1.867415
H -0.923699  2.509933  -0.072949
H -2.200649  1.479183  -0.744922
H -1.873843  1.44886  0.981238
H -1.839799  -1.068706  -0.822233
H -0.283424  -1.765173  -0.346167
H -1.400492  -1.154354  0.875459
H 1.201336  1.7219  -0.466637
H 2.754241  -0.212398  0.127575
H 1.667906  -1.32225  -0.7073
H 2.101868  0.079395  -2.5857""",

"""O -0.168488  0.443026  0.0
O 1.006323  -0.176508  0.0
H -0.837834  -0.266518  0.0""",
]

Example: H abstraction

In [None]:
reactants = [
"""CCC[O]""",
"""CC(C)=C(C)C""",]            

products = [
"""CCCO""",
"""[CH2]C(C)=C(C)C""",]

Example: Subsitution_O
This family currently have issue matching templates

In [None]:
reactants = [
"""CCCOCC""",
"""[CH3]""",]            

products = [
"""CCCOC""",
"""[CH2]C""",]

Example: 1+2_Cycloaddition

In [None]:
reactants = [
("""[CH2]""", 1),
"""C=C""",]       

products = [
"""C1CC1""",]

test

In [None]:
reactants = [
("""[OH]""", 2),
"""CC(=O)OCCCC""",
]

products = [
"""O""",
"""CC(=O)OCCC[CH2]"""
]

In [None]:
reactants = [
("""[O][O]""", 3),
"""CC1[CH]C=CCC1""",
]

products = [
"""O[O]""",
"""CC1C=C=CCC1"""
]

In [None]:
reactants = [
("""C     -3.179501   -2.101882   -3.441990
C     -3.542025   -1.477545   -2.107006
N     -2.443844   -0.718800   -1.561296
C     -2.201992    0.592030   -1.884568
O     -2.877732    1.266582   -2.648854
C     -1.009053    1.187838   -1.217057
N     -0.118197    0.258552   -0.478621
C      0.371111    0.850487   -1.720828
C      1.458422    1.867084   -1.689444
H     -2.311892   -2.762549   -3.341642
H     -4.017108   -2.690521   -3.828264
H     -2.929459   -1.334500   -4.182011
H     -4.410881   -0.821175   -2.223221
H     -3.793011   -2.256923   -1.380903
H     -1.767378   -1.152620   -0.939010
H     -1.177504    2.153008   -0.754097
H      0.274171    0.735071    0.336041
H      0.458069    0.179988   -2.568870
H      1.431462    2.472236   -2.601063
H      1.371919    2.543020   -0.832033
H      2.434612    1.374647   -1.637706""", 1),]       

products = [
("""C     -3.904420   -2.075553   -2.703411
C     -2.639835   -1.417565   -3.224048
N     -1.920649   -0.744198   -2.171784
C     -2.186323    0.544332   -1.786790
O     -3.027014    1.297326   -2.247859
C     -1.302591    1.023987   -0.654957
N     -0.404112    0.247940   -0.135749
H     -4.589106   -1.334083   -2.278584
H     -3.671938   -2.802457   -1.918020
H     -4.423976   -2.597892   -3.512445
H     -1.968870   -2.170525   -3.648895
H     -2.889048   -0.694696   -4.008079
H     -1.179932   -1.214631   -1.654949
H     -1.481620    2.057464   -0.318709
H      0.084027    0.757116    0.611780""", 1),
("""C     -3.567030   -1.473882   -2.176145
C     -3.178358   -2.108183   -3.458538
H     -4.409864   -0.794754   -2.133275
H     -2.311618   -2.757529   -3.309789
H     -4.003343   -2.709713   -3.849537
H     -2.920653   -1.344111   -4.196849""", 1)]

In [None]:
reactants = [
("""CCNC(=O)C1NC1C""", 1),]       

products = [
("""CCNC(=O)C=N""", 1),
("""[CH2]C""", 1)]

In [None]:
reactants = [
("""CC=CC""", 1),
"""C=C"""]       

products = [
("""C=CC(C)CC""", 1),]

## 2. Find RMG reaction and generate reactant/product complex

Check if this reaction matches RMG templates. If the reaction matches at least one RMG family, the result will be shown, and complexes will be generated. Otherwise, this notebook is not helpful to you.

In [None]:
# Generate reactant and product complex
if len(reactants) == 2 and len(products) == 1:
    reactants, products = products, reactants
    print('Warning: the reactants and the products are inverted for convenience!')

for backend in backends:
    print(f'Using \"{backend}\" method as the XYZ perception backend.')
    try:
        # Convert XYZ to rdkit mol
        r_mols, r_is_3D = parse_xyz_or_smiles_list(reactants, backend=backend, header=False)
        p_mols, p_is_3D = parse_xyz_or_smiles_list(products, backend=backend, header=False)
        # Detect 3D information
        r_all_3D, p_all_3D = all(r_is_3D), all(p_is_3D)
        r_any_3D, p_any_3D = r_all_3D or any(r_is_3D), p_all_3D or any(p_is_3D)
        # Convert rdkit mol to RMG mol
        r_rmg_mols = [from_rdkit_mol(r.ToRWMol()) for r in r_mols]
        p_rmg_mols = [from_rdkit_mol(p.ToRWMol()) for p in p_mols]

    except Exception as e:
        print(e)
        print(f'Cannot generate molecule instances using {backend}...')
        continue

    else:
        # A product complex with the same atom indexing as the reactant is generated
        family_label, forward = find_reaction_family(rmg_db,
                                                     r_rmg_mols,
                                                     p_rmg_mols,
                                                     verbose=False)
        r_complex, p_complex = generate_reaction_complex(rmg_db,
                                                         r_rmg_mols,
                                                         p_rmg_mols,
                                                         only_families=[family_label],
                                                         verbose=False)
    if not r_complex:
        # Cannot find the reaction
        continue

    try:
        # Convert complexes back from their RMG molecule forms to RDKitMol form
        r_complex, p_complex = RDKitMol.FromRMGMol(r_complex), RDKitMol.FromRMGMol(p_complex)
    except Exception as e:
        # There can be some problem converting RMG mol back to RDKit
        print(e); continue
    else:
        print('Find a match!\n'); break
else:
    print('No matched RMG reaction is found for the given reactants and products.')

if r_complex:
    if (# CONDITION 1: for families whose product contains more information
        family_label in REACTANT_MORE_INFORMATIVE_FAMILIES and \
        REACTANT_MORE_INFORMATIVE_FAMILIES[family_label] != forward) or \
       (# CONDITION 2: own_reverse family and information is inbalanced
        (family_label in OWN_REVERSE or family_label in BIMOLECULAR) and \
        ((p_all_3D and not r_all_3D) or (p_any_3D and not r_any_3D))):
            # For convenience, revert the sequence of reactants and products
            reactants, products, r_mols, p_mols, r_rmg_mols, p_rmg_mols, r_complex, p_complex = \
                products, reactants, p_mols, r_mols, p_rmg_mols, r_rmg_mols, p_complex, r_complex
            r_is_3D, p_is_3D, r_any_3D, p_any_3D, r_all_3D, p_all_3D = \
                p_is_3D, r_is_3D, p_any_3D, r_any_3D, p_all_3D, r_all_3D
            forward = not forward
            print('Warning: the reactants and the products are inverted for convenience!')


    reaction_type = '+'.join(['A', 'B'][:len(r_mols)]) + '=' + '+'.join(['C', 'D'][:len(p_mols)])
    print(' + '.join([s.ToSmiles() for s in r_mols]) + \
          ' <=> ' + \
          ' + '.join([s.ToSmiles() for s in p_mols]))
    print(f'RMG family: {family_label}\nIs forward reaction: {forward}')
    print(f'This is a {reaction_type} reaction\n')

    # Find formed and broken bonds
    formed_bonds, broken_bonds = get_formed_and_broken_bonds(r_complex, p_complex)
    print(f'Bonds are FORMED: {formed_bonds}\nBonds are BROKEN: {broken_bonds}')
    only_break_bonds = not any(formed_bonds)
    if only_break_bonds:
        print(f'This is a reaction that only breaks bonds!')


    to_print = {(True, True): 'all of', (False, True): 'part of', (False, False): 'none of'}
    print(f'{to_print[(r_all_3D, r_any_3D)].capitalize()} the reactant geometries and '
          f'{to_print[(p_all_3D, p_any_3D)]} the products geometries are provided.')

    # A state variable of the script
    FINISHED = False

## 3. Complexes generation

### 3.1  A = C reactions

In [None]:
if reaction_type == 'A=C':

    ############################
    ## 1. Geometry initialize ##
    ############################

    # 1.1 Set the r_complex to the given geometry, otherwise, embed one.
    r_bond_length = []
    if r_all_3D:
        r_match = r_complex.GetSubstructMatch(r_mols[0])
        r_bond_length.extend(get_bond_length_list(mol=r_mols[0],
                                                  match=r_match))
        r_complex.SetPositions(r_mols[0].GetPositions()[reverse_map(r_match, as_list=False), :])
    else:
        # TODO: May embed several times until get the desired conformers
        r_complex.EmbedConformer()
        r_complex = optimize_mol(r_complex,
                                 frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                       for bond in formed_bonds])

    # 1.2 Grab the product information; Initialize p_complex to the same coordinates
    # as the r_complex unless no reactant geometry given and reactant not necessary
    # has more information than the product.
    p_bond_length = []
    p_complex.SetPositions(r_complex.GetPositions())
    if p_all_3D:
        # Match the p_complex to the product geometry
        p_match = p_complex.GetSubstructMatch(p_mols[0])
        p_bond_length.extend(get_bond_length_list(mol=p_mols[0],
                                                  match=p_match))
        if not r_all_3D and family_label not in REACTANT_MORE_INFORMATIVE_FAMILIES:
            p_complex.SetPositions(p_mols[0].GetPositions()[reverse_map(p_match, as_list=False), :])
            r_complex.SetPositions(p_complex.GetPositions())
            r_complex = optimize_mol(r_complex,
                                     frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                           for bond in formed_bonds])

    ##############################
    ## 2. Geometry optimization ##
    ##############################

    # 2.1 Optimize the product. Initial guess is the geometry r_complex, which makes sure
    # non-reacting coordinates won't change too much from the reactants.
    p_complex = optimize_mol(p_complex,
                             frozen_bonds=p_bond_length,
                             frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                   for bond in broken_bonds])

    # 2.2 Optimize the reactant geometry again, if it is not more informative than the product,
    # Step 2.2 may introduce more constraints to the reactant geometry
    if family_label not in REACTANT_MORE_INFORMATIVE_FAMILIES:
        r_complex.SetPositions(p_complex.GetPositions())
        r_complex = optimize_mol(r_complex,
                                 frozen_bonds=r_bond_length,
                                 frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                      for bond in formed_bonds])


### 3.2 A = C + D reactions

In [None]:
if reaction_type == 'A=C+D':

    ############################
    ## 1. Geometry initialize ##
    ############################

    # 1.1 Set the r_complex to the given geometry, otherwise, embed one.
    r_bond_length = []
    if r_all_3D:
        r_match = r_complex.GetSubstructMatch(r_mols[0])
        r_bond_length.extend(get_bond_length_list(mol=r_mols[0],
                                                  match=r_match))
        r_complex.SetPositions(r_mols[0].GetPositions()[reverse_match(r_match, as_list=False), :])

        # Check DA reaction
        # Only check DA now
        if family_label == 'Diels_alder_addition' and da_stereo_specific and \
            (is_DA_rxn_endo(r_complex, p_complex, embed=True) != (da_stereo_specific == 'endo')):
            raise ValueError('The provided DA product doesn\'t match the stereotype '
                             'required. You have to provide another DA product geometry!')
    else:
        # TODO: May embed several times until get the desired conformers
        r_complex.EmbedConformer()

        if family_label == 'Diels_alder_addition' and da_stereo_specific:
            is_endo = da_stereo_specific == 'endo'
            max_num_try = 100
            for i in range(max_num_try):
                rxn_is_endo = is_DA_rxn_endo(r_complex, p_complex, embed=True)
                if rxn_is_endo == 'none' or rxn_is_endo == is_endo:
                    # This reaction may not distinguish endo or exo
                    break
                r_complex.EmbedConformer()
            else:
                raise RuntimeError('Have trouble to find a conformer with the desired stereo type.')

        r_complex = optimize_mol(r_complex,
                                 frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                      for bond in formed_bonds])

    # 1.2 Grab the product information; Initialize p_complex to the same coordinates
    # as the r_complex unless no reactant geometry given and reactant not necessary
    # has more information than the product.
    p_bond_length = []
    p_complex.SetPositions(r_complex.GetPositions())
    if p_any_3D:

        new_xyz = np.zeros((p_complex.GetNumAtoms(), 3))
        p_frags_idx = match_mols_to_complex(mol_complex=p_complex,
                                             mols=p_mols)
        for i, is_3D in enumerate(p_is_3D):
            if not is_3D:
                new_xyz[p_frags_idx[i], :] = p_complex.GetPositions()[p_frags_idx[i], :]
            else:
                # Align the geometry
                atom_map = list(enumerate(p_frags_idx[i]))
                p_complex.AlignMol(prbMol=p_mols[i],
                                   atomMaps=[atom_map])
#                 p_mols[i].AlignMol(refMol=p_complex,
#                                    atomMap=atom_map)
                new_xyz[p_frags_idx[i], :] = p_mols[i].GetPositions()
                p_bond_length.extend(get_bond_length_list(mol=p_mols[i],
                                                          match=p_frags_idx[i]))
        p_complex.SetPositions(new_xyz)

    ##############################
    ## 2. Geometry optimization ##
    ##############################

    # 2.1 Optimize the product. Initial guess is the geometry r_complex, which makes sure
    # non-reacting coordinates won't change too much from the reactants.
    p_complex = optimize_mol(p_complex,
                             frozen_bonds=p_bond_length,
                             frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                   for bond in broken_bonds])

    # 2.3 Optimize the reactant geometry again, if it is not more informative than the product
    # Step 2.3 may introduce more constraints to the reactant geometry 
    if family_label not in REACTANT_MORE_INFORMATIVE_FAMILIES:
        r_complex.SetPositions(p_complex.GetPositions())
        r_complex = optimize_mol(r_complex,
                                 frozen_bonds=r_bond_length,
                                 frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                      for bond in formed_bonds])


### 3.3 A + B = C + D reactions

In [None]:
# Test feature, helpful for finding guesses with H-bond 
ignore_interfrag_interaction=False

In [None]:
if reaction_type == 'A+B=C+D':

    # There is always a fragment being transfered between the two things
    # H abstraction, Disproportionation, CO_Disproportionation: H
    # Substitution_O, SubstitutionS: RO / RS group
    # And there will be one bond forms and one bond breaks
    # Implementation is based on this observation


    # Get the transfered atom
    transfered_atom = list(set(formed_bonds[0]) & set(broken_bonds[0]))[0]
    
    # Find the flux pair
    r_frags_idx = match_mols_to_complex(mol_complex=r_complex,
                                         mols=r_mols)
    p_frags_idx = match_mols_to_complex(mol_complex=p_complex,
                                         mols=p_mols)
    
    pairs = {0: 0, 1: 1}  # True for both reactant to product and product to reactant
    if transfered_atom in r_frags_idx[0] == transfered_atom in p_frags_idx[0]:
        pairs = {0: 1, 1: 0}
    
    # Re-analyze xyz based on the pairs. In the pair, if both reactant and the product is provided,
    # Then only use the reactant one.
    # Possible cases:
    # - r_all_3D: use 3D geometries of reactants
    # - p_all_3D and not r_all_3D: not possible, due to the reactant, product switch in previous step
    # - r_any_3D and p_any_3D: if same pair: use 3D geometries of the reactant
    #                          if different pair: use both geometries
    # - r_any_3D: Use the geometry anyway
    # - p_any_3D and not r_any_3D: not possible, due to the reactant, product switch in previous step
    # - non geometry: embed r_complex
    
    if not r_all_3D and not p_all_3D and r_any_3D and p_any_3D:
        if (r_is_3D[0] and p_is_3D[pairs[0]]) or \
           (r_is_3D[1] and p_is_3D[pairs[1]]):
            p_is_3D, p_any_3D = [False, False], False
    
    # First, create complexes that stores alignment information
    # After embed, molecules are overlapping
    # forcefield optimization helps de-overlapping
    r_complex.EmbedConformer()  
    r_complex = optimize_mol(r_complex,
                             frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                   for bond in formed_bonds])
    p_complex.SetPositions(r_complex.GetPositions())
    p_complex = optimize_mol(p_complex,
                             # This is a experimental arguments
                             # It seems that ignore interfrag_interaction helps
                             # generate better A + B = C + D reactions
                             ignore_interfrag_interaction=ignore_interfrag_interaction,
                             frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                   for bond in broken_bonds])

    r_bond_length = []
    p_bond_length = []

    if r_all_3D:
        new_xyz = np.zeros((p_complex.GetNumAtoms(), 3))
        for i in range(len(r_mols)):
            # Align the geometry
            atom_map = list(enumerate(r_frags_idx[i]))
            r_mols[i].AlignMol(refMol=r_complex,
                               atomMaps=atom_map)
            new_xyz[r_frags_idx[i], :] = r_mols[i].GetPositions()
            r_bond_length.extend(get_bond_length_list(mol=r_mols[i],
                                                      match=r_frags_idx[i]))
        r_complex.SetPositions(new_xyz)
        r_complex = optimize_mol(r_complex,
                             frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                   for bond in formed_bonds])
        if p_any_3D:
            for i in range(len(p_mols)):
                if p_is_3D[i]:
                    p_bond_length.extend(get_bond_length_list(mol=p_mols[i],
                                                              match=p_frags_idx[i]))
        p_complex.SetPositions(r_complex.GetPositions())
        p_complex = optimize_mol(p_complex,
                                 frozen_bonds=p_bond_length,
                                 frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                   for bond in broken_bonds])
        
    elif r_any_3D and not p_any_3D:
        new_xyz = np.zeros((p_complex.GetNumAtoms(), 3))
        for i in range(len(r_mols)):
            if r_is_3D[i]:
                # Align the geometry
                atom_map = list(enumerate(r_frags_idx[i]))
                r_mols[i].AlignMol(refMol=r_complex,
                                   atomMaps=atom_map)
                new_xyz[r_frags_idx[i], :] = r_mols[i].GetPositions()
                r_bond_length.extend(get_bond_length_list(mol=r_mols[i],
                                                          match=r_frags_idx[i]))
            else:
                new_xyz[r_frags_idx[i], :] = r_complex.GetPositions()[r_frags_idx[i], :]

        r_complex.SetPositions(new_xyz)
        r_complex = optimize_mol(r_complex,
                                 frozen_bonds=r_bond_length,
                                 frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                   for bond in formed_bonds])
        p_complex.SetPositions(r_complex.GetPositions())
        p_complex = optimize_mol(p_complex,
                                 frozen_bonds=p_bond_length,
                                 frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                   for bond in broken_bonds])
        
    elif r_any_3D and p_any_3D:
        new_xyz = np.zeros((p_complex.GetNumAtoms(), 3))
        for i in range(len(r_mols)):
            if r_is_3D[i]:
                # Align the geometry
                atom_map = list(enumerate(r_frags_idx[i]))
                r_mols[i].AlignMol(refMol=r_complex,
                                   atomMaps=atom_map)
                new_xyz[r_frags_idx[i], :] = r_mols[i].GetPositions()
                r_bond_length.extend(get_bond_length_list(mol=r_mols[i],
                                                          match=r_frags_idx[i]))
            else:
                trucated_p_frag_idx, undeteremined_idx = [], []
                for i in p_frags_idx[pairs[i]]:
                    if i in r_frags_idx[i]:
                        tructated_p_frag_idx.append(i)
                    else:
                        undetermined_idx.append(i)
                atom_map = list(enumerate(trucated_p_frag_idx))
                p_mols[pairs[i]].AlignMol(refMol=r_complex,
                                          atomMaps=atom_map)
                new_xyz[trucated_p_frag_idx, :] = p_mols[pairs[i]].GetPositions()
                p_bond_length.extend(get_bond_length_list(mol=p_mols[i],
                                                          match=p_frags_idx[i]))
                new_xyz[undetermined_idx, :] = r_complex.GetPositions()[undetermined_idx, :]
        
        r_complex.SetPositions(new_xyz)
        r_complex = optimize_mol(r_complex,
                                 frozen_bonds=r_bond_length,
                                 frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                   for bond in formed_bonds])
        p_complex.SetPositions(r_complex.GetPositions())
        p_complex = optimize_mol(p_complex,
                                 frozen_bonds=p_bond_length,
                                 frozen_non_bondings=[[bond, BOND_CONSTRAINT.get(family_label, 3.0)]
                                                   for bond in broken_bonds])
    
    if p_any_3D:
        
        new_xyz = np.zeros((p_complex.GetNumAtoms(), 3))
        p_frags_idx = match_mols_to_complex(mol_complex=p_complex,
                                             mols=p_mols)
        for i, is_3D in enumerate(p_is_3D):
            if not is_3D:
                new_xyz[p_frags_idx[i], :] = p_complex.GetPositions()[p_frags_idx[i], :]
            else:
                # Align the geometry
                atom_map = list(enumerate(p_frags_idx[i]))
                p_mols[i].AlignMol(refMol=p_complex,
                                   atomMaps=atom_map)
                new_xyz[p_frags_idx[i], :] = p_mols[i].GetPositions()
                p_bond_length.extend(get_bond_length_list(mol=p_mols[i],
                                                          match=p_frags_idx[i]))
        p_complex.SetPositions(new_xyz)
    


### Find the best atom mapping by RMSD. 
At this point, all heavy atoms are mapped, but some H atoms may be no longer mapped, for example due to a rotation in the methyl rotor during the optimization. We recommend you to do this step, but it is not a requirement though

NOTE:
1. this can perform relatively poorly if the reactant and the product are in different stereotype (cis/trans). or most rotors are significantly different oriented. However, previous step (match according to RMG reaction) makes sure that all heavy atoms and reacting H atoms are consistent, so only H atoms that are more trivial are influenced by this.
2. AlignMol can yields wrong numbers, we switch to `GetBestRMS` and `CalcRMS`.

In [None]:
# Whether to find better matches by reflecting the molecule (resulting in mirror image)
reflect = False

In [None]:
# Generate substructure matches
# There is no difference using `p_combine` or `p_complex` as the argument
# Since both of them have the same connectivity information
matches = p_complex.GetSubstructMatches(p_complex, uniquify=False)

# Make a copy of p_combine to preserve its original information
p_align = p_complex.Copy()
weights = [atom.GetMass() for atom in p_align.GetAtoms()]

rmsds = []

# Align the combined complex to the rmg generated complex
# According to different mapping and find the best one.
for i, match in enumerate(matches):
    atom_map = [list(enumerate(match))]
    rmsd1 = Chem.rdMolAlign.CalcRMS(prbMol=p_align.ToRWMol(),
                                       refMol=r_complex.ToRWMol(),
                                       map=atom_map,
                                       weights=weights)
    if reflect:
        p_align.Reflect()
        rmsd2 = Chem.rdMolAlign.CalcRMS(prbMol=p_align.ToRWMol(),
                                           refMol=r_complex.ToRWMol(),
                                           map=atom_map,
                                           weights=weights)
        p_align.Reflect()
    else:
        rmsd2 = 1e10
    if rmsd1 > rmsd2:
        rmsds.append((i, True, rmsd2,))
    else:
        rmsds.append((i, False, rmsd1,))
best = sorted(rmsds, key=lambda x: x[2])[0]
print('Match index: {0}, Reflect Conformation: {1}, RMSD: {2}'.format(*best))

# Realign and reorder atom indexes according to the best match
best_match = matches[best[0]]
r_complex.AlignMol(prbMol=p_align,
                   atomMaps=[list(enumerate(best_match))],
                   weights=weights) 
if best[1]:
    p_align.Reflect()
new_atom_indexes = [best_match.index(i) for i in range(len(best_match))]
p_align = p_align.RenumberAtoms(new_atom_indexes)

### 4. View Complexes

In [None]:
mols_to_view = [r_complex, p_align]
entries = len(mols_to_view)

viewer = grid_viewer(viewer_grid=(1, entries), viewer_size=(240 * entries, 300),)
for i in range(entries):
    mol_viewer(mols_to_view[i], viewer=viewer, viewer_loc=(0, i))

print('reactant complex    product complex')
viewer.show()

### 5. Export to SDF file and run ts_gen

In [None]:
r_complex.ToSDFFile('reactant.sdf')
p_align.ToSDFFile('product.sdf')

#### 5.1 Run TS model

In [None]:
try:
    subprocess.run(f'export PYTHONPATH=$PYTHONPATH:{TS_MODEL_DIR};'
                   f'{TS_MODEL_PYTHON} {TS_MODEL_DIR}/inference.py '
                   f'--r_sdf_path reactant.sdf '
                   f'--p_sdf_path product.sdf '
                   f'--ts_xyz_path TS.xyz',
                   check=True,
                   shell=True)
except subprocess.CalledProcessError as e:
    print(e)
else:
    with open('TS.xyz', 'r') as f:
        ts_xyz=f.read()
    ts = r_complex.Copy()
    ts.SetPositions(ts_xyz, header=True)

### 6. Visualize TS

In [None]:
# Align the TS to make visualization more convenient
ts.GetBestAlign(refMol=r_complex,
                keepBestConformer=True)

print('reactant    TS      product')
ts_viewer(*[r_complex, p_align, ts],
          alignment=['r','ts','p'],
          vertically_aligned=False,
          ts_bond_color='#f2f2f2',
          ts_bond_width=0.05)

In [None]:
print(ts.ToXYZ())