In [1]:
import os
from rdmc.conformer_generation.ts_verifiers import TSVerifier
from ts_ml.dataloaders.ts_screener_loader import TSScreenerDataset, mol2data
from ts_ml.trainers.ts_screener_trainer import LitScreenerModule

In [2]:
# log file -> features

from typing import Union

import numpy as np

from rdkit import Chem
from rdmc.rdtools.element import PERIODIC_TABLE as PT
from rdmc.external.gaussian import GaussianLog

def get_frames_from_freq(log,
                         amplitude: float = 1.0,
                         num_frames: int = 10,
                         weights: Union[bool, np.array] = False):
    """
    Args:
        log (GaussianLog): A gaussian log object with vibrational freq calculated.
        amplitude (float): The amplitude of the motion. If a single value is provided then the guess
                           will be unique (if available). 0.25 will be the default. Otherwise, a list
                           can be provided, and all possible results will be returned.
        num_frames (int): The number of frames in each direction (forward and reverse). Defaults to 10.
        weights (bool or np.array): If ``True``, use the sqrt(atom mass) as a scaling factor to the displacement.
                              If ``False``, use the identity weights. If a N x 1 ``np.array` is provided, then
                              The concern is that light atoms (e.g., H) tend to have larger motions
                              than heavier atoms.
    
    Returns:
        np.array: The atomic numbers as an 1D array
        np.array: The 3D geometries at each frame as a 3D array (number of frames x 2 + 1, number of atoms, 3)
    """
    assert log.num_neg_freqs == 1
    
    equ_xyz = log.converged_geometries[-1]
    disp = log.cclib_results.vibdisps[0]
    amp_factors = np.linspace(-amplitude, amplitude, 2 * num_frames + 1)
    
    # Generate weights
    if isinstance(weights, bool) and weights:
        atom_masses = np.array([PT.GetAtomicWeight(int(num)) for num in log.cclib_results.atomnos]).reshape(-1, 1)
        weights = np.sqrt(atom_masses)
    elif isinstance(weights, bool) and not weights:
        weights = np.ones((equ_xyz.shape[0], 1))
    
    xyzs = equ_xyz - np.einsum('i,jk->ijk', amp_factors, weights * disp)
    
    return log.cclib_results.atomnos, xyzs


def convert_log_to_mol(log_path: str,
                       amplitude: float = 1.0,
                       num_frames: int = 10,
                       weights: Union[bool, np.array] = False):
    """    
    Args:
        log_path (str): The path to the log file.
        amplitude (float): The amplitude of the motion. If a single value is provided then the guess
                           will be unique (if available). 0.25 will be the default. Otherwise, a list
                           can be provided, and all possible results will be returned.
        num_frames (int): The number of frames in each direction (forward and reverse). Defaults to 10.
        weights (bool or np.array): If ``True``, use the sqrt(atom mass) as a scaling factor to the displacement.
                              If ``False``, use the identity weights. If a N x 1 ``np.array` is provided, then
                              The concern is that light atoms (e.g., H) tend to have larger motions
                              than heavier atoms.
    """
    glog = GaussianLog(log_path)
    
    try:
        assert glog.success
        assert glog.is_ts
        assert glog.num_neg_freqs == 1
    except AssertionError:
        return None
    
    # Get TS mol object and construct geometries as numpy arrays for all frames
    mol = glog.get_mol(converged=True, embed_conformers=False, sanitize=False)
    _, xyzs = get_frames_from_freq(glog, amplitude=amplitude, num_frames=num_frames, weights=weights)

    # Embed geometries to the mol object for output
    mol.EmbedMultipleNullConfs(xyzs.shape[0])
    [mol.SetPositions(xyzs[i, :, :], id=i) for i in range(xyzs.shape[0])]
    
    return mol

In [None]:
class TSScreener(TSVerifier):
    """
    The class for screening TS guesses.
    """

    def __init__(self,
                 trained_model_dir: str,
                 threshold: Optional[int],
                 track_stats: Optional[bool] = False):
        """
        Initialize the TS-Screener model.

        Args:
            trained_model_dir (str): The path to the directory storing the trained TS-Screener model.
            threshold (int): Threshold prediction at which we classify a failure/success.
            track_stats (bool, optional): Whether to track timing stats. Defaults to False.
        """
        super(TSScreener, self).__init__(track_stats)

        # Load the TS-Screener model
        self.module = LitScreenerModule.load_from_checkpoint(
            checkpoint_path=os.path.join(trained_model_dir, "best_model.ckpt")
        )

        # Setup configuration
        self.config = self.module.config
        self.module.model.eval()
        self.threshold = threshold

    def verify_ts_guesses(self,
                          ts_mol: 'RDKitMol',
                          multiplicity: int = 1,
                          save_dir: Optional[str] = None,
                          **kwargs):
        """
        Screen poor TS guesses by using reacting mode from frequency calculation.

        Args:
            ts_mol ('RDKitMol'): The TS in RDKitMol object with 3D geometries embedded.
            multiplicity (int, optional): The spin multiplicity of the TS. Defaults to 1.
            save_dir (_type_, optional): The directory path to save the results. Defaults to None.

        Returns:
            None
        """
        rxn_smiles = kwargs["rxn_smiles"]
        mol_data, ids = [], []

        # parse all optimization folders (which hold the frequency jobs)
        for log_dir in sorted([d for d in glob(os.path.join(save_dir, "*opt*")) if os.path.isdir(d)], \
                               key=lambda x: int(x.split("opt")[-1])):

            idx = int(log_dir.split("opt")[-1])
            if ts_mol.KeepIDs[idx]:
                freq_log_path = glob(os.path.join(log_dir, "*opt.log"))[0]
                ts_freq_mol = convert_log_to_mol(freq_log_path)
                ts_freq_mol.SetProp("Name", rxn_smiles)
                data = mol2data(ts_freq_mol, ts_screener.config, eval_mode=True)

                mol_data.append(data)
                ids.append(idx)

        # create data batch and run screener model
        batch_data = Batch.from_data_list(mol_data)
        preds = self.module.model(batch_data) > self.threshold
        
        # update which TSs to keep
        updated_keep_ids = {idx: pred.item() for idx, pred in zip(ids, preds)}
        ts_mol.KeepIDs.update(updated_keep_ids)
        
        # write ids to file
        with open(os.path.join(save_dir, "screener_check_ids.pkl"), "wb") as f:
            pickle.dump(ts_mol.KeepIDs, f)

In [72]:
from rdmc import RDKitMol
from torch_geometric.data import Batch
from glob import glob

rxn_smiles = '[C:1]([O:2][C:3]([C:4]1([C:5]([H:15])([H:16])[H:17])[C:6]([H:18])([H:19])[C:7]([C:8]([H:21])([H:22])[H:23])([H:20])[C:9]1([H:24])[H:25])([H:13])[H:14])([H:10])([H:11])[H:12]>>[C:1]([O:2][C:3]([C:4]1([H:17])[C:6]([H:18])([H:19])[C:7]([C:8]([H:21])([H:22])[H:23])([H:20])[C:9]1([H:24])[H:25])([H:13])[H:14])([H:10])([H:11])[H:12].[C:5]([H:15])[H:16]'
threshold = 0.2

trained_model_dir = "../../TS-ML/trained_models/ts_screener/2022_05_18/"

ts_screener = LitScreenerModule.load_from_checkpoint(
    checkpoint_path=os.path.join(trained_model_dir, "best_model.ckpt")
)

In [73]:
save_dir = "./data/rmg_results/2988/"

ts_mol = RDKitMol.FromFile(os.path.join(save_dir, "ts_optimized_confs.sdf"), sameMol=True)

keep = [1, 2, 3, 4, 6, 8, 9, 10, 11, 15, 17, 18]
KeepIDs = {i: True if i in keep else False for i in range(ts_mol.GetNumConformers())}

In [78]:
mol_data, ids = [], []

for log_dir in sorted([d for d in glob(os.path.join(save_dir, "*opt*")) if os.path.isdir(d)], \
                       key=lambda x: int(x.split("opt")[-1])):
    
    idx = int(log_dir.split("opt")[-1])
    if KeepIDs[idx]:
        freq_log_path = glob(os.path.join(log_dir, "*opt.log"))[0]
        ts_freq_mol = convert_log_to_mol(freq_log_path)
        ts_freq_mol.SetProp("Name", rxn_smiles)
        data = mol2data(ts_freq_mol, ts_screener.config, eval_mode=True)
        
        mol_data.append(data)
        ids.append(idx)
        
batch_data = Batch.from_data_list(mol_data)
preds = ts_screener(batch_data) > threshold

In [85]:
preds

tensor([False,  True, False, False, False, False, False, False, False, False,
        False, False])

In [98]:
ts_screener.model

FrameClassifier(
  (featurizer): DimeNetPlusPlus(
    (rbf): BesselBasisLayer(
      (envelope): Envelope()
    )
    (sbf): SphericalBasisLayer(
      (envelope): Envelope()
    )
    (emb): EmbeddingBlock(
      (emb): Linear(in_features=9, out_features=64, bias=True)
      (lin_rbf): Linear(in_features=6, out_features=64, bias=True)
      (lin): Linear(in_features=192, out_features=64, bias=True)
    )
    (output_blocks): ModuleList(
      (0): OutputPPBlock(
        (lin_rbf): Linear(in_features=6, out_features=64, bias=False)
        (lin_up): Linear(in_features=64, out_features=32, bias=True)
        (lins): ModuleList(
          (0): Linear(in_features=32, out_features=32, bias=True)
          (1): Linear(in_features=32, out_features=32, bias=True)
        )
        (lin): Linear(in_features=32, out_features=32, bias=False)
      )
      (1): OutputPPBlock(
        (lin_rbf): Linear(in_features=6, out_features=64, bias=False)
        (lin_up): Linear(in_features=64, out_feature

In [103]:
from rdmc.view import conformer_viewer

for log_dir in sorted([d for d in glob(os.path.join(save_dir, "*opt*")) if os.path.isdir(d)], \
                       key=lambda x: int(x.split("opt")[-1])):
    
    idx = int(log_dir.split("opt")[-1])
    if idx == 2:
        freq_log_path = glob(os.path.join(log_dir, "*opt.log"))[0]
        ts_freq_mol = convert_log_to_mol(freq_log_path)
        break

In [104]:
conformer_viewer(ts_freq_mol).update()

In [106]:
ts_path = "../../TS-ML/exps/plots/ts_screener/ex1.sdf"
ts_writer = Chem.rdmolfiles.SDWriter(ts_path)
for i in range(ts_freq_mol.GetNumConformers()):
    conf = ts_freq_mol.GetConformer(i).ToMol().ToRWMol()
    ts_writer.write(conf)
ts_writer.close()