In [1]:
import polars as pl
import numpy as np
from pathlib import Path
from hydra import compose, initialize
from src.utils import construct_sparse_adj_mat, load_json, augment_idx_feature
import json
import rdkit
from rdkit import Chem
import matplotlib.pyplot as plt
from sklearn.metrics import (
    roc_auc_score,
    confusion_matrix,
)
from collections import defaultdict
from rxnmapper import RXNMapper
from src.ergochemics.mapping import get_reaction_center
from tqdm import tqdm

In [2]:
with initialize(config_path="../configs/filepaths", version_base=None):
    fps = compose(config_name="base")

In [3]:
rcmcs_test = pl.read_parquet(
    Path(fps.scratch) / "rcmcs_test.parquet"
)
print(len(rcmcs_test.unique("smarts")))
rcmcs_test.head()

729


protein_idx,reaction_idx,pid,rid,protein_embedding,smarts,am_smarts,reaction_center,y
i32,i32,str,str,list[f32],str,str,list[list[list[i64]]],i64
16113,5226,"""Q09KQ6""","""16449""","[0.038376, 0.180993, … 0.215456]","""O=C(O)C=CC=C(O)C(=O)O.N>>NC(=C…","""[C:1](=[CH:3][CH:5]=[CH:8][C:9…","[[[6, 7], [0]], [[1, 0], [0]]]",1
21519,6100,"""P0DJQ7""","""9228""","[0.122922, 0.275483, … 0.109619]","""N=C(N)NCCCC(C(=O)O)N1CCC1=O.O=…","""[N:1]1([CH:7]([CH2:8][CH2:10][…","[[[14, 11], [2], [17, 15]], [[8, 11, 13], [18, 15]]]",1
5853,2476,"""P9WIN5""","""16898""","[-0.054251, 0.153572, … 0.120209]","""CC(C)(COP(=O)(O)OP(=O)(O)OCC1O…","""[SH:2][CH2:3][CH2:4][NH:5][C:6…","[[[47], [19, 18]], [[15, 17], [18]]]",1
13019,4673,"""Q6F7B8""","""9043""","[0.13116, 0.108926, … -0.013253]","""CC(C)(COP(=O)(O)OP(=O)(O)OCC1O…","""[SH:2][CH2:3][CH2:4][NH:5][C:6…","[[[47], [15], [3, 4, … 47]], [[15, 17], [3, 4, … 47]]]",1
13020,4673,"""Q79FC3""","""9043""","[0.287938, 0.200004, … 0.146176]","""CC(C)(COP(=O)(O)OP(=O)(O)OCC1O…","""[SH:2][CH2:3][CH2:4][NH:5][C:6…","[[[47], [15], [3, 4, … 47]], [[15, 17], [3, 4, … 47]]]",1


In [4]:
rcmcs_test = rcmcs_test.filter(
    pl.col("smarts").str.len_chars() < 512
)
print(len(rcmcs_test.unique("smarts")))

720


In [5]:
rxn_mapper = RXNMapper()
batch_size = 32
unmapped = {
    row['reaction_idx']: row['smarts']
    for row in rcmcs_test.unique("smarts").iter_rows(named=True)
}
res = []
for i in tqdm(range(0, len(unmapped), batch_size), total=len(unmapped)//batch_size + 1):
    batch_smarts = list(unmapped.values())[i:i+batch_size]
    batch_res = rxn_mapper.get_attention_guided_atom_maps(
        batch_smarts
    )
    res.extend(batch_res)

100%|██████████| 23/23 [03:57<00:00, 10.32s/it]


In [6]:
rcs = {}
for k, r in zip(unmapped.keys(), res):
    rcs[k] = get_reaction_center(r['mapped_rxn'])

In [9]:
ridx_to_res = {k: v for k, v in zip(unmapped.keys(), res)}
rc_dtype = rcmcs_test.schema["reaction_center"]
rxnmapper_test = rcmcs_test.with_columns(
    pl.col("reaction_idx").map_elements(lambda x: ridx_to_res[x]['mapped_rxn'], return_dtype=pl.String).alias("smarts"),
    pl.col("reaction_idx").map_elements(lambda x: ridx_to_res[x]['mapped_rxn'], return_dtype=pl.String).alias("am_smarts"),
    pl.col("reaction_idx").map_elements(lambda x: rcs[x], return_dtype=rc_dtype).alias("reaction_center"),
    pl.col("reaction_idx").map_elements(lambda x: ridx_to_res[x]['confidence'], return_dtype=pl.Float32).alias("rxnmapper_confidence"),
)
rxnmapper_test.head()

protein_idx,reaction_idx,pid,rid,protein_embedding,smarts,am_smarts,reaction_center,y,rxnmapper_confidence
i32,i32,str,str,list[f32],str,str,list[list[list[i64]]],i64,f32
16113,5226,"""Q09KQ6""","""16449""","[0.038376, 0.180993, … 0.215456]","""[NH3:1].[C:2](=[CH:3][CH:4]=[C…","""[NH3:1].[C:2](=[CH:3][CH:4]=[C…","[[[0], [0, 10]], [[0, 1], [0]]]",1,0.902623
21519,6100,"""P0DJQ7""","""9228""","[0.122922, 0.275483, … 0.109619]","""[NH:1]=[C:2]([NH2:3])[NH:4][CH…","""[NH:1]=[C:2]([NH2:3])[NH:4][CH…","[[[8, 11], [18], [0, 3, … 6]], [[8, 11, 13], [18, 19, … 25]]]",1,0.105942
5853,2476,"""P9WIN5""","""16898""","[-0.054251, 0.153572, … 0.120209]","""[SH:18][CH2:19][CH2:20][NH:21]…","""[SH:18][CH2:19][CH2:20][NH:21]…","[[[0], [15, 17]], [[15, 17], [18]]]",1,0.109628
13019,4673,"""Q6F7B8""","""9043""","[0.13116, 0.108926, … -0.013253]","""[SH:18][CH2:19][CH2:20][NH:21]…","""[SH:18][CH2:19][CH2:20][NH:21]…","[[[0], [15], [3, 4, … 47]], [[15, 17], [3, 4, … 47]]]",1,0.058677
13020,4673,"""Q79FC3""","""9043""","[0.287938, 0.200004, … 0.146176]","""[SH:18][CH2:19][CH2:20][NH:21]…","""[SH:18][CH2:19][CH2:20][NH:21]…","[[[0], [15], [3, 4, … 47]], [[15, 17], [3, 4, … 47]]]",1,0.058677


In [10]:
rxnmapper_test.write_parquet(
    Path(fps.scratch) / "rxn_mapper_am_v3_folded_pt_ns" / "rcmcs_random" / "Nonefold" / "test.parquet"
)