In [1]:
import os
import gc
import random
import itertools
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import tokenizers
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModel, T5EncoderModel, get_linear_schedule_with_warmup
import datasets
from datasets import load_dataset, load_metric
import sentencepiece
import argparse
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import AdamW
import pickle
import time
import math
from sklearn.preprocessing import MinMaxScaler
from datasets.utils.logging import disable_progress_bar
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
disable_progress_bar()

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, required=False)
#     parser.add_argument("--dataset_name", type=str, required=False)
    parser.add_argument("--pretrained_model_name_or_path", type=str, default="sagawa/ZINC-t5", required=False)
    parser.add_argument("--model_name_or_path", type=str, required=False)
    parser.add_argument("--scaler_path", type=str, default="/data2/sagawa/tcrp-regression-model-archive/10-23-1st-new-metric-reactant-product", required=False)
    parser.add_argument("--debug", action='store_true', default=False, required=False)
    parser.add_argument("--batch_size", type=int, default=5, required=False)
    parser.add_argument("--max_len", type=int, default=512, required=False)
    parser.add_argument("--num_workers", type=int, default=1, required=False)
    parser.add_argument("--fc_dropout", type=float, default=0.1, required=False)
    parser.add_argument("--output_dir", type=str, default='./', required=False)
    parser.add_argument("--seed", type=int, default=42, required=False)

    return parser.parse_args()

class CFG():
    data_path='../../all_ord_reaction_uniq_with_attr_v3.tsv'
#     pretrained_model_name_or_path = 'sagawa/ZINC-t5'
    model = 'sagawa/ZINC-t5'
    batch_size = 5 #max_lenを大きくしたらoomしたから15から5に
    seed = 42
    num_workers = 4
    output_dir = './'
    model_name_or_path = '/data2/sagawa/tcrp-regression-model-archive/10-23-1st-new-metric-reactant-product'
    scaler_path = '/data2/sagawa/tcrp-regression-model-archive/10-23-1st-new-metric-reactant-product'

In [2]:
train_ds = pd.read_csv('../../regression-input-train.csv')
train_ds

Unnamed: 0,CATALYST,REACTANT,REAGENT,SOLVENT,INTERNAL_STANDARD,NoData,PRODUCT,YIELD,TEMP,input
0,,CC(=O)Cl.CC(C)(C)OC(=O)N1CCC2(CC1)CC(=O)c1nn(C...,,CO,,,Cc1cc(Cc2cc(C)c(O)c(C=O)c2)cc(C=O)c1O,96.0,0.0,REACTANT:CC(=O)Cl.CC(C)(C)OC(=O)N1CCC2(CC1)CC(...
1,Cl[Pd](Cl)([P](c1ccccc1)(c1ccccc1)c1ccccc1)[P]...,CCCC[Sn](CCCC)(CCCC)c1cccs1.CCCCc1nc(C)c(Br)c(...,,CCOC(C)=O.CN(C)C=O,,,CCCCCCCC/C=C\CCCCCCCC(=O)O,75.0,,REACTANT:CCCC[Sn](CCCC)(CCCC)c1cccs1.CCCCc1nc(...
2,,CC(C)CCBr.[Li]c1cccs1,,O,,,CCc1c(Cc2[nH]c(C(=O)OCc3ccccc3)c(C)c2CCC(=O)OC...,88.0,0.0,REACTANT:CC(C)CCBr.[Li]c1cccs1PRODUCT:CCc1c(Cc...
3,,CS(=O)(=O)O.O=P12OP3(=O)OP(=O)(O1)OP(=O)(O2)O3...,,,,,Oc1cc2cc[nH]c2cc1O,95.0,90.0,REACTANT:CS(=O)(=O)O.O=P12OP3(=O)OP(=O)(O1)OP(...
4,,C=C[Mg]Br.CC(C)CC=O.CCOC(=O)CC(=O)OCC.[K+].[OH-],,C1CCOC1.CCO,,,NN=C(C=Cc1ccccc1)c1ccccc1,30.0,0.0,REACTANT:C=C[Mg]Br.CC(C)CC=O.CCOC(=O)CC(=O)OCC...
...,...,...,...,...,...,...,...,...,...,...
545630,,Nc1cccc2c1C(=O)N(C1CCC(=O)NC1=O)C2=O.O=C(Cl)c1...,,C1CCOC1,,,CCCCCCCC/C=C\CCCCCCCC(=O)O,88.0,,REACTANT:Nc1cccc2c1C(=O)N(C1CCC(=O)NC1=O)C2=O....
545631,,CC(=O)O[BH-](OC(C)=O)OC(C)=O.O=C(CNC(=O)c1cccc...,,CC(=O)O.CCOC(C)=O.ClCCl,,,O=C(O)c1ccccc1-c1c2ccc(=O)cc-2oc2cc(O)ccc12,90.0,,REACTANT:CC(=O)O[BH-](OC(C)=O)OC(C)=O.O=C(CNC(...
545632,,CN1CCOCC1.COC(=O)c1ccc(Cc2cn(C)c3ccc(N)cc23)c(...,,ClCCl,,,CC#CN1C(=O)C(C)Oc2ccc(-n3c(=O)cc(C(F)(F)F)[nH]...,74.0,,REACTANT:CN1CCOCC1.COC(=O)c1ccc(Cc2cn(C)c3ccc(...
545633,,Cc1ccc(N)c(C#C[Si](C)(C)C)n1.[Na+].[OH-],,CO,,,CC(C)CN=C1C(c2ccccc2)=C(c2ccccc2)C(c2ccccc2)=C...,75.0,0.0,REACTANT:Cc1ccc(N)c(C#C[Si](C)(C)C)n1.[Na+].[O...


In [3]:
train_ds = train_ds[['input', 'YIELD']]
train_ds

Unnamed: 0,input,YIELD
0,REACTANT:CC(=O)Cl.CC(C)(C)OC(=O)N1CCC2(CC1)CC(...,96.0
1,REACTANT:CCCC[Sn](CCCC)(CCCC)c1cccs1.CCCCc1nc(...,75.0
2,REACTANT:CC(C)CCBr.[Li]c1cccs1PRODUCT:CCc1c(Cc...,88.0
3,REACTANT:CS(=O)(=O)O.O=P12OP3(=O)OP(=O)(O1)OP(...,95.0
4,REACTANT:C=C[Mg]Br.CC(C)CC=O.CCOC(=O)CC(=O)OCC...,30.0
...,...,...
545630,REACTANT:Nc1cccc2c1C(=O)N(C1CCC(=O)NC1=O)C2=O....,88.0
545631,REACTANT:CC(=O)O[BH-](OC(C)=O)OC(C)=O.O=C(CNC(...,90.0
545632,REACTANT:CN1CCOCC1.COC(=O)c1ccc(Cc2cn(C)c3ccc(...,74.0
545633,REACTANT:Cc1ccc(N)c(C#C[Si](C)(C)C)n1.[Na+].[O...,75.0


In [55]:
ori = pd.read_csv('../../all_ord_reaction_uniq_with_attr_v3.tsv').drop_duplicates().reset_index(drop=True)
ori

Unnamed: 0,CATALYST,REACTANT,REAGENT,SOLVENT,INTERNAL_STANDARD,NoData,PRODUCT,YIELD,TEMP
0,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.O=C(O)C1CCCN1C(=O)OCc1ccccc1,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,93.0,23.0
1,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCCCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,82.0,23.0
2,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCOCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,61.0,23.0
3,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)C(NC(=O)OC(C)(C)C)C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,72.0,23.0
4,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)NC(Cc1cn(C(=O...,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,83.0,23.0
...,...,...,...,...,...,...,...,...,...
2058485,,COc1ccccc1CCCCBr.Fc1ccc2c(C3CCNCC3)noc2c1.O=C(...,,CC#N.CCO,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058486,,CC(O)=S.CCOC(=O)N=NC(=O)OCC.O=C1C[C@@H](O)CN1....,,C1CCOC1,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058487,,C[C@@H](O[Si](C)(C)C(C)(C)C)[C@H]1C(=O)N2C(C(=...,,CC#N,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058488,CC(=O)[O-].O.[Cu+2],C=O.CC(=O)O.CC(=O)[O-].CCOC(=O)CC(=O)OCC.[K+],,,,,O=S(=O)(c1ccc(Cl)cc1)C(F)(F)F,,


In [59]:
len(ori['REACTANT'].unique()), len(ori['PRODUCT'].unique())

(1045802, 440212)

In [60]:
df = ori[~ori['YIELD'].isna()]
df['YIELD'] = df['YIELD'].clip(0, 100)
df

Unnamed: 0,CATALYST,REACTANT,REAGENT,SOLVENT,INTERNAL_STANDARD,NoData,PRODUCT,YIELD,TEMP
0,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.O=C(O)C1CCCN1C(=O)OCc1ccccc1,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,93.0,23.0
1,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCCCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,82.0,23.0
2,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCOCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,61.0,23.0
3,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)C(NC(=O)OC(C)(C)C)C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,72.0,23.0
4,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)NC(Cc1cn(C(=O...,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,83.0,23.0
...,...,...,...,...,...,...,...,...,...
2058408,[Na+].[OH-],COc1ccc(CC#N)cc1OC.O=Cc1ccc2ccccc2c1,,CCO,,,O=C(F)OCC(F)(F)F,82.0,
2058469,,CNCC[C@@H](O)c1ccccc1.FC(F)(F)c1ccc(Cl)cc1.[H-...,,CC(=O)N(C)C,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,86.0,90.0
2058470,,COc1cc(OC)c(Br)c(OC)c1.O.O=C1CCCCC1,,C1CCOC1,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,65.0,-30.0
2058473,,CN(C)Cc1ccccc1.OCC1CO1.OCCS,,CC(=O)CC(C)C,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,90.0,50.0


In [62]:
len(df['REACTANT'].unique()), len(df['PRODUCT'].unique())

(399405, 358)

In [44]:
df['REACTANT'].isna().sum(), df['PRODUCT'].isna().sum(), (df['REACTANT'].isna() & df['PRODUCT'].isna()).sum()

(144, 142, 142)

In [63]:
df = df[~(df['REACTANT'].isna() | df['PRODUCT'].isna())]

Unnamed: 0,CATALYST,REACTANT,REAGENT,SOLVENT,INTERNAL_STANDARD,NoData,PRODUCT,YIELD,TEMP
0,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.O=C(O)C1CCCN1C(=O)OCc1ccccc1,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,93.0,23.0
1,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCCCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,82.0,23.0
2,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCOCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,61.0,23.0
3,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)C(NC(=O)OC(C)(C)C)C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,72.0,23.0
4,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)NC(Cc1cn(C(=O...,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,83.0,23.0
...,...,...,...,...,...,...,...,...,...
2058408,[Na+].[OH-],COc1ccc(CC#N)cc1OC.O=Cc1ccc2ccccc2c1,,CCO,,,O=C(F)OCC(F)(F)F,82.0,
2058469,,CNCC[C@@H](O)c1ccccc1.FC(F)(F)c1ccc(Cl)cc1.[H-...,,CC(=O)N(C)C,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,86.0,90.0
2058470,,COc1cc(OC)c(Br)c(OC)c1.O.O=C1CCCCC1,,C1CCOC1,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,65.0,-30.0
2058473,,CN(C)Cc1ccccc1.OCC1CO1.OCCS,,CC(=O)CC(C)C,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,90.0,50.0


In [64]:
df['input'] = 'REACTANT:' + df['REACTANT'] + 'PRODUCT:' + df['PRODUCT']
df = df[['input', 'YIELD']].drop_duplicates().reset_index(drop=True)
df

Unnamed: 0,input,YIELD
0,REACTANT:CC(=O)c1ccc(Br)cc1.O=C(O)C1CCCN1C(=O)...,93.0
1,REACTANT:CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CC...,82.0
2,REACTANT:CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CC...,61.0
3,REACTANT:CC(=O)c1ccc(Br)cc1.CC(C)C(NC(=O)OC(C)...,72.0
4,REACTANT:CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)NC(C...,83.0
...,...,...
598384,REACTANT:COc1ccc(CC#N)cc1OC.O=Cc1ccc2ccccc2c1P...,82.0
598385,REACTANT:CNCC[C@@H](O)c1ccccc1.FC(F)(F)c1ccc(C...,86.0
598386,REACTANT:COc1cc(OC)c(Br)c(OC)c1.O.O=C1CCCCC1PR...,65.0
598387,REACTANT:CN(C)Cc1ccccc1.OCC1CO1.OCCSPRODUCT:CC...,90.0


In [51]:
len(df['input'].unique())

589667

In [54]:
df[df['input']=='REACTANT:Cn1cnc(C#N)c1.Fc1ccccc1BrPRODUCT:Cn1cnc(C#N)c1-c1ccccc1F.Cn1cnc(C#N)c1-c1ccccc1F']

Unnamed: 0,input,YIELD
137311,REACTANT:Cn1cnc(C#N)c1.Fc1ccccc1BrPRODUCT:Cn1c...,29.0
137312,REACTANT:Cn1cnc(C#N)c1.Fc1ccccc1BrPRODUCT:Cn1c...,45.0
137313,REACTANT:Cn1cnc(C#N)c1.Fc1ccccc1BrPRODUCT:Cn1c...,50.0
137314,REACTANT:Cn1cnc(C#N)c1.Fc1ccccc1BrPRODUCT:Cn1c...,6.0
137315,REACTANT:Cn1cnc(C#N)c1.Fc1ccccc1BrPRODUCT:Cn1c...,4.0
...,...,...
312939,REACTANT:Cn1cnc(C#N)c1.Fc1ccccc1BrPRODUCT:Cn1c...,92.0
312940,REACTANT:Cn1cnc(C#N)c1.Fc1ccccc1BrPRODUCT:Cn1c...,98.0
312941,REACTANT:Cn1cnc(C#N)c1.Fc1ccccc1BrPRODUCT:Cn1c...,100.0
312942,REACTANT:Cn1cnc(C#N)c1.Fc1ccccc1BrPRODUCT:Cn1c...,96.0


In [67]:
dfagg = df.groupby('input')['YIELD'].agg('mean').reset_index()
dfagg

Unnamed: 0,input,YIELD
0,REACTANT:*C(F)(F)C(*)(F)F.C1=CCCC1.Cl[SiH](Cl)...,91.0
1,REACTANT:*C(F)(F)C(*)(F)F.C1CCNCC1.CC(C)(C)[O-...,77.0
2,REACTANT:*C(F)(F)C(*)(F)F.CC(=O)c1ccccc1.[BH4-...,95.0
3,REACTANT:*C(F)(F)C(*)(F)F.CC(=O)c1ccccc1.[H][H...,95.0
4,REACTANT:*C(F)(F)C(*)(F)F.CCCCC(N)N.CCOCC.Cc1c...,21.0
...,...,...
589661,REACTANT:c1cncc(C2CCCC2)c1PRODUCT:C[C@H](CO)CO...,84.0
589662,REACTANT:c1cncc(OCCOC2CCCCO2)c1PRODUCT:CC(C)(C...,86.0
589663,REACTANT:c1cncc(OCCOC2CCCCO2)c1PRODUCT:CC1CO1,86.0
589664,REACTANT:c1cncc(OCCOC2CCCCO2)c1PRODUCT:COC(=O)...,86.0


In [78]:
lens = dfagg['input'].apply(lambda x: len(x))
# remove data that have too long inputs
dfagg = dfagg[lens <= 512].reset_index(drop=True)
dfagg

Unnamed: 0,input,YIELD
0,REACTANT:*C(F)(F)C(*)(F)F.C1=CCCC1.Cl[SiH](Cl)...,91.0
1,REACTANT:*C(F)(F)C(*)(F)F.C1CCNCC1.CC(C)(C)[O-...,77.0
2,REACTANT:*C(F)(F)C(*)(F)F.CC(=O)c1ccccc1.[BH4-...,95.0
3,REACTANT:*C(F)(F)C(*)(F)F.CC(=O)c1ccccc1.[H][H...,95.0
4,REACTANT:*C(F)(F)C(*)(F)F.CCCCC(N)N.CCOCC.Cc1c...,21.0
...,...,...
589651,REACTANT:c1cncc(C2CCCC2)c1PRODUCT:C[C@H](CO)CO...,84.0
589652,REACTANT:c1cncc(OCCOC2CCCCO2)c1PRODUCT:CC(C)(C...,86.0
589653,REACTANT:c1cncc(OCCOC2CCCCO2)c1PRODUCT:CC1CO1,86.0
589654,REACTANT:c1cncc(OCCOC2CCCCO2)c1PRODUCT:COC(=O)...,86.0


In [2]:
# multiinput
ori = pd.read_csv('../../all_ord_reaction_uniq_with_attr_v3.tsv').drop_duplicates().reset_index(drop=True)
ori

Unnamed: 0,CATALYST,REACTANT,REAGENT,SOLVENT,INTERNAL_STANDARD,NoData,PRODUCT,YIELD,TEMP
0,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.O=C(O)C1CCCN1C(=O)OCc1ccccc1,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,93.0,23.0
1,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCCCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,82.0,23.0
2,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCOCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,61.0,23.0
3,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)C(NC(=O)OC(C)(C)C)C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,72.0,23.0
4,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)NC(Cc1cn(C(=O...,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,83.0,23.0
...,...,...,...,...,...,...,...,...,...
2058485,,COc1ccccc1CCCCBr.Fc1ccc2c(C3CCNCC3)noc2c1.O=C(...,,CC#N.CCO,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058486,,CC(O)=S.CCOC(=O)N=NC(=O)OCC.O=C1C[C@@H](O)CN1....,,C1CCOC1,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058487,,C[C@@H](O[Si](C)(C)C(C)(C)C)[C@H]1C(=O)N2C(C(=...,,CC#N,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058488,CC(=O)[O-].O.[Cu+2],C=O.CC(=O)O.CC(=O)[O-].CCOC(=O)CC(=O)OCC.[K+],,,,,O=S(=O)(c1ccc(Cl)cc1)C(F)(F)F,,


In [3]:
df = ori[~ori['PRODUCT'].isna()]
df

Unnamed: 0,CATALYST,REACTANT,REAGENT,SOLVENT,INTERNAL_STANDARD,NoData,PRODUCT,YIELD,TEMP
0,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.O=C(O)C1CCCN1C(=O)OCc1ccccc1,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,93.0,23.0
1,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCCCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,82.0,23.0
2,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCOCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,61.0,23.0
3,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)C(NC(=O)OC(C)(C)C)C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,72.0,23.0
4,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)NC(Cc1cn(C(=O...,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,83.0,23.0
...,...,...,...,...,...,...,...,...,...
2058485,,COc1ccccc1CCCCBr.Fc1ccc2c(C3CCNCC3)noc2c1.O=C(...,,CC#N.CCO,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058486,,CC(O)=S.CCOC(=O)N=NC(=O)OCC.O=C1C[C@@H](O)CN1....,,C1CCOC1,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058487,,C[C@@H](O[Si](C)(C)C(C)(C)C)[C@H]1C(=O)N2C(C(=...,,CC#N,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058488,CC(=O)[O-].O.[Cu+2],C=O.CC(=O)O.CC(=O)[O-].CCOC(=O)CC(=O)OCC.[K+],,,,,O=S(=O)(c1ccc(Cl)cc1)C(F)(F)F,,


In [4]:
len(df['REACTANT'].unique()), len(df['PRODUCT'].unique())

(1045802, 440211)

In [6]:
dfr = df[~df['REACTANT'].isna()]
dfr

Unnamed: 0,CATALYST,REACTANT,REAGENT,SOLVENT,INTERNAL_STANDARD,NoData,PRODUCT,YIELD,TEMP
0,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.O=C(O)C1CCCN1C(=O)OCc1ccccc1,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,93.0,23.0
1,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCCCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,82.0,23.0
2,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)N1CCOCC1C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,61.0,23.0
3,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)C(NC(=O)OC(C)(C)C)C(=O)O,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,72.0,23.0
4,CC(C)(C)c1ccn2->[Ir+]34(<-n5cc(C(F)(F)F)ccc5-c...,CC(=O)c1ccc(Br)cc1.CC(C)(C)OC(=O)NC(Cc1cn(C(=O...,O=C([O-])[O-].[Cs+],CN(C)C=O,,,CC(=O)c1ccc(C2CCCN2C(=O)OCc2ccccc2)cc1,83.0,23.0
...,...,...,...,...,...,...,...,...,...
2058485,,COc1ccccc1CCCCBr.Fc1ccc2c(C3CCNCC3)noc2c1.O=C(...,,CC#N.CCO,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058486,,CC(O)=S.CCOC(=O)N=NC(=O)OCC.O=C1C[C@@H](O)CN1....,,C1CCOC1,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058487,,C[C@@H](O[Si](C)(C)C(C)(C)C)[C@H]1C(=O)N2C(C(=...,,CC#N,,,CC(O)CC(=O)[O-].O=C([O-])CCCO,,
2058488,CC(=O)[O-].O.[Cu+2],C=O.CC(=O)O.CC(=O)[O-].CCOC(=O)CC(=O)OCC.[K+],,,,,O=S(=O)(c1ccc(Cl)cc1)C(F)(F)F,,


In [7]:
len(dfr['REACTANT'].unique()), len(dfr['PRODUCT'].unique())

(1045801, 447)

In [11]:
import os
import gc
import random
import itertools
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import tokenizers
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModel, T5EncoderModel, get_linear_schedule_with_warmup
import datasets
from datasets import load_dataset, load_metric
import sentencepiece
import argparse
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import AdamW
import pickle
import time
import math
from sklearn.preprocessing import MinMaxScaler
from datasets.utils.logging import disable_progress_bar
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
disable_progress_bar()


CFG = dict(
    data_path='../../all_ord_reaction_uniq_with_attr_v3.tsv',
    pretrained_model_name_or_path='sagawa/ZINC-t5',
    model = 'sagawa/ZINC-t5',
    debug = True,
    epochs = 5,
    batch_size = 5, #max_lenを大きくしたらoomしたから15から5に
    max_len = 512,
    seed = 42,
    num_workers = 4,
    fc_dropout = 0.1,
    eps = 1e-6,
    max_grad_norm=1000,
    gradient_accumulation_steps=3,
    num_warmup_steps=0,
    n_trials=100,
    batch_scheduler=True,
    print_freq=100,
    use_apex=False,
    output_dir = './')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

OUTPUT_DIR = CFG['output_dir']
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed=CFG['seed'])  
    

df = pd.read_csv(CFG['data_path']).drop_duplicates().reset_index(drop=True)
df = df[~df['YIELD'].isna()].reset_index(drop=True)
df['YIELD'] = df['YIELD'].clip(0, 100)/100
df = df[~(df['REACTANT'].isna() | df['PRODUCT'].isna())]
for col in ['CATALYST', 'REACTANT', 'REAGENT', 'SOLVENT', 'INTERNAL_STANDARD', 'NoData','PRODUCT']:
    df[col] = df[col].fillna(' ')
    
    
###############################################
def clean(row):
    row = row.replace('. ', '').replace(' .', '').replace('  ', ' ')
    return row
df['REAGENT'] = df['CATALYST'] + '.' + df['REAGENT']
df['REAGENT'] = df['REAGENT'].apply(lambda x: clean(x))

from rdkit import Chem
def canonicalize(mol):
    mol = Chem.MolToSmiles(Chem.MolFromSmiles(mol),True)
    return mol

df['REAGENT'] = df['REAGENT'].apply(lambda x: canonicalize(x) if x != ' ' else ' ')
###############################################
    

df['input'] = 'REACTANT:' + df['REACTANT']  + 'REAGENT:' + df['REAGENT'] + 'PRODUCT:' + df['PRODUCT']
df = df[['input', 'YIELD']].drop_duplicates().reset_index(drop=True)

lens = df['input'].apply(lambda x: len(x))
# remove data that have too long inputs
df = df[lens <= 512].reset_index(drop=True)

train_ds, test_ds = train_test_split(df, test_size=int(len(df)*0.1))
train_ds, valid_ds = train_test_split(train_ds, test_size=int(len(df)*0.1))



In [13]:
!nvidia-smi

Sat Nov 26 16:32:36 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 470.42.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:19:00.0 Off |                  N/A |
| 52%   84C    P2   221W / 250W |   8792MiB / 11019MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  On   | 00000000:1A:00.0 Off |                  N/A |
| 32%   50C    P8    11W / 250W |   5131MiB / 11019MiB |      0%      Default |
|       

In [4]:
class CFG():
    data_path='../../all_ord_reaction_uniq_with_attr_v3.tsv'
    pretrained_model_name_or_path = 'sagawa/ZINC-t5'
    model = 'sagawa/ZINC-t5'
    debug = True
    epochs = 5
    lr = 2e-5
    batch_size = 5 #max_lenを大きくしたらoomしたから15から5に
    max_len = 512
    weight_decay = 0.01
    seed = 42
    num_workers = 4
    fc_dropout = 0.1
    eps = 1e-6
    max_grad_norm=1000
    gradient_accumulation_steps=3
    num_warmup_steps=0
    batch_scheduler=True
    print_freq=100
    use_apex=False
    output_dir = './'
    
class RegressionModel(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.pretrained_model_name_or_path, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        if pretrained:
            if 't5' in cfg.pretrained_model_name_or_path:
                self.model = T5EncoderModel.from_pretrained(CFG.pretrained_model_name_or_path)
            else:
                self.model = AutoModel.from_pretrained(CFG.pretrained_model_name_or_path)
        else:
            if 't5' in cfg.model_name_or_path:
                self.model = T5EncoderModel.from_pretrained('sagawa/ZINC-t5')
            else:
                self.model = AutoModel.from_config(self.config)
#         self.model.resize_token_embeddings(len(cfg.tokenizer))
        self.fc_dropout1 = nn.Dropout(cfg.fc_dropout)
        self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
        self.fc_dropout2 = nn.Dropout(cfg.fc_dropout)
        self.fc2 = nn.Linear(self.config.hidden_size, 1)
        
    def forward(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_states = outputs[0]
        output = self.fc1(self.fc_dropout1(last_hidden_states)[:, 0, :].view(-1, self.config.hidden_size))
        output = self.fc2(self.fc_dropout2(output))
        return output
    
model = RegressionModel(CFG, config_path=None, pretrained=True)
model

Some weights of the model checkpoint at sagawa/ZINC-t5 were not used when initializing T5EncoderModel: ['decoder.block.11.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.6.layer.1.layer_norm.weight', 'decoder.block.1.layer.1.EncDecAttention.k.weight', 'decoder.block.10.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.2.layer.1.EncDecAttention.o.weight', 'decoder.block.10.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.4.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.11.layer.1.EncDecAttention.q.weight', 'decoder.block.3.layer.0.layer_norm.weight', 'decoder.block.11.layer.1.layer_norm.weight', 'decoder.block.5.layer.1.layer_norm.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.5.layer.0.SelfAttention.o.weight', 'decoder.block.6.layer.2.layer_norm.weight', 'decoder.block.5.layer.1.EncDecAttention.q.w

RegressionModel(
  (model): T5EncoderModel(
    (shared): Embedding(221, 768)
    (encoder): T5Stack(
      (embed_tokens): Embedding(221, 768)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=768, out_features=768, bias=False)
                (k): Linear(in_features=768, out_features=768, bias=False)
                (v): Linear(in_features=768, out_features=768, bias=False)
                (o): Linear(in_features=768, out_features=768, bias=False)
                (relative_attention_bias): Embedding(32, 12)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseGatedActDense(
                (wi_0): Linear(in_features=768, out_features=2048, bias=False)
                (wi_1): Linear(in_features=7

In [21]:
for name, param in model.named_parameters():
    if ('model' in name) or ('block.11' noin name):
        print(name)
        param.requires_grad = False

All Flax model weights were used when initializing T5ForConditionalGeneration.

Some weights of T5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


T5ForConditionalGeneration(
  (shared): Embedding(221, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(221, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo): Lin

In [22]:
del model.lm_head

In [25]:
model.lm = nn.Linear(config.hidden_size, config.hidden_size)
model

T5ForConditionalGeneration(
  (shared): Embedding(221, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(221, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo): Lin

T5Stack(
  (embed_tokens): Embedding(221, 768)
  (block): ModuleList(
    (0): T5Block(
      (layer): ModuleList(
        (0): T5LayerSelfAttention(
          (SelfAttention): T5Attention(
            (q): Linear(in_features=768, out_features=768, bias=False)
            (k): Linear(in_features=768, out_features=768, bias=False)
            (v): Linear(in_features=768, out_features=768, bias=False)
            (o): Linear(in_features=768, out_features=768, bias=False)
            (relative_attention_bias): Embedding(32, 12)
          )
          (layer_norm): T5LayerNorm()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): T5LayerCrossAttention(
          (EncDecAttention): T5Attention(
            (q): Linear(in_features=768, out_features=768, bias=False)
            (k): Linear(in_features=768, out_features=768, bias=False)
            (v): Linear(in_features=768, out_features=768, bias=False)
            (o): Linear(in_features=768, out_features=768, bias=Fal

In [40]:
for name, param in model.decoder.block.named_parameters():
#     if ('model' in name) or ('block.11' noin name):
    print(name)
#         param.requires_grad = False

0.layer.0.SelfAttention.q.weight
0.layer.0.SelfAttention.k.weight
0.layer.0.SelfAttention.v.weight
0.layer.0.SelfAttention.o.weight
0.layer.0.SelfAttention.relative_attention_bias.weight
0.layer.0.layer_norm.weight
0.layer.1.EncDecAttention.q.weight
0.layer.1.EncDecAttention.k.weight
0.layer.1.EncDecAttention.v.weight
0.layer.1.EncDecAttention.o.weight
0.layer.1.layer_norm.weight
0.layer.2.DenseReluDense.wi_0.weight
0.layer.2.DenseReluDense.wi_1.weight
0.layer.2.DenseReluDense.wo.weight
0.layer.2.layer_norm.weight
1.layer.0.SelfAttention.q.weight
1.layer.0.SelfAttention.k.weight
1.layer.0.SelfAttention.v.weight
1.layer.0.SelfAttention.o.weight
1.layer.0.layer_norm.weight
1.layer.1.EncDecAttention.q.weight
1.layer.1.EncDecAttention.k.weight
1.layer.1.EncDecAttention.v.weight
1.layer.1.EncDecAttention.o.weight
1.layer.1.layer_norm.weight
1.layer.2.DenseReluDense.wi_0.weight
1.layer.2.DenseReluDense.wi_1.weight
1.layer.2.DenseReluDense.wo.weight
1.layer.2.layer_norm.weight
2.layer.0.SelfA

ModuleList(
  (0): T5Block(
    (layer): ModuleList(
      (0): T5LayerSelfAttention(
        (SelfAttention): T5Attention(
          (q): Linear(in_features=768, out_features=768, bias=False)
          (k): Linear(in_features=768, out_features=768, bias=False)
          (v): Linear(in_features=768, out_features=768, bias=False)
          (o): Linear(in_features=768, out_features=768, bias=False)
          (relative_attention_bias): Embedding(32, 12)
        )
        (layer_norm): T5LayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (1): T5LayerCrossAttention(
        (EncDecAttention): T5Attention(
          (q): Linear(in_features=768, out_features=768, bias=False)
          (k): Linear(in_features=768, out_features=768, bias=False)
          (v): Linear(in_features=768, out_features=768, bias=False)
          (o): Linear(in_features=768, out_features=768, bias=False)
        )
        (layer_norm): T5LayerNorm()
        (dropout): Dropout(p=0.1, inplace=Fals

In [4]:
import pandas as pd
df = pd.read_csv('/data2/sagawa/t5chem/data/C_N_yield/MFF_FullCV_01/train.csv').drop_duplicates().reset_index(drop=True)
df['input'] = 'REACTANT:' + df['REACTANT']  + 'REAGENT:' + df['REAGENT'] + 'PRODUCT:' + df['PRODUCT']
# df = df[['input', 'YIELD']].drop_duplicates().reset_index(drop=True)

lens = df['input'].apply(lambda x: len(x))
# remove data that have too long inputs
# df = df[lens <= 512].reset_index(drop=True)

In [5]:
len(df), len(df[lens <= 512].reset_index(drop=True))

(2767, 2767)

In [8]:
from torch.nn import MSELoss

loss = MSELoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)
output

tensor(2.7675, grad_fn=<MseLossBackward0>)

In [9]:
output = loss(input.reshape(-1, 1), target.reshape(-1, 1))
output

tensor(2.7675, grad_fn=<MseLossBackward0>)