In [1]:
from hydra import initialize, compose
import polars as pl
from pathlib import Path
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
from tqdm import tqdm
from src.ergochemics.standardize import standardize_rxn, hash_reaction
import json
import matplotlib.pyplot as plt
from collections import Counter
from itertools import chain
import pickle
import torch
from unimol_tools import UniMolRepr
from tqdm import tqdm

In [2]:
with initialize(version_base=None, config_path="../configs/filepaths"):
    cfg = compose(config_name="base")

In [17]:
arc_negs = pl.read_parquet(Path(cfg.scratch) / "time_split_250915_time_split/None_alternate_reaction_center/Nonefold/test.parquet")
arc_negs.head()

protein_idx,reaction_idx,pid,rid,protein_embedding,smarts,am_smarts,reaction_center,y
i64,i64,str,str,list[f32],str,str,list[list[list[i64]]],i64
0,0,"""Q9H3Z7""","""b974ffacf1d79e2635848eb4228153…","[0.03637, 0.16163, … 0.000506]","""*C(=O)OCC(O)CO.O>>*C(=O)O.OCC(…","""[*:3][C:1](=[O:4])[O:5][CH2:6]…","[[[1, 3], [0]], [[1, 3], [0]]]",1
1,1,"""Q0RWC9""","""1a55f3c111dc32bc80f27e9d4874ea…","[-0.023623, 0.137708, … 0.061511]","""O=C(O)c1ccc(O)c(O)c1.O=C=O>>O=…","""[O:12]=[C:8]([OH:13])[c:5]1[cH…","[[[10], [1, 0]], [[10, 11, 13]]]",1
2,2,"""Q80XH4""","""46b8f933837da28efe31dccc0dc939…","[-0.046594, 0.372996, … -0.210981]","""[1*]C(O)C(COC1OC(CO)C(OC2OC(CO…","""[1*:73][CH:71]([OH:74])[CH:70]…","[[[53, 56], [17, 15]], [[53, 54], [15, 18]]]",1
3,3,"""H2LNR5""","""9851de4d3521c90a80bc8cf68a948f…","[-0.062493, 0.197662, … 0.024808]","""NC(CCC(=O)NC(CS)C(=O)NCC(=O)O)…","""[NH2:17][CH:16]([CH2:15][CH2:1…","[[[14, 16], [21], [2, 1]], [[14, 16], [0], [22, 23]]]",1
4,4,"""Q8NKB0""","""17bb99c7a7a73802ae1dc5760af1cb…","[0.072061, 0.220584, … 0.06905]","""CC1CCCC(=O)CCCC=Cc2cc(O)cc(O)c…","""[CH3:4][CH:2]1[CH2:5][CH2:6][C…","[[[1, 22], [0]], [[1, 2, 23]]]",1


In [18]:
rand_negs = pl.read_parquet(Path(cfg.scratch) / "time_split_250915_time_split/None_random/Nonefold/test.parquet")
rand_negs.head()

protein_idx,reaction_idx,pid,rid,protein_embedding,smarts,am_smarts,reaction_center,am_chiral_smarts,chiral_smarts,y
i32,i32,str,str,list[f32],str,str,list[list[list[i64]]],str,str,i64
0,0,"""P9WEM0""","""772cec55ef0385f46f9c59a741eff3…","[-0.045221, 0.127457, … -0.064685]","""CC1=C(COP(=O)(O)O)C2(C)CCCC(C)…","""[CH3:6][C:4]1=[C:3]([CH2:1][O:…","[[[3, 4], [0]], [[3, 4], [2]]]","""[CH3:1][C:2]1=[C:3]([CH2:4][OH…","""CC1=C(CO)[C@@]2(C)CCCC(C)(C)[C…",1
1,1,"""A0A085GHR3""","""6b5456b8e1371df3f1326343255cbc…","[0.08258, 0.175821, … 0.161705]","""*NC(CO)C(*)=O.Nc1ncnc2c1ncn2C1…","""[*:6][NH:4][CH:3]([CH2:1][OH:1…","[[[3, 4], [22, 19]], [[3, 4], [19, 21]]]","""[*:1][NH:2][C@@H:3]([CH2:4][O:…","""*N[C@@H](COP(=O)(O)O)C(*)=O.Nc…",1
2,2,"""O24769""","""181b940baaf957ae014788e3f2b289…","[0.059184, 0.253729, … 0.162025]","""Nc1ncnc2c1ncn2C1OC(COP(=O)(O)O…","""[NH2:18][c:16]1[n:17][cH:15][n…","[[[13, 14], [0]], [[13, 14], [2]]]","""[NH2:1][c:2]1[n:3][cH:4][n:5][…","""Nc1ncnc2c1ncn2[C@@H]1O[C@H](CO…",1
3,3,"""A0A397HK53""","""e083443eeced3a35861687bd0dbcf4…","[0.072891, 0.278549, … 0.093182]","""Nc1ncnc2c1ncn2C1OC(CSCCC(N)C(=…","""[NH2:27][c:25]1[n:26][cH:24][n…","[[[14], [0, 1]], [[0, 1], [17]]]","""[CH3:1][O:45][c:44]1[cH:43][cH…","""COc1ccc(/C=C2\NC(=O)[C@H](CCCN…",1
4,4,"""Q4WBI5""","""d1e49e481f532c530e85f82580657a…","[0.083018, 0.236445, … 0.102194]","""Cc1cc(O)cc(=O)o1.CC(C)=CCCC(C)…","""[CH3:10][c:9]1[cH:5][c:3]([OH:…","[[[5], [19, 20]], [[19, 20], [2]]]","""[CH3:1][C:2]([CH3:3])=[CH:4][C…","""CC(C)=CCC/C(C)=C/CC/C(C)=C/CC/…",1


In [19]:
toc = pl.read_csv(Path(cfg.data) / "time_split" / "250915_time_split.csv", separator='\t')
toc.head()


Entry,Label,Sequence
str,str,str
"""P9WEM0""","""772cec55ef0385f46f9c59a741eff3…","""MCTTFKAAIFDMGGVLFTWNPIVDTQVSLK…"
"""A0A085GHR3""","""6b5456b8e1371df3f1326343255cbc…","""MFNVSNNVAPSRYQGPSSTSVTPNAFHDVP…"
"""O24769""","""181b940baaf957ae014788e3f2b289…","""MSSLLDIIYQLRQVPRWDGSFQFEKEDVSQ…"
"""A0A397HK53""","""e083443eeced3a35861687bd0dbcf4…","""MYESIIPFPDETASTVDKYCCSHSTSLPPL…"
"""Q4WBI5""","""d1e49e481f532c530e85f82580657a…","""MCKMYKVVETDASPGQRSVLQLVKDLLILS…"


clipzyme

In [20]:
clip_dir = "/home/stef/clipzyme/files"
af2_whitelist = {}
for file in (Path(cfg.data) / "time_split" / "af2").iterdir():
    af2_whitelist[file.stem.split("-")[1]] = str(file).replace("/home/stef/quest_data", "/projects/p30041/spn1560")

In [21]:
arc_clip = arc_negs.filter(
    pl.col("pid").is_in(list(af2_whitelist.keys()))
).join(
    toc,
    how="left",
    left_on="pid",
    right_on="Entry",
).select(
    ["am_smarts", "Sequence", "pid", "y"]
).with_columns(
    pl.col("pid").replace(af2_whitelist).alias("cif")
).rename(
    {
        "am_smarts": "reaction",
        "Sequence": "sequence",
        "pid": "protein_id",
    }
)
print(f"# pairs post af2 filtering: {arc_clip.height} vs pre-filtering: {arc_negs.height}")
arc_clip.head()

# pairs post af2 filtering: 6332 vs pre-filtering: 7234


reaction,sequence,protein_id,y,cif
str,str,str,i64,str
"""[*:3][C:1](=[O:4])[O:5][CH2:6]…","""MCVICFVKALVRVFKIYLTASYTYPFRGWP…","""Q9H3Z7""",1,"""/projects/p30041/spn1560/hiec/…"
"""[O:12]=[C:8]([OH:13])[c:5]1[cH…","""MNETDIEGARVVVAEACRVAAARGLMEGIL…","""Q0RWC9""",1,"""/projects/p30041/spn1560/hiec/…"
"""[1*:73][CH:71]([OH:74])[CH:70]…","""MAKPFFRLQKFLRRTQFLLLFLTAAYLMTG…","""Q80XH4""",1,"""/projects/p30041/spn1560/hiec/…"
"""[NH2:17][CH:16]([CH2:15][CH2:1…","""MAPMLVSLNCGIRVQRRTLTLLIRQTSSYH…","""H2LNR5""",1,"""/projects/p30041/spn1560/hiec/…"
"""[CH3:4][CH:2]1[CH2:5][CH2:6][C…","""MRTRSTISTPNGITWYYEQEGTGPDVVLVP…","""Q8NKB0""",1,"""/projects/p30041/spn1560/hiec/…"


In [22]:
arc_clip.write_csv(Path(clip_dir) / "time_split_arc.csv")

In [23]:
random_clip = rand_negs.filter(
    pl.col("pid").is_in(list(af2_whitelist.keys()))
).join(
    toc,
    how="left",
    left_on="pid",
    right_on="Entry",
).select(
    ["am_chiral_smarts", "Sequence", "pid", "y"]
).with_columns(
    pl.col("pid").replace(af2_whitelist).alias("cif")
).rename(
    {
        "am_chiral_smarts": "reaction",
        "Sequence": "sequence",
        "pid": "protein_id",
    }
)
print(f"# pairs post af2 filtering: {random_clip.height} vs pre-filtering: {rand_negs.height}")
random_clip.head()

# pairs post af2 filtering: 5285 vs pre-filtering: 6072


reaction,sequence,protein_id,y,cif
str,str,str,i64,str
"""[*:1][NH:2][C@@H:3]([CH2:4][O:…","""MFNVSNNVAPSRYQGPSSTSVTPNAFHDVP…","""A0A085GHR3""",1,"""/projects/p30041/spn1560/hiec/…"
"""[NH2:1][c:2]1[n:3][cH:4][n:5][…","""MSSLLDIIYQLRQVPRWDGSFQFEKEDVSQ…","""O24769""",1,"""/projects/p30041/spn1560/hiec/…"
"""[CH3:1][O:45][c:44]1[cH:43][cH…","""MYESIIPFPDETASTVDKYCCSHSTSLPPL…","""A0A397HK53""",1,"""/projects/p30041/spn1560/hiec/…"
"""[CH3:1][C:2]([CH3:3])=[CH:4][C…","""MCKMYKVVETDASPGQRSVLQLVKDLLILS…","""Q4WBI5""",1,"""/projects/p30041/spn1560/hiec/…"
"""[NH2:1][C:2](=[O:3])[C:4]1=[CH…","""MNEILKKRLKLLKNNFGTHINKIANKKILI…","""A0A0U3AP28""",1,"""/projects/p30041/spn1560/hiec/…"


In [24]:
random_clip.write_csv(Path(clip_dir) / "time_split_random.csv")

reactzyme

In [25]:
embed_dir = Path("/home/stef/ReactZyme/data/embedding")
data_dir = Path("/home/stef/ReactZyme/data/25_time_split")
esm_src = Path(cfg.data) / "time_split" / "esm"

In [34]:
rz_smi_convert = lambda x : ".".join(x.replace('*', 'C').split(">>"))
unique_mols = [rz_smi_convert(elt)
               for elt in rand_negs['chiral_smarts'].unique().to_list()
]
clf = UniMolRepr(data_type='molecule', remove_hs=False)
mol_dict = {}
bad_ones = []
for elt in tqdm(unique_mols, total=len(unique_mols)):    
    elts = elt.split('.')
    if '[H]C(=O)C1=C(CCC(=O)[O-])C2=N3/C1=C\\C1=C(C)C([C@@H](O)CC/C=C(\\C)CC/C=C(\\C)CCC=C(C)C)=C4/C=C5/C(C)=C(C=C)C6=N5[Fe]3(N14)N1C(=C2)C(CCC(=O)[O-])=C(C)/C1=C/6' in elts:
        elts.remove('[H]C(=O)C1=C(CCC(=O)[O-])C2=N3/C1=C\\C1=C(C)C([C@@H](O)CC/C=C(\\C)CC/C=C(\\C)CCC=C(C)C)=C4/C=C5/C(C)=C(C=C)C6=N5[Fe]3(N14)N1C(=C2)C(CCC(=O)[O-])=C(C)/C1=C/6')

    try:
        unimol_repr = clf.get_repr(elts, return_atomic_reprs=True)
        mol_dict[elt] = torch.tensor([rep.mean(0) for rep in unimol_repr['atomic_reprs']])
    except:
        bad_ones.append(elt)
        continue

print(f"Failed to embed {len(bad_ones)} molecules.")
torch.save(mol_dict, embed_dir / 'unimol_mol_embedding.pt')

2025-12-16 10:37:01 | unimol_tools/models/unimol.py | 167 | INFO | Uni-Mol Tools | Loading pretrained weights from /home/stef/miniconda3/envs/hiec2/lib/python3.11/site-packages/unimol_tools/weights/mol_pre_all_h_220816.pt
  0%|          | 0/239 [00:00<?, ?it/s]2025-12-16 10:37:02 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Start generating conformers...
4it [00:00,  5.15it/s]
2025-12-16 10:37:04 | unimol_tools/data/conformer.py | 197 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 100.00% of molecules.
2025-12-16 10:37:04 | unimol_tools/data/conformer.py | 214 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 100.00% of molecules.
2025-12-16 10:37:04 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.
100%|██████████| 1/1 [00:00<00:00,  1.08it/s]
  0%|          | 1/239 [00:02<09:24,  2.37s/it]2025-12-16 10:37:05 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Start generating conformers...
3it [0

Failed to embed 0 molecules.





In [27]:
id2seq = {row["Entry"]: row["Sequence"] for row in toc.select(["Entry", "Sequence"]).to_dicts()}
id_2_trun_unq_seq = {
    pid: id2seq[pid][:5000] for pid in rand_negs['pid'].unique().to_list()
}

esm_dict = {}
for pid, seq in id_2_trun_unq_seq.items():
    esm_data = torch.load(esm_src / f"{pid}.pt")
    esm_dict[seq] = esm_data['mean_representations'][33]

torch.save(esm_dict, embed_dir / 'esm_seq_embedding.pt')


In [28]:
neg_rand_negs = rand_negs.filter(
    pl.col("y") == 0
)
pos_rand_negs = rand_negs.filter(
    pl.col("y") == 1
)

In [None]:
pos_tst = {
    i: (rz_smi_convert(row['chiral_smarts']), id_2_trun_unq_seq[row['pid']])
    for i, row in enumerate(pos_rand_negs.iter_rows(named=True))
}

neg_tst = {
    i: (rz_smi_convert(row['chiral_smarts']), id_2_trun_unq_seq[row['pid']])
    for i, row in enumerate(neg_rand_negs.iter_rows(named=True))
}

torch.save(pos_tst, data_dir / 'positive_test.pt')
torch.save(neg_tst, data_dir / 'negative_test.pt')

In [30]:
foo = set([elt[0] for elt in neg_tst.values()]) | set([elt[0] for elt in pos_tst.values()])
bar = set(mol_dict.keys())

In [37]:
foo ^ bar

set()

rxnfp zone

In [21]:
from rxnfp.transformer_fingerprints import (
    RXNBERTFingerprintGenerator, get_default_model_and_tokenizer, generate_fingerprints
)
import pandas as pd
import numpy as np
import json
from pathlib import Path

In [22]:
def construct_sparse_adj_mat(path: Path):
        '''
        Returns sparse representation of sample x feature adjacency matrix
        and lookup of sample names from row idx key.

        Args
            - path: Path to dataset table of contents csv
        Returns
            - adj: Sparse adjacency matrix
            - idx_sample: Index to label dict for samples
            - idx_feature: Index to label dict for features / classes
        '''
        df = pd.read_csv(path, delimiter='\t')
        
        # Load from dataset "table of contents csv"
        df.set_index('Entry', inplace=True)
        sample_idx = {}
        feature_idx = {}
        
        # Construct ground truth protein-function matrix
        print(f"Constructing {path.stem} sparse adjacency matrix")
        row, col, data = [], [], [] # For csr
        for i, elt in enumerate(df.index):
            labels = df.loc[elt, 'Label'].split(';')
            sample_idx[elt] = i
            for label in labels:
                if label in feature_idx:
                    j = feature_idx[label]
                else:
                    j = len(feature_idx)
                    feature_idx[label] = j
                row.append(i)
                col.append(j)
                data.append(1)

        # adj = sp.sparse.csr_matrix((data, (row, col)), shape=(len(sample_idx), len(feature_idx)))
        idx_sample = {v:k for k,v in sample_idx.items()}
        idx_feature = {v:k for k,v in feature_idx.items()}
            
        return idx_sample, idx_feature

In [23]:
data_dir = Path("/home/stef/quest_data/hiec/data")
dataset = "time_split"
toc = "250915_time_split"

In [24]:
_, idx_feature = construct_sparse_adj_mat(
    data_dir / dataset / f"{toc}.csv",
)
with open(data_dir / dataset / f"{toc}.json", 'r') as f:
    rxns = json.load(f)

Constructing 250915_time_split sparse adjacency matrix


In [25]:
model, tokenizer = get_default_model_and_tokenizer()
rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer)

In [26]:
embeds = []
for idx, rid in sorted(idx_feature.items()):
    rxn = rxns[rid]['smarts']
    fp = rxnfp_generator.convert(rxn)
    embeds.append(fp)
    

In [28]:
embeds = np.vstack(embeds)
np.save(f"{data_dir}/pretrained_rxn_embeddings/{dataset}_{toc}_rxnfp.npy", embeds)