# Find TS conformers
This uses a RDkit as backend to generate the TS conformers and calculate the energy by Psi4.

Note: if the jupyter notebook crashes. It is highly possible that the `openbabel` and `rdkit` <br>
uses different dependencies during compiling, assuming you are using conda builds.<br>
You can either try to build your own version (solve the problem) or change the import sequence<br>
(can make the latter one malfunctioning when use some of the methods)<br>

## 1. Generate conformers
Necessary packages

In [None]:
from typing import Optional, Union
import os
import sys
# To add this RDMC into PYTHONPATH in case you haven't do it
sys.path.append(os.path.dirname(os.path.abspath('')))
from itertools import combinations, product
from collections.abc import Iterable

import numpy as np

from rdkit import Chem
from rdmc import RDKitMol
from rdmc.ts import get_all_changing_bonds
from rdmc.view import mol_viewer, grid_viewer
from rdmc.external.gaussian import GaussianLog

%load_ext autoreload
%autoreload 2


In [None]:
def gen_scan_angle_list(samplings: Union[list, tuple],
                        from_angles: Optional[Iterable] = None,
                        scale=360.,):
    """
    Get a angle list for each input dimension. For each dimension
    The input can be a int, indicating the angles will be evenly sampled;
    Or a list, indicate the angles to be sampled;
    Examples:
    [[120, 240,], 4, 0] => [np.array([120, 240,]),
                            np.array([0, 90, 180, 270,]),
                            np.array([0])]
    List of np.arrays are returned for the sake of further calculation

    Args:
        samplings (Union[list, tuple]): An array of sampling information.
                  For each element, it can be either list or int.
        from_angles (Union[list, tuple]): An array of initial angles.
                    If not set, angles will begin at zeros.

    Returns:
        list: A list of sampled angles sets.
    """
    from_angles = from_angles or len(samplings) * [0.]
    angle_list = []
    for ind, angles in enumerate(samplings):
        # Only provide a number
        # This is the step number of the angles
        if isinstance(angles, (int, float)):
            try:
                step = scale // angles
            except ZeroDivisionError:
                # Does not change
                angles = from_angles[ind] + np.array([0])
            else:
                angles = from_angles[ind] + \
                         np.array([step * i for i in range(angles)])
        elif isinstance(angles, Iterable):
            angles = from_angles[ind] + np.array(angles)

        # Set to angles to be within 0 - scale
        for i in range(angles.shape[0]):
            while angles[i] < 0.:
                angles[i] += scale
            while angles[i] > scale:
                angles[i] -= scale

        angle_list.append(angles.tolist())
    return angle_list


def conformers_by_change_torsions(conf: 'RDKitConf',
                                  angle_mesh,
                                  bookkeep: dict,
                                  torsions=None,
                                  on_the_fly_check=False):
    """
    Generate conformers by rotating the angles of the torsions. The result will be saved into
    ``bookkeep``. A on-the-fly check can be applied, which identifies the conformers with colliding
    atoms.

    Args:
        conf (RDkitConf): A RDKit Conformer to be used.
        angle_mesh (iterable): An iterable contains the angle_list for conformers to be generated from.
        bookkeep (dict): A dictionary to save the coords.
        torsions (list): A list of four-atom-index lists indicating the torsional modes.
        on_the_fly_filter (bool): Whether to check colliding atoms on the fly.
    """
    if not torsions:
        torsions = conf.GetTorsionalModes()
        for ind, angles in enumerate(angle_mesh):
            conf.SetAllTorsionsDeg(angles)
            bookkeep[ind] = {'angles': angles,
                             'coords': conf.GetPositions().tolist()}
            bookkeep[ind]['colliding_atoms'] = conf.HasCollidingAtoms() \
                if on_the_fly_check == True else None

    else:
        all_torsions = conf.GetTorsionalModes()
        try:
            changing_torsions_index = [all_torsions.index(tor) for tor in torsions]
        except ValueError as e:
            # tor not in all_torsions
            raise

        original_angles = conf.GetAllTorsionsDeg()

        for ind, angles in enumerate(angle_mesh):
            for i, angle, tor in zip(range(len(angles)), angles, torsions):
                conf.SetTorsionDeg(tor, angle)
                original_angles[changing_torsions_index[i]] = angle

            bookkeep[ind] = {'angles': original_angles,
                             'coords': conf.GetPositions().tolist()}
            bookkeep[ind]['colliding_atoms'] = conf.HasCollidingAtoms() \
                    if on_the_fly_check == True else None

## Arguments

In [None]:
VISUAL_MOLECULE = True

## 1.1 Perceive TS

### 1.1.1 Directly input the TS conformer geometry [WIP]

Directly input the geometry of the TS. You need to also provide the atom-mapped reactants and products to help analyze the bonding situation

In [None]:
############## INPUT  ##################
xyz_str = """C     -3.513463   -0.214965   -0.355384
C     -2.054689   -0.689928   -0.311191
C     -1.171525    0.126171    0.627947
C     -0.583913    1.322614    0.186349
O      1.375372    1.065851   -0.208339
C      1.838730   -0.086320    0.026443
O      1.196280   -1.003171    0.628779
C      3.238549   -0.405117   -0.451715
H     -1.626384   -0.657815   -1.319769
H     -2.027644   -1.742210   -0.009143
H     -0.038732   -0.513987    0.715911
H     -1.472457    0.124541    1.677775
H     -0.676687    1.649585   -0.842428
H     -0.306660    2.098582    0.887087
H     -3.582321    0.818358   -0.709252
H     -4.115398   -0.839707   -1.021863
H     -3.968731   -0.253716    0.638890
H      3.175289   -1.092244   -1.300079
H      3.755340    0.503284   -0.758985
H      3.789038   -0.910844    0.343636
"""

r_complex = RDKitMol.FromSmiles('[C:0]([C:1]([C:2]([C:3]([O:4][C:5](=[O:6])[C:7]([H:17])([H:18])[H:19])([H:12])[H:13])([H:10])[H:11])([H:8])[H:9])([H:14])([H:15])[H:16]')
p_complex = RDKitMol.FromSmiles('[C:0]([C:1]([C:2](=[C:3]([H:12])[H:13])[H:11])([H:8])[H:9])([H:14])([H:15])[H:16].[O:4]=[C:5]([O:6][H:10])[C:7]([H:17])([H:18])[H:19]')

######################################
ts = r_complex.Copy()
ts.SetPositions(xyz_str, header=False)

### 1.1.2 Read a TS frequency job / IRC job
Frequency jobs or IRC jobs provide extra information about the bonding


In [None]:
############## INPUT  ##################
log_path = 'data/ts-cbsqb3.out'
########################################
glog = GaussianLog(log_path)
mol = glog.get_mol(backend='openbabel')

if 'freq' in glog.job_type and glog.success:
    r_complex, p_complex = glog.guess_rxn_from_normal_mode(
        amplitude=0.5, atom_weighted=True)
    r_complex = r_complex[0]
    p_complex = p_complex[0]
elif 'irc' in glog.job_type and glog.success:
    r_complex, p_complex = glog.guess_rxn_from_irc()

### 1.1.3 Create a fake molecule

Create a fake molecule that have all the bonds. The purpose is to identify the rigidity dihedrals in the TS.

In [None]:
formed_bonds, broken_bonds, change_bonds = get_all_changing_bonds(r_complex, p_complex)

fake_ts = ts.Copy()
for bond in change_bonds:
    bond1 = r_complex.GetBondBetweenAtoms(*bond)
    bond2 = p_complex.GetBondBetweenAtoms(*bond)
    if bond1.GetBondTypeAsDouble() > bond2.GetBondTypeAsDouble():
        fake_ts.GetBondBetweenAtoms(*bond).SetBondType(bond1.GetBondType())
    else:
        fake_ts.GetBondBetweenAtoms(*bond).SetBondType(bond2.GetBondType())
fake_ts = fake_ts.AddRedundantBonds(bonds=formed_bonds)

## 1.2 Use RDKit to generate conformers

### 1.2.1 Get the torsional mode and the original angles

In [None]:
# You can set the correct (all) torsions, otherwise RDKit will perceive.
######################################
# INPUT
torsions = []
exclude_methyl_rotors = False
######################################
if not torsions:
    torsions = fake_ts.GetTorsionalModes(excludeMethyl=exclude_methyl_rotors)
    print(f'RDKit perceived torsions: {torsions}')


conf = fake_ts.GetConformer()
conf.SetTorsionalModes(torsions)
num_torsions = len(torsions)
original_angles = conf.GetAllTorsionsDeg()
print(f'The original dihedral angles is: {original_angles}')
if VISUAL_MOLECULE:
    mol_viewer(fake_ts).update()

### 1.2.3 Generate conformers according to the angle mesh

#### Example 1:
Sampling the angles `0, 120, 240` for each torsion for a 7 heavy atom species with 5 rotors cost ~20 ms on Intel(R) Core(TM) i9-9880H CPU @ 2.30GHz

#### Example2:
Sampling the angles with a 45 x 45 evenly spaced mesh for each torsion pairs of a 7 heavy atom species with 5 rotors cost 1.4 s on Intel(R) Core(TM) i9-9880H CPU @ 2.30GHz

- `RESOLUTION`: the resolution in degree for rotational bond scan
- `RESOLUTION_METHYL`: the resolution in degree for rotational bond scan for the methyl group
- `DIMENSION`: the dimension for rotor coupling. The default is `0` for coupling all rotors
- `SAMPLING` : The sampling for each rotor. If `sampling` is provided as an empty list `[]`, it will be automatically created.

In [None]:
################ INPUT ################################

RESOLUTION = 60  # degrees
RESOLUTION_METHYL = 180  # degrees
DIMENSION = 0
SAMPLING = []  # you can provide something like SAMPLING = [3, 3, 3] to customized the sampling
########################################################

if not SAMPLING:
    methyl_carbons = [item[0] for item in fake_ts.GetSubstructMatches(RDKitMol.FromSmarts('[CH3]'))]
    sampling = []
    for tor in torsions:
        if tor[1] in methyl_carbons or tor[2] in methyl_carbons:
            sampling.append(360 // RESOLUTION_METHYL)
        else:
            sampling.append(360 // RESOLUTION)
else:
    sampling = SAMPLING
print(sampling)

Generate initial guesses

In [None]:
bookkeeps = {}
if DIMENSION == 0:
    DIMENSION = len(torsions)
init_coords = conf.GetPositions()
for tor_indexes in combinations(range(len(torsions)), DIMENSION):
    # Reset the geometry
    conf.SetPositions(init_coords)
    # Get angles
    sampling_points = [sampling[i] for i in tor_indexes]
    tor_orig_angles = [original_angles[i] for i in tor_indexes]
    tor_to_gen = [torsions[i] for i in tor_indexes]

    angles_list = gen_scan_angle_list(sampling_points,
                                      tor_orig_angles)
    angle_mesh = product(*angles_list)
    # Generate conformers
    bookkeep = {}
    conformers_by_change_torsions(conf,
                                  angle_mesh,
                                  bookkeep=bookkeep,
                                  torsions=tor_to_gen,
                                  on_the_fly_check=False)
    bookkeeps[str(tor_indexes)] = bookkeep

### 1.3 [OPTIONAL] Check volume of the TS
You can check the distribution of the molecule exvolume to see whether the molecule is apart. It is possible to see smaller volume due to folding or colliding.

In [None]:
import seaborn as sns
from rdkit import Chem

check_bookkeep = bookkeeps
random_points = 30  # each group

In [None]:
all_volumes = []
for bk in check_bookkeep.values():
    rnd_sample = np.random.randint(len(bk), size=min(random_points, len(bk)))
    volume = np.zeros_like(rnd_sample)
    for index in range(rnd_sample.shape[0]):
        conf.SetPositions(bk[rnd_sample[index]]['coords'])
        volume[index] = Chem.AllChem.ComputeMolVolume(fake_ts.ToRWMol())
    all_volumes.append(volume)

In [None]:
ax = sns.violinplot(data=all_volumes)
ax.set_xlabel('group index')
ax.set_ylabel('volume')
for tick in ax.get_xticklabels():
    tick.set_rotation(90)

## 2.1 Calculate using Psi4 [Not working]

This section is only for testing instead of actual tasks.
From experience, for conformer search purpose, better to use `n_threads_each_calculation = 1` and use `n_worker` as many as possible

In [None]:
import psi4

# How many threads to use as worker
n_worker = 8  # -1 to use all threads
n_memory_each_calculation = 12000 / n_worker  # Assuming you have 
n_threads_each_calculation = 1
reference = 'uhf'
level_of_theory = 'b3lyp/def2-svp'

In [None]:
def geom_producer(bookkeep, xyz_dict):
    for ind, conf in bookkeep.items():
        xyz_dict['coords'] = conf['coords']
        xyz_file = xyz_dict_to_xyz_file(xyz_dict)
        yield (ind, xyz_file)
        
def get_psi4_dftenergy(ind, xyz_file):
    psi4.set_memory(f'{n_memory_each_calculation} MB')
    psi4.set_options({'reference': reference})
    try:
        psi4.geometry(xyz_file)
        psi4.set_num_threads(n_threads_each_calculation)
        return (ind, psi4.energy(level_of_theory))
    except Exception as e:
        print(e)
        return (ind, 1e4)

In [None]:
result = Parallel(n_jobs=n_worker, verbose=100) \
         (delayed(get_psi4_dftenergy)(*data) for data in geom_producer(bookkeep, xyz_dict))

### 2.2 Optimize using Forcefield

In [None]:
from rdmc.forcefield import RDKitFF

# one of the ND combination as an example.
bookkeep = list(bookkeeps.values())[0]

# To use force field we need a normal molecule as a template. Otherwise force
# field may not be able to find parameters
ts_mol = r_complex.Copy()

### 2.2.1 Align reaction center

In [None]:
ts_mol.EmbedMultipleConfs(len(bookkeep))
confs = ts_mol.GetAllConformers()
for i, value in bookkeep.items():
    confs[i].SetPositions(value['coords'])
symbols = ts_mol.GetElementSymbols()

atom_list = list(set.union(*formed_bonds+broken_bonds+change_bonds))  # Align according to the reaction center
Chem.rdMolAlign.AlignMolConformers(ts_mol.ToRWMol(), maxIters=200, atomIds=atom_list)

if VISUAL_MOLECULE:
    view = grid_viewer((1,1), viewer_size=(600, 400))
    for i in range(len(bookkeep)):
        view.addModel(Chem.MolToMolBlock(ts_mol.ToRWMol(), confId=i), 'sdf')
    view.zoomTo()
    view.update()

In [None]:
from rdmc.forcefield import RDKitFF

### 2.2.2 Optimize By force field

In [None]:
ff = RDKitFF(force_field='MMFF94s')
# All at once or iterative optimization
# Don't know which way is faster using optimize_confs or optimize each conformer one by one

# # All at once
# ff.setup(ts_mol)
# for atom in atom_list:
#     ff.fix_atom(atom)
# results = ff.optimize_confs(num_threads=-1)
# energies = [e for _, e in results]

# Iterative
energies = []
for i in range(ts_mol.GetNumConformers()):
    ff.setup(ts_mol, conf_id=i)
    for atom in atom_list:
        ff.fix_atom(atom)
    ff.optimize()
    energies.append(ff.get_energy())

ts_mol = ff.get_optimized_mol()

In [None]:
if VISUAL_MOLECULE:
    view = grid_viewer((1,1), viewer_size=(600, 400))
    for i in range(len(bookkeep)):
        view.addModel(Chem.MolToMolBlock(ts_mol.ToRWMol(), confId=i), 'sdf')
    view.zoomTo()
    view.update()

## 2.3 Cluster conformers by energies
This will make the filtering duplicate conformers easier in the latter steps

In [None]:
energy_clusters = {value: [] for value in set([round(energy, 2)
                                               for energy in energies])}
for ind, energy in enumerate(energies):
    energy_clusters[round(energy, 2)].append(ind)
print(energy_clusters)

You can visualize the conformer from each of the group

In [None]:
mol_viewer(ts_mol.ToMolBlock(confId=4), 'sdf')

## 2.4 Filter out duplicate conformers

### 2.4.1 Filtering according to the torsional fingerprint
Currently, it uses the naivest fingerprint (angle values)

In [None]:
import scipy.cluster.hierarchy as hcluster

threshold = 10.

In [None]:
for energy_level, confs in energy_clusters.items():
    tor_matrix = []
    for conf_id in confs:
        conf = ts_mol.GetConformer(id=conf_id)
        conf.SetTorsionalModes(torsions)
        tor_matrix.append(conf.GetAllTorsionsDeg())

    tor_matrix = np.array(tor_matrix)
    clusters = hcluster.fclusterdata(tor_matrix, threshold, criterion='distance').tolist()

    clusters_unique = {}
    for i, cluster in enumerate(clusters):
        if not clusters_unique.get(cluster):
            clusters_unique[cluster] = confs[i]
    energy_clusters[energy_level] = list(clusters_unique.values())

energy_clusters

### 2.4.2 Filtering according to RMSD without changing atom orders

In [None]:
rmsd_threshold = 1e-3

for energy_level, confs in energy_clusters.items():

    if len(confs) == 1:
        continue

    distinct_confs = []
    while len(confs) > 1:
        distinct_confs.append(confs[0])
        rmsd_list = []
        Chem.rdMolAlign.AlignMolConformers(ts_mol.ToRWMol(),
                                           confIds=confs,
                                           maxIters=1000,
                                           RMSlist=rmsd_list,
                                          )

        confs_no_reflect = [confs[0]] + [conf for idx, conf in enumerate(confs[1:])
                                         if rmsd_list[idx] > rmsd_threshold]

        rmsd_list = []
        Chem.rdMolAlign.AlignMolConformers(ts_mol.ToRWMol(),
                                           confIds=confs_no_reflect,
                                           maxIters=1000,
                                           RMSlist=rmsd_list,
                                           reflect=True,)

        # Reflect everything back
        Chem.rdMolAlign.AlignMolConformers(ts_mol.ToRWMol(),
                                           confIds=confs_no_reflect,
                                           maxIters=0,
                                           reflect=True,)

        confs = [conf for idx, conf in enumerate(confs_no_reflect[1:])
                 if rmsd_list[idx] > rmsd_threshold]

    distinct_confs += confs
    energy_clusters[energy_level] = distinct_confs

energy_clusters

### 2.4.3 Filtering according to RMSD with changing atom orders [Broken]
This can be really computational expensive

In [None]:
def calc_rmsd_single_thread(symbols,
                            confs_id,
                            confs_coords,
                            rmsd_threshold=1e-2):
    
    distinct = [(confs_coords[0], confs_id[0])]
    
    if len(confs_id) > 1:
        
        for idx in range(1, len(confs_id)):
            new_coords = confs_coords[idx]
            for coords, conf_id in distinct:
                xyzs = [{'symbols': symbols, 'coords': new_coords},
                        {'symbols': symbols, 'coords': coords}]
                if calc_rmsd_wrapper(*xyzs) < rmsd_threshold:
                    break
            else:
                distinct.append((new_coords, confs_id[idx]))
    
    return distinct

In [None]:
result = Parallel(n_jobs=-1, verbose=100) \
         (delayed(calc_rmsd_single_thread)(symbols,
                                           energy_clusters[key],
                                           energy_clusters_coords[key])
          for key in energy_clusters.keys())