In [1]:
import yaml
import numpy as np
import pandas as pd
import pickle as pkl
import torch
from torch.utils.data import DataLoader
from dataset import ReactionDataset, collate_reaction, query_to_vec
from train import Trainer
from util import euclidean_sim, is_valid_smiles

### User Query and Specification for Search
1. Query for search
2. No. records to retrieve
3. Model identifier to load

In [2]:
query = {
    'Q.ID': 1,
    'product': 'c1ccc(-c2ccc3nnc(CNc4ncnc5nc[nH]c45)n3n2)cc1',#smiles
    'reactant': 'Clc1ncnc2nc[nH]c12.NCc1nnc2ccc(-c3ccccc3)nn12',#smiles
}
n_retrieve = 15
identifier = 'paper'

In [3]:
# Check query
has_product = is_valid_smiles(query['product'])
has_reactant = is_valid_smiles(query['reactant'])
if len(query['product']) > 0: assert has_product
if len(query['reactant']) > 0: assert has_reactant
assert has_product or has_reactant

### Load Configurations

In [4]:
cuda = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', cuda)

config_url = 'config.yaml'
with open(config_url, 'r') as f:
    config = yaml.safe_load(f)  

frac_var = batch_size = config['pca']['frac_var']

Using device: cuda


In [5]:
# Load molecular dictionary
with open('./data/uspto_mol_dict.pkl', 'rb') as f:
    mol_dict = pkl.load(f)

# Load pre-trained representation model
model_path = './model/%s_checkpoint.pt'%identifier

trainer = Trainer(None, model_path, mol_dict, cuda)
trainer.load(model_path)

## Prepare Database

In [6]:
# Load database to be searched
ref_data = ReactionDataset('uspto_train', mol_dict, mode = 'dict').data | ReactionDataset('uspto_valid', mol_dict, mode = 'dict').data

# Load reaction embeddings
pca = np.load('./embed/%s_pca.npz'%identifier)
V = torch.FloatTensor(pca['pc'])

data = np.load('./embed/%s_embeddings_reduced.npz'%identifier)
ref_rid_list = data['ids']
ref_embed_list = torch.FloatTensor(data['embeds'])
assert len(ref_data) == len(ref_embed_list)

imported dataset ./data/uspto_train.pkl
imported dataset ./data/uspto_valid.pkl


## Retrieve Relevant Records for Query

In [7]:
# Process query
q_vec = query_to_vec(query)

q_product, q_product_pred = trainer.embed(q_vec[1], q_vec[2], to_numpy = False)
if not has_reactant: q_product_pred = q_product
if not has_product: q_product = q_product_pred
assert torch.sum(torch.abs(q_product)) > 0 and torch.sum(torch.abs(q_product_pred)) > 0 

# Reduce dimensionality
q_product = torch.matmul(q_product, V)
q_product_pred = torch.matmul(q_product_pred, V)

# Retrieve relevant records
q_embed = torch.cat([q_product, q_product_pred], dim=1)
sim = euclidean_sim(q_embed, ref_embed_list).numpy().ravel()
sort_idx = np.argsort(-sim)

## Print Search Results

In [8]:
retrieved_dict = {}
print_list = []
for i, idx in enumerate(sort_idx):
    rid = ref_rid_list[idx]
    
    inst = ref_data[rid]
    res_product = mol_dict.get(inst['product'][0])
    res_reactant = '.'.join([mol_dict.get(x) for x in inst['reactant']])
    key = '%s_%s'%(res_product, res_reactant)
    if key in retrieved_dict.keys():
        continue
    else:
        retrieved_dict[key] = [query['Q.ID'], query['product'], query['reactant'], res_product, res_reactant]

    print_list.append([len(retrieved_dict), rid, res_product, res_reactant, np.sqrt(-sim[idx])])
    if len(retrieved_dict) == n_retrieve: break

In [9]:
pd.DataFrame(print_list, columns=['Rank', 'RX.ID', 'Product', 'Reactant', 'Distance']).set_index('Rank')

Unnamed: 0_level_0,RX.ID,Product,Reactant,Distance
Rank,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,train_164218,NCc1nnc2ncc(-c3ccccc3)nn12,[N-]=[N+]=NCc1nnc2ncc(-c3ccccc3)nn12,50.386028
2,train_402863,Nc1ccn2nc(-c3ccccc3)nc2c1,CC(C)(C)OC(=O)Nc1ccn2nc(-c3ccccc3)nc2c1,51.836639
3,train_82403,OCc1nnc2ccc(-c3ccccc3)nn12,OCc1nnc2ccc(Cl)nn12.OB(O)c1ccccc1,52.3568
4,train_391826,Oc1nc(O)n2ncc(-c3ccccc3)c2n1,[OH-].Oc1nc(S)nc2c(-c3ccccc3)cnn12,52.486568
5,train_71802,c1ccc(Nc2nsc3nc4ccccc4n23)nc1,Brc1nsc2nc3ccccc3n12.Nc1ccccn1,54.869946
6,train_156427,Nc1nc2nccc(-c3ccc(O)cc3)n2n1,Nc1nc2nccc(-c3ccc(OCc4ccccc4)cc3)n2n1,56.004971
7,train_5236,CC(C)(C)OC(=O)Nc1ccn2nc(-c3ccccc3)nc2c1,Brc1ccn2nc(-c3ccccc3)nc2c1.CC(C)(C)OC(N)=O,56.923813
8,train_364075,O=C(Nc1nc2c(Br)cccn2n1)c1ccccc1,Nc1nc2c(Br)cccn2n1.O=C(Cl)c1ccccc1,57.059418
9,train_7560,Nc1nsnc1-c1ccccc1,C[Si](C)(C)[N-][Si](C)(C)C.Clc1nsnc1-c1ccccc1,57.343334
10,train_320693,Clc1nc2nc(Br)nn2cc1-c1ccccc1,Clc1nc2nc(Br)nn2c(Br)c1-c1ccccc1,58.170109
