In [1]:
import os
import gc
import random
import itertools
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import torch
import tokenizers
import transformers
from transformers import AutoTokenizer, EncoderDecoderModel, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
import datasets
from datasets import load_dataset, load_metric
import sentencepiece
import argparse
from datasets.utils.logging import disable_progress_bar
from rdkit import Chem
import rdkit
disable_progress_bar()

class CFG:
    model = 't5'
    dataset_path = 'multiinput_prediction_output.csv'
    model_name_or_path = 'sagawa/ReactionT5v2-forward'
    num_beams = 5
    num_return_sequences = 5
    debug = True
    seed = 42
    

device = 'cpu'

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed=CFG.seed)  
    

# dataset = pd.read_csv(CFG.dataset_path)

tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors='pt')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
from rdkit import Chem
def canonicalize(mol):
    mol = Chem.MolToSmiles(Chem.MolFromSmiles(mol),True)
    return mol
def canonicalize2(mol):
    try:
        return canonicalize(mol)
    except:
        return None
def remove_space(row):
    for i in range(5):
        row[f'{i}th'] = row[f'{i}th'].replace(' ', '')
#     row['valid compound'] = row['valid compound'].replace(' ', '')
    return row


In [3]:
df = pd.read_csv('/data1/ReactionT5_neword/task_forward/t5_finetune_sampling10_similarall/output.csv')
seed_everything(seed=CFG.seed)  
# df['target'] = pd.read_csv('sampled.csv')['PRODUCT']
df

Unnamed: 0,input,0th,1th,2th,3th,4th,0th score,1th score,2th score,3th score,4th score
0,REACTANT:N#Cc1ccsc1N.O=[N+]([O-])c1cc(F)c(F)cc...,N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-],N#Cc1ccsc1Nc1cc(F)c2cc(F)c(F)cc2c1[N+](=O)[O-],N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-])cc1F,N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-])c(F)c1F,N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-])cc(F)c1F,0.000000,-1.149990,-1.232230,-1.261917,-1.278072
1,REACTANT:COC(=O)Cc1cn(C)c2cc(O)ccc12.Cc1nn(-c2...,COC(=O)Cc1cn(C)c2cc(OCC(C)c3cn(-c4ccc(C(F)(F)F...,COC(=O)Cc1cn(CCC(C)c2cn(-c3ccc(C(F)(F)F)cc3)nc...,COC(=O)Cc1cn(C)c2cc(OCC(C)c3cn(-c4ccc(C(F)(F)F...,COC(=O)Cc1cn(C)c2cc(OCC(C)c3cn(-c4ccc(C(F)(F)F...,COC(=O)Cc1cn(C)c2cc(OCC(C)c3cn(-c4ccc(C(F)(F)F...,0.000000,-0.557471,-0.646210,-0.667282,-0.675386
2,REACTANT:Cl.NC1CCN(CC2Cn3c(=O)ccc4ncc(F)c2c43)...,Cl.O=c1ccc2ncc(F)c3c2n1CC3CN1CCC(NCc2cc3c(cn2)...,Cl.O=c1ccc2ncc(F)c3c2n1CC3CN1CC[C@@H](NCc2cc3c...,Cl.O=c1ccc2ncc(F)c3c2n1CC3CN1CC[C@H](NCc2cc3c(...,Cl.Cl.O=c1ccc2ncc(F)c3c2n1CC3CN1CC[C@@H](NCc2c...,Cl.Cl.O=c1ccc2ncc(F)c3c2n1CC3CN1CCC(NCc2cc3c(c...,-0.008115,-0.016474,-0.107739,-0.552855,-0.680572
3,REACTANT:C=C(C)C(=O)Cl.CC(C)=C1C(=O)N(c2ccc(O)...,C=C(C)C(=O)Oc1ccc(N2C(=O)C(=C(C)C)C(=C(C)c3cc(...,C=C(C)C(=O)Oc1ccc(N2C(=O)C(=C(C)C)C(=C(C)c3cc(...,C=C(C)C(=O)Oc1ccc(N2C(=O)C(=C(C)C)/C(=C(/C)c3c...,C=C(C)C(=O)Oc1ccc(N2C(=O)C(=C(C)C)C(=C(C)c4cc(...,C=C(C)C(=O)Oc1ccc(N2C(=O)C(=C(C)C)C(=C(C)c3cc(...,0.000000,-0.648712,-0.651420,-0.677820,-0.685970
4,REACTANT:O=Cc1cncc(Cl)c1COC1CCCCO1REAGENT:OCc1...,OCc1cncc(Cl)c1COC1CCCCO1,OCc1cncc(Cl)c1COC1CCCCO1)c1cncc(Cl)c1COC1CCCCO1,OCc1cncc(Cl)c1CO))))CCCCOC1CCCCO1,OCc1cncc(Cl)c1CO)))))CCOCCOC1CCCCO1,OCc1cncc(Cl)c1COC1CCCCO1)c1cncc(Cl)c1,0.000000,-1.514894,-2.409068,-2.469165,-2.501678
...,...,...,...,...,...,...,...,...,...,...,...
39995,REACTANT:COC(=O)NCC1Cc2c(Cl)cc3c(c2O1)CCC3REAG...,CNCC1Cc2c(Cl)cc3c(c2O1)CCC3,CNCC1Cc2c(Cl)cc3c(c2O1)CCC3.Cl,CNCC1Cc3c(Cl)cc4c(c3O1)CCC4,CNCC1Cc2cc3c(c(Cl)c2O1)CCC3,CNCC1Cc2cc3c(cc2O1)CCC3,0.000000,-0.975255,-1.750278,-1.826253,-1.893469
39996,REACTANT:COc1cccc(C(=O)Cl)c1.COc1ccccc1OCREAGE...,COc1cccc(C(=O)c2ccc(OC)c(OC)c2)c1,COc1cccc(C(=O)c2ccc(OC)c(OC)c2)c1.COc1cccc(C(=...,COc1cccc(C(=O)c2ccc(OC)c(OC)c2)c1.COc1ccc(C(=O...,COc1cccc(C(=O)c2ccc(OC)c(OC)c2)c1)c1ccc(OC)c(O...,COc1cccc(C(=O)c2ccc(Oc3ccccc3OC)c(OC)c2)c1,0.000000,-0.929118,-1.071079,-1.410985,-1.425890
39997,REACTANT:C#Cc1cccc(C2C(C(=O)OC)=C(C)NC(C(OC)OC...,C#Cc1cccc(C2C(C(=O)OC)=C(C)NC(C=O)=C2C(=O)OC)c1,C#Cc1cccc(C2C(C(=O)OC)=C(C)NC(C=O)=C2C(=O)OC,C#Cc1cccc(C2C(C(=O)OC)=C(C)Nc3cc(C(=O)OC)c(C)n...,C#Cc1cccc(C2C(C(=O)OC)=C(C)Nc3cc(OC)c(OC)cc32)c1,C#Cc1cccc(C2C(C(=O)NC)=C(C)NC(C=O)=C2C(=O)OC)c1,0.000000,-0.791577,-0.841979,-0.938112,-0.989227
39998,REACTANT:CO.COc1ccc(CC(=O)c2ccc(O)cc2O)cc1REAG...,COc1ccc(CC(O)c2ccc(O)cc2O)cc1,COc1ccc(CC(OC)c2ccc(O)cc2O)cc1,COc1ccc(-c2coc3cc(O)ccc3c2=O)cc1,COc1ccc(CCc2ccc(O)cc2O)cc1,COc1ccc(CC(OC)(OC)c2ccc(O)cc2O)cc1,-0.030015,-0.057779,-0.060773,-0.099768,-0.545348


In [4]:
df['target'] = pd.read_csv("/data1/ReactionT5_neword/data/USPTO_MIT/MIT_separated/test.csv")["PRODUCT"].to_list() 
df = df.apply(remove_space, axis=1)

In [5]:
# silence RDKit warnings
rdkit.RDLogger.DisableLog('rdApp.*')

In [6]:
top_k_invalidity = 5

top1, top2, top3, top5 = [], [], [], []
invalidity = []

for idx, row in df.iterrows():
    target = canonicalize(row['target'])
    if canonicalize2(row['0th']) == target:
        top1.append(1)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['1th']) == target:
        top1.append(0)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['2th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['3th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    elif canonicalize2(row['4th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    else:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(0)

        
    input_compound = row['input']
    output = [row[f'{i}th'] for i in range(top_k_invalidity)]
    inval_score = 0
    for ith, out in enumerate(output):
        mol = Chem.MolFromSmiles(out.rstrip('.'))
        if type(mol) != rdkit.Chem.rdchem.Mol:
            inval_score += 1
    invalidity.append(inval_score)
df['top1_accuracy'] = top1
df['top2_accuracy'] = top2
df['top3_accuracy'] = top3
df['top5_accuracy'] = top5
df['invalidity'] = invalidity

In [7]:
print(sum(df['top1_accuracy']) / len(df), sum(df['top2_accuracy']) / len(df), sum(df['top3_accuracy']) / len(df), sum(df['top5_accuracy']) / len(df))
print(sum(invalidity)/(len(invalidity)*top_k_invalidity)*100)

0.94015 0.965175 0.969825 0.9742
25.384


In [8]:
df = pd.read_csv('/data1/ReactionT5_neword/task_forward/t5_finetune_sampling10_similartotestall/output.csv')
seed_everything(seed=CFG.seed)  
df['target'] = pd.read_csv("/data1/ReactionT5_neword/data/USPTO_MIT/MIT_separated/test.csv")["PRODUCT"].to_list() 
df = df.apply(remove_space, axis=1)
top_k_invalidity = 5

top1, top2, top3, top5 = [], [], [], []
invalidity = []

for idx, row in df.iterrows():
    target = canonicalize(row['target'])
    if canonicalize2(row['0th']) == target:
        top1.append(1)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['1th']) == target:
        top1.append(0)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['2th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['3th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    elif canonicalize2(row['4th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    else:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(0)

        
    input_compound = row['input']
    output = [row[f'{i}th'] for i in range(top_k_invalidity)]
    inval_score = 0
    for ith, out in enumerate(output):
        mol = Chem.MolFromSmiles(out.rstrip('.'))
        if type(mol) != rdkit.Chem.rdchem.Mol:
            inval_score += 1
    invalidity.append(inval_score)
df['top1_accuracy'] = top1
df['top2_accuracy'] = top2
df['top3_accuracy'] = top3
df['top5_accuracy'] = top5
df['invalidity'] = invalidity

print(sum(df['top1_accuracy']) / len(df), sum(df['top2_accuracy']) / len(df), sum(df['top3_accuracy']) / len(df), sum(df['top5_accuracy']) / len(df))
print(sum(invalidity)/(len(invalidity)*top_k_invalidity)*100)

0.9389 0.960075 0.965325 0.9707
20.811
