In [1]:
import os
import gc
import random
import itertools
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import torch
import tokenizers
import transformers
from transformers import AutoTokenizer, EncoderDecoderModel, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
import datasets
from datasets import load_dataset, load_metric
import sentencepiece
import argparse
from datasets.utils.logging import disable_progress_bar
from rdkit import Chem
import rdkit
disable_progress_bar()

class CFG:
    model = 't5'
    dataset_path = 'multiinput_prediction_output.csv'
    model_name_or_path = 'sagawa/ReactionT5v2-retrosynthesis'
    num_beams = 5
    num_return_sequences = 5
    debug = True
    seed = 42
    

device = 'cpu'

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed=CFG.seed)  
    

# dataset = pd.read_csv(CFG.dataset_path)

tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors='pt')

In [2]:
from rdkit import Chem
def canonicalize(mol):
    mol = Chem.MolToSmiles(Chem.MolFromSmiles(mol),True)
    return mol
def canonicalize2(mol):
    try:
        return canonicalize(mol)
    except:
        return None
def remove_space(row):
    for i in range(5):
        row[f'{i}th'] = row[f'{i}th'].replace(' ', '')
#     row['valid compound'] = row['valid compound'].replace(' ', '')
    return row


In [3]:
df = pd.read_csv('/data1/ReactionT5_neword/task_retrosynthesis/t5_finetune_sampling10_similarall/output.csv')
seed_everything(seed=CFG.seed)  
# df['target'] = pd.read_csv('sampled.csv')['PRODUCT']
df

Unnamed: 0,input,0th,1th,2th,3th,4th,0th score,1th score,2th score,3th score,4th score
0,COC(=O)CCC(=O)c1ccc(OC2CCCCO2)cc1O,C1=COCCC1.COC(=O)CCC(=O)c1ccc(O)cc1O.Cc1ccc(S(...,C1=COCCC1.COC(=O)CCC(=O)c1ccc(O)cc1O,C1=COCCC1.COC(=O)CCC(=O)c1ccc(O)cc1O.Cc1ccc(S(...,C1=COCCC1.COC(=O)CCC(=O)c1ccc(O)cc1O.Cc1ccc(S(...,C1=COCCC1.COC(=O)CCC(=O)c1ccc(O)cc1O.Cc1ccc(S(...,-0.026484,-0.031880,-0.045180,-0.055616,-0.076609
1,COC(=O)c1cccc(-c2nc3cccnc3[nH]2)c1,COC(=O)c1cccc(C(=O)O)c1.Nc1cccnc1N.O=P(Cl)(Cl)Cl,COC(=O)c1cccc(C(=O)O)c1.Nc1cccnc1N,CCN(CC)CC.COC(=O)c1cccc(C(=O)O)c1.Nc1cccnc1N.O...,CCN(CC)CC.COC(=O)c1cccc(C(=O)O)c1.Nc1cccnc1N.O...,CC(=O)O.COC(=O)c1cccc(C(=O)Nc2cccnc2N)c1,-0.046570,-0.048827,-0.085147,-0.088241,-0.090641
2,CON(C)C(=O)C1CCC(NC(=O)OC(C)(C)C)CC1,CC(C)(C)OC(=O)NC1CCC(C(=O)O)CC1.CCN=C=NCCCN(C)...,C(=NC1CCCCC1)=NC1CCCCC1.CC(C)(C)OC(=O)NC1CCC(C...,CC(C)(C)OC(=O)NC1CCC(C(=O)O)CC1.CCN=C=NCCCN(C)...,C(=NC1CCCCC1)=NC1CCCCC1.CC(C)(C)OC(=O)NC1CCC(C...,C(=NC1CCCCC1)=NC1CCCCC1.CC(C)(C)OC(=O)NC1CCC(C...,-0.033483,-0.043998,-0.044669,-0.052927,-0.065151
3,O=[N+]([O-])c1ccc(Cl)nc1Nc1ccc(O)cc1,Nc1ccc(O)cc1.O=[N+]([O-])c1ccc(Cl)nc1Cl,CCO.Nc1ccc(O)cc1.O=[N+]([O-])c1ccc(Cl)nc1Cl,Nc1ccc(O)cc1.O=C([O-])O.O=[N+]([O-])c1ccc(Cl)n...,CCN(CC)CC.Nc1ccc(O)cc1.O=[N+]([O-])c1ccc(Cl)nc1Cl,C1CCOC1.Nc1ccc(O)cc1.O=[N+]([O-])c1ccc(Cl)nc1Cl,-0.026931,-0.052209,-0.061678,-0.062336,-0.082540
4,NCC1=CC[C@@H](c2ccc(Cl)cc2Cl)[C@H]([N+](=O)[O-...,[N-]=[N+]=NCC1=CC[C@@H](c2ccc(Cl)cc2Cl)[C@H]([...,[N-]=[N+]=NCC1=CC[C@@H](c2ccc(Cl)cc2Cl)[C@H]([...,N.[N-]=[N+]=NCC1=CC[C@@H](c2ccc(Cl)cc2Cl)[C@H]...,N.N#CC1=CC[C@@H](c2ccc(Cl)cc2Cl)[C@H]([N+](=O)...,N.N#CC1=CC[C@@H](c2ccc(Cl)cc2Cl)[C@H]([N+](=O)...,-0.007578,-0.041177,-0.042563,-0.051556,-0.055493
...,...,...,...,...,...,...,...,...,...,...,...
4999,Cc1cc([N+](=O)[O-])ccc1Oc1ccnc(N)c1,CCN(C(C)C)C(C)C.Cc1cc([N+](=O)[O-])ccc1O.Nc1cc...,CCN(C(C)C)C(C)C.CN1CCCC1=O.Cc1cc([N+](=O)[O-])...,Cc1cc([N+](=O)[O-])ccc1Oc1ccnc(NC(=O)OC(C)(C)C...,CCN(C(C)C)C(C)C.Cc1cc([N+](=O)[O-])ccc1O.Nc1cc...,CCN(C(C)C)C(C)C.CN1CCCC1=O.Cc1cc([N+](=O)[O-])...,-0.019575,-0.034775,-0.045381,-0.055425,-0.058150
5000,COC(=O)c1[nH]c2cc(Cl)cc3c2c1C(CC(=O)O)CC3,COC(=O)c1[nH]c2cc(Cl)cc3c2c1C(CC(=O)OC(C)(C)C)...,CO.COC(=O)c1[nH]c2cc(Cl)cc3c2c1C(CC(=O)OCc1ccc...,COC(=O)c1[nH]c2cc(Cl)cc3c2c1C(CC(=O)OC(C)(C)C)...,COC(=O)c1[nH]c2cc(Cl)cc3c2c1C(CC(=O)OC(C)(C)C)CC3,CO.COC(=O)c1[nH]c2cc(Cl)cc3c2c1C(CC(=O)OC(C)(C...,-0.037330,-0.041904,-0.042261,-0.053992,-0.057731
5001,COc1cc(C(F)(F)F)cc(SC)c1C(=O)NC1(c2ccccc2)CC(O...,COc1cc(C(F)(F)F)cc(SC)c1C(=O)NC1(c2ccccc2)CC(=...,CO.COc1cc(C(F)(F)F)cc(SC)c1C(=O)NC1(c2ccccc2)C...,CCO.COc1cc(C(F)(F)F)cc(SC)c1C(=O)NC1(c2ccccc2)...,CCOC(C)=O.CO.COc1cc(C(F)(F)F)cc(SC)c1C(=O)NC1(...,CO.COc1cc(C(F)(F)F)cc(SC)c1C(=O)NC1(c2ccccc2)C...,-0.017691,-0.026993,-0.036331,-0.051202,-0.058533
5002,Cc1nn(CC(C)CO)c(-c2ccc(F)cc2)c1Br,CC(CO)CBr.Cc1n[nH]c(-c2ccc(F)cc2)c1Br.O=C([O-]...,CC(CO)CBr.Cc1n[nH]c(-c2ccc(F)cc2)c1Br,CC(CO)CBr.CN(C)C=O.Cc1n[nH]c(-c2ccc(F)cc2)c1Br...,CC(CO)CBr.CN(C)C=O.Cc1n[nH]c(-c2ccc(F)cc2)c1Br...,CC(CO)CBr.CCOC(C)=O.CN(C)C=O.Cc1n[nH]c(-c2ccc(...,-0.027789,-0.036755,-0.044898,-0.047607,-0.056572


In [4]:
df['target'] = pd.read_csv("/data1/ReactionT5_neword/data/USPTO_50k/test.csv")["REACTANT"].to_list() 
df = df.apply(remove_space, axis=1)

In [5]:
# silence RDKit warnings
rdkit.RDLogger.DisableLog('rdApp.*')

In [6]:
top_k_invalidity = 5

top1, top2, top3, top5 = [], [], [], []
invalidity = []

for idx, row in df.iterrows():
    target = canonicalize(row['target'])
    if canonicalize2(row['0th']) == target:
        top1.append(1)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['1th']) == target:
        top1.append(0)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['2th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['3th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    elif canonicalize2(row['4th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    else:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(0)

        
    input_compound = row['input']
    output = [row[f'{i}th'] for i in range(top_k_invalidity)]
    inval_score = 0
    for ith, out in enumerate(output):
        mol = Chem.MolFromSmiles(out.rstrip('.'))
        if type(mol) != rdkit.Chem.rdchem.Mol:
            inval_score += 1
    invalidity.append(inval_score)
df['top1_accuracy'] = top1
df['top2_accuracy'] = top2
df['top3_accuracy'] = top3
df['top5_accuracy'] = top5
df['invalidity'] = invalidity

In [7]:
print(sum(df['top1_accuracy']) / len(df), sum(df['top2_accuracy']) / len(df), sum(df['top3_accuracy']) / len(df), sum(df['top5_accuracy']) / len(df))
print(sum(invalidity)/(len(invalidity)*top_k_invalidity)*100)

0.15547561950439648 0.201638689048761 0.22721822541966427 0.28037569944044766
2.841726618705036


In [8]:
df = pd.read_csv('/data1/ReactionT5_neword/task_retrosynthesis/t5_finetune_sampling10_similartotestall/output.csv')
seed_everything(seed=CFG.seed)  
df['target'] = pd.read_csv("/data1/ReactionT5_neword/data/USPTO_50k/test.csv")["REACTANT"].to_list() 
df = df.apply(remove_space, axis=1)
top_k_invalidity = 5

top1, top2, top3, top5 = [], [], [], []
invalidity = []

for idx, row in df.iterrows():
    target = canonicalize(row['target'])
    if canonicalize2(row['0th']) == target:
        top1.append(1)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['1th']) == target:
        top1.append(0)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['2th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['3th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    elif canonicalize2(row['4th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    else:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(0)

        
    input_compound = row['input']
    output = [row[f'{i}th'] for i in range(top_k_invalidity)]
    inval_score = 0
    for ith, out in enumerate(output):
        mol = Chem.MolFromSmiles(out.rstrip('.'))
        if type(mol) != rdkit.Chem.rdchem.Mol:
            inval_score += 1
    invalidity.append(inval_score)
df['top1_accuracy'] = top1
df['top2_accuracy'] = top2
df['top3_accuracy'] = top3
df['top5_accuracy'] = top5
df['invalidity'] = invalidity

print(sum(df['top1_accuracy']) / len(df), sum(df['top2_accuracy']) / len(df), sum(df['top3_accuracy']) / len(df), sum(df['top5_accuracy']) / len(df))
print(sum(invalidity)/(len(invalidity)*top_k_invalidity)*100)

0.15267785771382894 0.2068345323741007 0.23741007194244604 0.28776978417266186
2.677857713828937
