In [1]:
from hydra import compose, initialize
from pathlib import Path
from chemprop import data, featurizers, models
from chemprop.nn.message_passing import BondMessagePassing
import torch
from lightning import pytorch as pl
import pandas as pd
import numpy as np
from ergochemics.mapping import rc_to_nest
from rdkit import Chem

with initialize(version_base=None, config_path="../configs/filepaths"):
    filepaths = compose(config_name="filepaths")


In [2]:
df = pd.read_parquet(
    Path(filepaths.raw_data) / "mapped_sprhea_240310_v3_mapped_no_subunits_x_mechanistic_rules.parquet"
)
df["reaction_center"] = df["reaction_center"].apply(rc_to_nest)
df.head()

Unnamed: 0,rxn_id,smarts,am_smarts,rule,reaction_center,rule_id
0,1,CC(O)C(O)C(O)C(O)C(=O)O>>O.CC(O)C(O)CC(=O)C(=O)O,[CH3:11][CH:9]([OH:12])[CH:8]([OH:10])[CH:7]([...,[O&D1&v2&H0&0*&!R:1]=[C&D3&v4&H0&0*&!R&z2:2](-...,"(((10, 9, 11, 7, 8, 5, 6),), ((0,), (0, 1, 2, ...",841
1,10,O=P(O)(O)OP(=O)(O)O.CCCCCCCCCCCC(=O)OP(=O)(O)O...,[O:10]=[P:1]([OH:2])([OH:11])[O:3][P:4](=[O:5]...,[*:1](~[O&D1&v2&H1&0*&!R:2])~[*:3]~[*:4](~[O&D...,"(((1, 2, 4, 5, 6, 7, 8), (14, 16, 13)), ((0, 1...",632
2,100,*C(=O)OCC(COP(=O)(O)OC1C(O)C(OP(=O)(O)O)C(O)C(...,[*:30][C:27](=[O:31])[O:25][CH2:23][CH:22]([CH...,[*:1](~[O&D2&v2&H0&0*&!R:2]-[P&D4&v5&H0&0*&!R:...,"(((15, 16, 17, 19, 13, 14), (0,)), ((0, 1, 2, ...",813
3,10001,Nc1c(NCC(O)C(O)C(O)CO)[nH]c(=O)[nH]c1=O.Cc1cc2...,[NH2:25][c:26]1[c:27]([NH:29][CH2:36][CH:37]([...,[N&D1&v3&H2&0*&!R:1]-[c&D3&v4&H0&0*&R&z1:2]:[c...,"(((0, 1, 2, 3, 13), (1, 2, 3, 4, 5, 12, 11, 9,...",338
4,10008,O.O=[N+]([O-])c1ccc(OP(=O)(O)O)cc1>>O=P(O)(O)O...,[OH2:5].[O:14]=[N+:13]([O-:15])[c:12]1[cH:10][...,[O&D0&v2&H2&0*&!R:1].[O&D2&v2&H0&0*&!R:2]-[P&D...,"(((0,), (7, 8, 9, 10, 11)), ((0, 1, 2, 3, 4), ...",585


In [11]:
smis = df["am_smarts"].tolist()
smis[:5]

['[CH3:11][CH:9]([OH:12])[CH:8]([OH:10])[CH:7]([OH:1])[CH:5]([OH:6])[C:3](=[O:4])[OH:2]>>[OH2:1].[CH3:11][CH:9]([OH:12])[CH:8]([OH:10])[CH2:7][C:5](=[O:6])[C:3](=[O:2])[OH:4]',
 '[O:10]=[P:1]([OH:2])([OH:11])[O:3][P:4](=[O:5])([OH:6])[OH:7].[CH3:45][CH2:44][CH2:43][CH2:42][CH2:41][CH2:40][CH2:39][CH2:38][CH2:37][CH2:36][CH2:34][C:33](=[O:35])[O:32][P:8](=[O:12])([OH:9])[O:13][CH2:14][CH:15]1[O:16][CH:18]([n:21]2[cH:23][n:25][c:27]3[c:29]([NH2:31])[n:30][cH:28][n:26][c:24]32)[CH:19]([OH:22])[CH:17]1[OH:20]>>[NH2:31][c:29]1[n:30][cH:28][n:26][c:24]2[c:27]1[n:25][cH:23][n:21]2[CH:18]1[O:16][CH:15]([CH2:14][O:13][P:8](=[O:12])([OH:9])[O:7][P:4](=[O:5])([OH:6])[O:3][P:1](=[O:10])([OH:2])[OH:11])[CH:17]([OH:20])[CH:19]1[OH:22].[CH3:45][CH2:44][CH2:43][CH2:42][CH2:41][CH2:40][CH2:39][CH2:38][CH2:37][CH2:36][CH2:34][C:33](=[O:35])[OH:32]',
 '[*:30][C:27](=[O:31])[O:25][CH2:23][CH:22]([CH2:21][O:18][P:17](=[O:19])([OH:20])[O:16][CH:12]1[CH:3]([OH:4])[CH:1]([O:2][P:32](=[O:35])([OH:33])[OH:36])[

In [4]:
test_data = [data.ReactionDatapoint.from_smi(smi, y=np.random.rand(1, 3)) for smi in smis]
featurizer = featurizers.CondensedGraphOfReactionFeaturizer(mode_="PROD_DIFF", atom_featurizer=featurizers.MultiHotAtomFeaturizer.v2())
test_dataset = data.ReactionDataset(test_data, featurizer=featurizer)
test_dataloader = data.build_dataloader(test_dataset, shuffle=False)

In [5]:
test_data[2].rct.GetNumAtoms()

36

In [6]:
mp = BondMessagePassing(d_v=featurizer.atom_fdim, d_e=featurizer.bond_fdim)

In [7]:
Ha = mp(next(iter(test_dataloader)).bmg)
Ha.shape

torch.Size([1110, 300])

In [8]:
for elt in test_dataset:
    print(elt.mg.V.shape)
    print(elt.mg.V[:, :9])
    break

(12, 106)
[[0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0.]]


In [9]:
for smi, elt in zip(smis, test_dataset):
    mol = Chem.MolFromSmiles(smi.split(">>")[0])
    cp_ans = np.nonzero(elt.mg.V[:, :38])[1] + 1
    rd_ans = [a.GetAtomicNum() for a in mol.GetAtoms()]
    amns = [a.GetAtomMapNum() for a in mol.GetAtoms()]

    for c, r, am in zip(cp_ans, rd_ans, amns):
        assert c == r or (c == 38) and (r == 0), f"CP: {c}, RDKit: {r}, AM: {am} for SMILES: {smi}"