In [1]:
from rdkit import Chem
from rdkit.Chem import AllChem

from copy import deepcopy
import numpy as np
import os

import torch
from model.ts_trainer import LitTSModule
from model.data import TSDataset
from torch_geometric.data import Batch

from rdmc.mol import RDKitMol
from rdmc.view import mol_viewer

In [2]:
class EvalTSDataset(TSDataset):
    def __init__(self, config):
        
        self.set_similar_mols = config["set_similar_mols"]  # use species (r/p) which is more similar to TS as starting mol
        self.shuffle_mols = config["shuffle_mols"]  # randomize which is reactant/product
        self.prep_mols = config["prep_mols"]  # prep as if starting from SMILES
        self.prod_feat = config["prod_feat"]  # whether product features include distance or adjacency

In [3]:
# the prep_mols argument is important
# it aligns the product in "reacting" configuration with the reactant configuration
# if the product has 1 fragment and the reactant has 2 fragments, simply reverse the reactant and product
# when inputting to the model

exp_dir = "./trained_models/2022_02_01/"

TSModule = LitTSModule.load_from_checkpoint(
    checkpoint_path=os.path.join(exp_dir, "best_model.ckpt"),
)

config = TSModule.config
config["shuffle_mols"] = False
config["prep_mols"] = True
test_dataset = EvalTSDataset(config)

In [4]:
# here, we start from SMILES; we can easily start from 3D mols if desired

r_smi = "[C:1]([c:2]1[n:3][o:4][n:5][n:6]1)([H:7])([H:8])[H:9]"
p_smi = "[C:1]([N:3]=[C:2]=[N:6][N:5]=[O:4])([H:7])([H:8])[H:9]"

r_mol = RDKitMol.FromSmiles(r_smi)
p_mol = RDKitMol.FromSmiles(p_smi)

r_mol.EmbedConformer()
p_mol.EmbedConformer()

In [5]:
mol_viewer(r_mol)

<py3Dmol.view at 0x7f2c78b458d0>

In [6]:
mol_viewer(p_mol)

<py3Dmol.view at 0x7f2c78ab7fd0>

In [7]:
mols = (r_mol.ToRWMol(), None, p_mol.ToRWMol())
data = test_dataset.process_mols(mols, no_ts=True)
batch_data = Batch.from_data_list([data])
predicted_ts_coords = TSModule.model(batch_data)[:, :3].cpu().detach().numpy()
predicted_ts = deepcopy(r_mol)
predicted_ts.SetPositions(np.array(predicted_ts_coords, dtype=float))

In [8]:
mol_viewer(predicted_ts)

<py3Dmol.view at 0x7f2c78ac8c50>