In [1]:
import numpy as np
import json
from scipy import sparse
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from tqdm import tqdm

In [2]:
# Load data/train_rxns_with_template.jsonl
with open('data/train_rxns_with_template.jsonl', 'r') as f:
    train_rxns = [json.loads(line) for line in f]
# Load data/val_rxns_with_template.jsonl
with open('data/val_rxns_with_template.jsonl', 'r') as f:
    val_rxns = [json.loads(line) for line in f]
# Load data/templates.jsonl
with open('data/templates.jsonl', 'r') as f:
    templates = [json.loads(line) for line in f]

In [3]:
def clear_atom_map(smi):
    mol = Chem.MolFromSmiles(smi)
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(0)
    return Chem.CanonSmiles(Chem.MolToSmiles(mol))

In [4]:
train_rxns[1]

{'class': '2',
 'id': 'US20120114765A1',
 'rxn_smiles': 'O[C:1](=[O:2])[c:3]1[cH:4][c:5]([N+:6](=[O:7])[O-:8])[c:9]([S:10][c:11]2[c:12]([Cl:13])[cH:14][n:15][cH:16][c:17]2[Cl:18])[s:19]1.[NH2:20][c:21]1[cH:22][cH:23][cH:24][c:25]2[cH:26][n:27][cH:28][cH:29][c:30]12>>[O:2]=[C:1]([NH:20][c:21]1[cH:22][cH:23][cH:24][c:25]2[cH:26][n:27][cH:28][cH:29][c:30]12)[c:3]1[cH:4][c:5]([N+:6](=[O:7])[O-:8])[c:9]([S:10][c:11]2[c:12]([Cl:13])[cH:14][n:15][cH:16][c:17]2[Cl:18])[s:19]1',
 'canon_reaction_smarts': '[#16&a:4]:[c:3]-[C&H0&D3&+0:1](=[O&D1&H0:2])-[N&H1&D2&+0:5]-[c:6]>>O-[C&H0&D3&+0:1](=[O&D1&H0:2])-[c:3]:[#16&a:4].[N&H2&D1&+0:5]-[c:6]'}

In [5]:
unique_products = {}
unique_molecules = {}
for rxn in tqdm(train_rxns):
    rxn_smiles = rxn['rxn_smiles']
    reactants, products = rxn_smiles.split('>>')
    assert '.' not in products
    product = clear_atom_map(products)
    reactants = [clear_atom_map(reactant) for reactant in reactants.split('.')]
    if product not in unique_products:
        unique_products[product] = 1
        unique_molecules[product] = 1
    else:
        unique_products[product] += 1
        unique_molecules[product] += 1
    for reactant in reactants:
        if reactant not in unique_molecules:
            unique_molecules[reactant] = 1
        else:
            unique_molecules[reactant] += 1

 35%|███▍      | 13858/39803 [00:10<00:20, 1270.13it/s]


KeyboardInterrupt: 

In [7]:
template2idx = {template['reaction_smarts']: template['index'] for template in templates}

In [8]:
triplets = []
for rxn in tqdm(train_rxns):
    rxn_smiles = rxn['rxn_smiles']
    template = rxn['canon_reaction_smarts']
    reactants, products = rxn_smiles.split('>>')
    assert '.' not in products
    reactants = [clear_atom_map(reactant) for reactant in reactants.split('.')]
    product = clear_atom_map(products)
    template_idx = template2idx[template]
    for reactant in reactants:
        triplets.append((reactant, template_idx, product))

# corrupted_triplets = []
# for rxn in tqdm(corrupted_rxns):
#     rxn_smiles = rxn['rxn_smiles']
#     template = rxn['canon_reaction_smarts']
#     reactants, products = rxn_smiles.split('>>')
#     assert '.' not in products
#     reactants = [clear_atom_map(reactant) for reactant in reactants.split('.')]
#     product = clear_atom_map(products)
#     template_idx = template2idx[template]
#     for reactant in reactants:
#         corrupted_triplets.append((reactant, template_idx, product))

  0%|          | 0/39803 [00:00<?, ?it/s]

100%|██████████| 39803/39803 [00:31<00:00, 1266.72it/s]


In [9]:
triplets[0]

('COC(=O)[C@H](CCCCNC(=O)OCc1ccccc1)NC(=O)Nc1cc(OC)cc(C(C)(C)C)c1O',
 93,
 'COC(=O)[C@H](CCCCN)NC(=O)Nc1cc(OC)cc(C(C)(C)C)c1O')

In [10]:
len(triplets)

68117

In [14]:
head_fps = []
relation_ids = []
relation_fps = []
tail_fps = []
for triplet in tqdm(triplets):
    reactant, template, product = triplet
    mol = Chem.MolFromSmiles(product)
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    head_fps.append(fp)
    relation_ids.append(template)
    mol = Chem.MolFromSmiles(reactant)
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    tail_fps.append(fp)
head_fps = np.array(head_fps)
relation_fps = np.array(relation_fps)
tail_fps = np.array(tail_fps)

100%|██████████| 68117/68117 [00:16<00:00, 4109.98it/s]


In [49]:
from multiprocessing import Pool

corrupted_head_fps = []
corrupted_relation_ids = []
corrupted_tail_fps = []

def process_triplet(triplet):
    reactant, template_idx, product = triplet
    head_mol = Chem.MolFromSmiles(product)
    head_fp = np.array(AllChem.GetMorganFingerprintAsBitVect(head_mol, 2, nBits=2048))
    tail_mol = Chem.MolFromSmiles(reactant)
    tail_fp = np.array(AllChem.GetMorganFingerprintAsBitVect(tail_mol, 2, nBits=2048))
    return head_fp, template_idx, tail_fp

with Pool(16) as pool:
    results = list(tqdm(pool.imap(process_triplet, corrupted_triplets), total=len(corrupted_triplets)))
    corrupted_head_fps = []
    corrupted_relation_ids = []
    corrupted_tail_fps = []
    for head_fp, relation_id, tail_fp in results:
        corrupted_head_fps.append(head_fp)
        corrupted_relation_ids.append(relation_id)
        corrupted_tail_fps.append(tail_fp)

corrupted_head_fps = np.array(corrupted_head_fps)
corrupted_relation_ids = np.array(corrupted_relation_ids)
corrupted_tail_fps = np.array(corrupted_tail_fps)

100%|██████████| 408702/408702 [00:55<00:00, 7344.58it/s]


In [15]:
unique_molecules_lst = list(unique_molecules.keys())
all_fps = []
for product in tqdm(unique_molecules_lst):
    mol = Chem.MolFromSmiles(product)
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    all_fps.append(fp)
all_fps = np.array(all_fps)

  3%|▎         | 2524/82006 [00:00<00:09, 8356.80it/s]

100%|██████████| 82006/82006 [00:11<00:00, 7261.96it/s]


In [16]:
# Save all product fps
np.save('data/molecule_fps.npy', all_fps)

In [17]:
# Save head_fps, relation_ids, tail_fps, corrupted_tail_fps
np.save('data/train_head_fps.npy', head_fps)
np.save('data/train_relation_ids.npy', relation_ids)
np.save('data/train_tail_fps.npy', tail_fps)
# np.save('data/train_corrupted_head_fps.npy', corrupted_head_fps)
# np.save('data/train_corrupted_relation_ids.npy', corrupted_relation_ids)
# np.save('data/train_corrupted_tail_fps.npy', corrupted_tail_fps)

In [11]:
# Load data/val_rxns_with_template.jsonl
with open('data/val_rxns_with_template.jsonl', 'r') as f:
    val_rxns = [json.loads(line) for line in f]

In [12]:
val_triplets = []
for rxn in tqdm(val_rxns):
    rxn_smiles = rxn['rxn_smiles']
    template = rxn['canon_reaction_smarts']
    reactants, products = rxn_smiles.split('>>')
    assert '.' not in products
    reactants = [clear_atom_map(reactant) for reactant in reactants.split('.')]
    product = clear_atom_map(products)
    if template in template2idx:
        template_idx = template2idx[template]
    else:
        template_idx = -1
    for reactant in reactants:
        val_triplets.append((reactant, template_idx, product))

  0%|          | 0/5001 [00:00<?, ?it/s]

100%|██████████| 5001/5001 [00:04<00:00, 1215.80it/s]


In [20]:
val_triplets[0]

('O=C(OC(=O)C(F)(F)F)C(F)(F)F',
 330,
 'O=C(Nc1ccc(Oc2ccnc3[nH]ccc23)c(F)c1)C(F)(F)F')

In [13]:
len(val_triplets)

8561

In [23]:
val_head_fps = []
val_relation_ids = []
val_tail_fps = []
for triplet in tqdm(val_triplets):
    reactant, template_idx, product = triplet
    mol = Chem.MolFromSmiles(product)
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    val_head_fps.append(fp)
    val_relation_ids.append(template_idx)
    mol = Chem.MolFromSmiles(reactant)
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    val_tail_fps.append(fp)
val_head_fps = np.array(val_head_fps)
val_relation_ids = np.array(val_relation_ids)
val_tail_fps = np.array(val_tail_fps)

  0%|          | 0/8561 [00:00<?, ?it/s]

100%|██████████| 8561/8561 [00:02<00:00, 4069.03it/s]


In [24]:
# Save val_head_fps, val_relation_ids, val_tail_fps
np.save('data/val_head_fps.npy', val_head_fps)
np.save('data/val_relation_ids.npy', val_relation_ids)
np.save('data/val_tail_fps.npy', val_tail_fps)

In [26]:
len(templates)

10225