In [8]:
import os
import gc
import random
import itertools
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import torch
import tokenizers
import transformers
from transformers import AutoTokenizer, EncoderDecoderModel, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
import datasets
from datasets import load_dataset, load_metric
import sentencepiece
import argparse
from datasets.utils.logging import disable_progress_bar
from rdkit import Chem
import rdkit
disable_progress_bar()

class CFG:
    model = 't5'
    dataset_path = 'multiinput_prediction_output.csv'
    model_name_or_path = 'sagawa/ReactionT5v2-retrosynthesis'
    num_beams = 5
    num_return_sequences = 5
    debug = True
    seed = 42
    

device = 'cpu'

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed=CFG.seed)  
    

# dataset = pd.read_csv(CFG.dataset_path)

tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors='pt')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
from rdkit import Chem
def canonicalize(mol):
    mol = Chem.MolToSmiles(Chem.MolFromSmiles(mol),True)
    return mol
def canonicalize2(mol):
    try:
        return canonicalize(mol)
    except:
        return None
def remove_space(row):
    for i in range(5):
        row[f'{i}th'] = row[f'{i}th'].replace(' ', '')
#     row['valid compound'] = row['valid compound'].replace(' ', '')
    return row


In [10]:
df = pd.read_csv('/data1/ReactionT5_neword/task_retrosynthesis_length_check/diff5/output.csv')
seed_everything(seed=CFG.seed)  
# df['target'] = pd.read_csv('sampled.csv')['PRODUCT']
df

Unnamed: 0,input,0th,1th,2th,3th,4th,0th score,1th score,2th score,3th score,4th score
0,COC(=O)CCC(=O)c1ccc(OC2CCCCO2)cc1O,C1=COCCC1.COC(=O)CCC(=O)c1ccc(O)cc1O,C1=COCCC1.COC(=O)CCC(=O)c1ccc(O)cc1O,C1=COCCC1.C1=COCCC1.COC(=O)CCC(=O)c1ccc,C1=COCCC1.CCOCC.COC(=O)CCC(=O)c1ccc(O)cc,C1=COCCC1.Cc1ccc(S(=O)(=O)O)cc1.O,-0.000126,-0.237458,-0.256045,-0.278268,-0.284757
1,COC(=O)c1cccc(-c2nc3cccnc3[nH]2)c1,COC(=O)c1cccc(C(=O)O)c1.Nc1cccnc1N,COC(=O)c1cccc(C(N)=O)c1.Nc1cccnc1N,COC(=O)c1cccc(C(=O)Cl)c1.Nc1cccnc1N,COC(=O)c1cccc(C(=O)Nc2cccnc2N)c1,CC(=O)O.COC(=O)c1cccc(C(=O)O)c1.Nc1,-0.004330,-0.124449,-0.133902,-0.166011,-0.186553
2,CON(C)C(=O)C1CCC(NC(=O)OC(C)(C)C)CC1,CC(C)(C)OC(=O)NC1CCC(C(=O)O)CC1.CNOC,CC(C)(C)OC(=O)N[C@H]1CC[C@H](C(=O)O)CC,CC(C)(C)OC(=O)N[C@H]1CC[C@@H](C(=O)O)C,CC(C)(C)OC(=O)NC1CCC(C(=O)O)CC1.CN(C)C=O,C1CCOC1.CC(C)(C)OC(=O)NC1CCC(C(=O)O)CC1.CNOC,-0.004454,-0.068799,-0.120171,-0.125859,-0.132321
3,O=[N+]([O-])c1ccc(Cl)nc1Nc1ccc(O)cc1,Nc1ccc(O)cc1.O=[N+]([O-])c1ccc(Cl)nc1Cl,CCO.Nc1ccc(O)cc1.O=[N+]([O-])c1ccc(Cl)nc1,CC#N.Nc1ccc(O)cc1.O=[N+]([O-])c1ccc(Cl)n,CN(C)C=O.Nc1ccc(O)cc1.O=[N+]([O-])c1ccc(,NC(N)=O.Nc1ccc(O)cc1.O=[N+]([O-])c1ccc(Cl,-0.000184,-0.174721,-0.181740,-0.193402,-0.214376
4,NCC1=CC[C@@H](c2ccc(Cl)cc2Cl)[C@H]([N+](=O)[O-...,[N-]=[N+]=NCC1=CC[C@@H](c2ccc(Cl)cc2Cl)[C@H]([...,N.O=[N+]([O-])[C@@H]1CC(CBr)=CC[C@H]1c1ccc(Cl)...,O=C1c2ccccc2C(=O)N1CC1=CC[C@@H](c2ccc(Cl)cc2Cl...,CC(C)(C)OC(=O)NCC1=CC[C@@H](c2ccc(Cl)cc2Cl)[C@...,N.O=[N+]([O-])[C@@H]1CC(CCl)=CC[C@H]1c1ccc(Cl)...,-0.003364,-0.057931,-0.085116,-0.088348,-0.103804
...,...,...,...,...,...,...,...,...,...,...,...
4999,Cc1cc([N+](=O)[O-])ccc1Oc1ccnc(N)c1,Cc1cc([N+](=O)[O-])ccc1Oc1ccnc(NC(=O)OC,Cc1cc([N+](=O)[O-])ccc1O.Nc1cc(Cl)c,Cc1cc([N+](=O)[O-])ccc1Oc1ccnc(Cl)c1.N,Cc1cc([N+](=O)[O-])ccc1Oc1ccnc(Cl)c1.[,CCN(C(C)C)C(C)C.Cc1cc([N+](=O)[O-])ccc,-0.016411,-0.034820,-0.090390,-0.126439,-0.142715
5000,COC(=O)c1[nH]c2cc(Cl)cc3c2c1C(CC(=O)O)CC3,COC(=O)c1[nH]c2cc(Cl)cc3c2c1C(CC(=O)OC(C)(C)C,COC(=O)CC1CCc2cc(Cl)cc3[nH]c(C(=O)OC)c(c23)C1,COC(=O)c1[nH]c2cc(Cl)cc3c2c1C(CC(=O)OCc1ccccc1)C,CCOC(=O)CC1CCc2cc(Cl)cc3[nH]c(C(=O)OC)c(c23)C1,COC(=O)CC1CCc2cc(Cl)cc3[nH]c(C(=O)OC)c1c23,-0.011357,-0.052449,-0.071756,-0.084066,-0.100000
5001,COc1cc(C(F)(F)F)cc(SC)c1C(=O)NC1(c2ccccc2)CC(O...,COc1cc(C(F)(F)F)cc(SC)c1C(=O)NC1(c2ccccc2)CC(=...,COCOC1CN(C)CC(NC(=O)c2c(OC)cc(C(F)(F)F)cc2SC)(...,CN1CC(O)CC(N)(c2ccccc2)C1.COc1cc(C(F)(F)F)cc(S...,COc1cc(C(F)(F)F)cc(SC)c1C(=O)NC1(c2ccccc2)CC(O...,COc1cc(C(F)(F)F)cc(SC)c1C(=O)NC1(c2ccccc2)CC(O...,-0.018316,-0.031270,-0.046004,-0.048696,-0.061745
5002,Cc1nn(CC(C)CO)c(-c2ccc(F)cc2)c1Br,CC(CO)CBr.Cc1n[nH]c(-c2ccc(F)cc2)c1Br,CC(CBr)CO.Cc1n[nH]c(-c2ccc(F)cc2)c1Br,Cc1nn(CC(C)C(=O)O)c(-c2ccc(F)cc2)c1Br,CC(CO)Cn1nc(C(F)(F)F)c(Br)c1-c1ccc(F),Cc1n[nH]c(-c2ccc(F)cc2)c1Br.OCC(CBr)C,-0.003974,-0.099226,-0.144819,-0.167172,-0.167784


In [11]:
df['target'] = pd.read_csv("/data1/ReactionT5_neword/data/USPTO_50k/test.csv")["REACTANT"].to_list() 
df = df.apply(remove_space, axis=1)

In [12]:
# silence RDKit warnings
rdkit.RDLogger.DisableLog('rdApp.*')

In [13]:
top_k_invalidity = 5

top1, top2, top3, top5 = [], [], [], []
invalidity = []

for idx, row in df.iterrows():
    target = canonicalize(row['target'])
    if canonicalize2(row['0th']) == target:
        top1.append(1)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['1th']) == target:
        top1.append(0)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['2th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['3th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    elif canonicalize2(row['4th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    else:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(0)

        
    input_compound = row['input']
    output = [row[f'{i}th'] for i in range(top_k_invalidity)]
    inval_score = 0
    for ith, out in enumerate(output):
        mol = Chem.MolFromSmiles(out.rstrip('.'))
        if type(mol) != rdkit.Chem.rdchem.Mol:
            inval_score += 1
    invalidity.append(inval_score)
df['top1_accuracy'] = top1
df['top2_accuracy'] = top2
df['top3_accuracy'] = top3
df['top5_accuracy'] = top5
df['invalidity'] = invalidity

In [14]:
print(sum(df['top1_accuracy']) / len(df), sum(df['top2_accuracy']) / len(df), sum(df['top3_accuracy']) / len(df), sum(df['top5_accuracy']) / len(df))
print(sum(invalidity)/(len(invalidity)*top_k_invalidity)*100)

0.5079936051159073 0.5769384492406076 0.5965227817745803 0.6123101518784972
42.27418065547562


In [15]:
df = pd.read_csv('/data1/ReactionT5_neword/task_retrosynthesis_length_check/diff10/output.csv')
seed_everything(seed=CFG.seed)  
df['target'] = pd.read_csv("/data1/ReactionT5_neword/data/USPTO_50k/test.csv")["REACTANT"].to_list() 
df = df.apply(remove_space, axis=1)
top_k_invalidity = 5

top1, top2, top3, top5 = [], [], [], []
invalidity = []

for idx, row in df.iterrows():
    target = canonicalize(row['target'])
    if canonicalize2(row['0th']) == target:
        top1.append(1)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['1th']) == target:
        top1.append(0)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['2th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['3th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    elif canonicalize2(row['4th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    else:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(0)

        
    input_compound = row['input']
    output = [row[f'{i}th'] for i in range(top_k_invalidity)]
    inval_score = 0
    for ith, out in enumerate(output):
        mol = Chem.MolFromSmiles(out.rstrip('.'))
        if type(mol) != rdkit.Chem.rdchem.Mol:
            inval_score += 1
    invalidity.append(inval_score)
df['top1_accuracy'] = top1
df['top2_accuracy'] = top2
df['top3_accuracy'] = top3
df['top5_accuracy'] = top5
df['invalidity'] = invalidity

print(sum(df['top1_accuracy']) / len(df), sum(df['top2_accuracy']) / len(df), sum(df['top3_accuracy']) / len(df), sum(df['top5_accuracy']) / len(df))
print(sum(invalidity)/(len(invalidity)*top_k_invalidity)*100)

0.6033173461231015 0.6924460431654677 0.7194244604316546 0.7434052757793765
25.463629096722624


In [16]:
df = pd.read_csv('/data1/ReactionT5_neword/task_retrosynthesis_length_check/diff20/output.csv')
seed_everything(seed=CFG.seed)  
df['target'] = pd.read_csv("/data1/ReactionT5_neword/data/USPTO_50k/test.csv")["REACTANT"].to_list() 
df = df.apply(remove_space, axis=1)
top_k_invalidity = 5

top1, top2, top3, top5 = [], [], [], []
invalidity = []

for idx, row in df.iterrows():
    target = canonicalize(row['target'])
    if canonicalize2(row['0th']) == target:
        top1.append(1)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['1th']) == target:
        top1.append(0)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['2th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['3th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    elif canonicalize2(row['4th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    else:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(0)

        
    input_compound = row['input']
    output = [row[f'{i}th'] for i in range(top_k_invalidity)]
    inval_score = 0
    for ith, out in enumerate(output):
        mol = Chem.MolFromSmiles(out.rstrip('.'))
        if type(mol) != rdkit.Chem.rdchem.Mol:
            inval_score += 1
    invalidity.append(inval_score)
df['top1_accuracy'] = top1
df['top2_accuracy'] = top2
df['top3_accuracy'] = top3
df['top5_accuracy'] = top5
df['invalidity'] = invalidity

print(sum(df['top1_accuracy']) / len(df), sum(df['top2_accuracy']) / len(df), sum(df['top3_accuracy']) / len(df), sum(df['top5_accuracy']) / len(df))
print(sum(invalidity)/(len(invalidity)*top_k_invalidity)*100)

0.7006394884092726 0.8051558752997602 0.8385291766586731 0.8693045563549161
8.181454836131095


In [17]:
df = pd.read_csv('/data1/ReactionT5_neword/task_retrosynthesis_length_check/no_specification/output.csv')
seed_everything(seed=CFG.seed)  
df['target'] = pd.read_csv("/data1/ReactionT5_neword/data/USPTO_50k/test.csv")["REACTANT"].to_list() 
df = df.apply(remove_space, axis=1)
top_k_invalidity = 5

top1, top2, top3, top5 = [], [], [], []
invalidity = []

for idx, row in df.iterrows():
    target = canonicalize(row['target'])
    if canonicalize2(row['0th']) == target:
        top1.append(1)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['1th']) == target:
        top1.append(0)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['2th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['3th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    elif canonicalize2(row['4th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    else:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(0)

        
    input_compound = row['input']
    output = [row[f'{i}th'] for i in range(top_k_invalidity)]
    inval_score = 0
    for ith, out in enumerate(output):
        mol = Chem.MolFromSmiles(out.rstrip('.'))
        if type(mol) != rdkit.Chem.rdchem.Mol:
            inval_score += 1
    invalidity.append(inval_score)
df['top1_accuracy'] = top1
df['top2_accuracy'] = top2
df['top3_accuracy'] = top3
df['top5_accuracy'] = top5
df['invalidity'] = invalidity

print(sum(df['top1_accuracy']) / len(df), sum(df['top2_accuracy']) / len(df), sum(df['top3_accuracy']) / len(df), sum(df['top5_accuracy']) / len(df))
print(sum(invalidity)/(len(invalidity)*top_k_invalidity)*100)

0.7088329336530775 0.8145483613109512 0.8507194244604317 0.88189448441247
0.38369304556354916


In [18]:
df = pd.read_csv('/data1/ReactionT5_neword/task_retrosynthesis_length_check/percentile25-75/output.csv')
seed_everything(seed=CFG.seed)  
df['target'] = pd.read_csv("/data1/ReactionT5_neword/data/USPTO_50k/test.csv")["REACTANT"].to_list() 
df = df.apply(remove_space, axis=1)
top_k_invalidity = 5

top1, top2, top3, top5 = [], [], [], []
invalidity = []

for idx, row in df.iterrows():
    target = canonicalize(row['target'])
    if canonicalize2(row['0th']) == target:
        top1.append(1)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['1th']) == target:
        top1.append(0)
        top2.append(1)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['2th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(1)
        top5.append(1)
    elif canonicalize2(row['3th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    elif canonicalize2(row['4th']) == target:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(1)
    else:
        top1.append(0)
        top2.append(0)
        top3.append(0)
        top5.append(0)

        
    input_compound = row['input']
    output = [row[f'{i}th'] for i in range(top_k_invalidity)]
    inval_score = 0
    for ith, out in enumerate(output):
        mol = Chem.MolFromSmiles(out.rstrip('.'))
        if type(mol) != rdkit.Chem.rdchem.Mol:
            inval_score += 1
    invalidity.append(inval_score)
df['top1_accuracy'] = top1
df['top2_accuracy'] = top2
df['top3_accuracy'] = top3
df['top5_accuracy'] = top5
df['invalidity'] = invalidity

print(sum(df['top1_accuracy']) / len(df), sum(df['top2_accuracy']) / len(df), sum(df['top3_accuracy']) / len(df), sum(df['top5_accuracy']) / len(df))
print(sum(invalidity)/(len(invalidity)*top_k_invalidity)*100)

0.37589928057553956 0.42525979216626697 0.44184652278177455 0.456634692246203
36.11111111111111
