In [1]:
import pandas as pd
import random
import dgl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dgllife.model import model_zoo
from dgllife.utils import smiles_to_bigraph
from dgllife.utils import EarlyStopping, Meter
from dgllife.utils import AttentiveFPAtomFeaturizer
from dgllife.utils import AttentiveFPBondFeaturizer

import torch
import os
import random
import numpy as np
import ast

import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import pandas as pd
from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
from IPython.display import SVG, display
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import pickle
import argparse
from rdkit import RDLogger 
import warnings
warnings.filterwarnings("ignore")
RDLogger.DisableLog('rdApp.*') # switch off RDKit warning messages

In [2]:
from utils import get_values_at_positions, atom_finder, smiles_augmentation, concat_feature_reactive_atom, collate_molgraphs, Canon_SMILES_similarity
from model import AttentiveFPPredictor_rxn, weighted_binary_cross_entropy

In [3]:
atom_featurizer = AttentiveFPAtomFeaturizer(atom_data_field='hv')
bond_featurizer = AttentiveFPBondFeaturizer(bond_data_field='he')
n_feats = atom_featurizer.feat_size('hv')
e_feats = bond_featurizer.feat_size('he')
print( 'Number of features in graph : ' , n_feats)

Number of features in graph :  39


In [4]:
#Assign device 
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

In [5]:
df_elementary = pd.read_csv('elementary_step_100000.csv')
df_elementary

Unnamed: 0,smiles,atom_mapped,class_label
0,CC(C)(C)[P+]([Pd])(c1ccnn1-c1c(-c2ccccc2)nn(-c...,"[5, 39, 38]",0
1,COc1ccc([Pd](Br)[P+](c2nnn(-c3ccccc3)c2-c2c(OC...,"[56, 57, 51, 52]",1
2,COc1cc(F)cc([Pd](c2cn(S(=O)(=O)c3ccccc3)c3cccc...,"[7, 8, 9]",3
3,COc1cccc([Pd](N2CCC(O)CC2)[P+](c2ccc(N3CCN(C)C...,"[6, 7, 8]",3
4,CC(C)CN1CCN2CCN(CC(C)C)P1(=N[P+](C1CCCCC1)(C1C...,"[70, 71, 60, 61]",1
...,...,...,...
99995,CNCc1cccc([Pd](Cl)[P+](c2ccc3ccccc3c2-c2cccc3c...,"[8, 9, 44, 41]",2
99996,CC(C)(C)[P+]([Pd](Br)c1cnc2ccccc2c1)(C(C)(C)C)...,"[5, 28]",4
99997,CC(C)Oc1cccc(OC(C)C)c1-c1ccccc1[P+](C1CCCCC1)(...,"[59, 60, 54, 55]",1
99998,CCN(CC)c1ccc([P+]([Pd](Cl)c2ccccc2C#N)([C@]23C...,"[10, 46]",4


In [6]:
train_datasets_, test_datasets = train_test_split( df_elementary, test_size=0.2, random_state=42, shuffle = False)
train_datasets, valid_datasets = train_test_split( train_datasets_, test_size=0.125, random_state=42, shuffle = False)

In [7]:
valid_datasets = valid_datasets.reset_index(drop=True)
test_datasets = test_datasets.reset_index(drop=True)

In [8]:
train_augm_smiles = smiles_augmentation(train_datasets, 5, augmentation =False)
valid_augm_smiles = smiles_augmentation(valid_datasets, 5, augmentation =False)
test_augm_smiles = smiles_augmentation(test_datasets, 5, augmentation =False)

In [9]:
print ( 'Total number of reaction steps after SMILES augmentation : ', len(train_augm_smiles) + len(valid_augm_smiles)+ len(test_augm_smiles))

Total number of reaction steps after SMILES augmentation :  100000


In [10]:
def graph_generation(df_augm_smiles):    
    graph_for_rxn = []
    for i in range(len(df_augm_smiles)):
        graph_for_rxn.append(smiles_to_bigraph(df_augm_smiles[i][0], node_featurizer=atom_featurizer,edge_featurizer=bond_featurizer, canonical_atom_order=False))
    return graph_for_rxn

In [11]:
train_graph_for_rxn = graph_generation(train_augm_smiles)
valid_graph_for_rxn = graph_generation(valid_augm_smiles)
test_graph_for_rxn = graph_generation(test_augm_smiles)

In [12]:
train_graph_dataset = concat_feature_reactive_atom(train_graph_for_rxn, train_augm_smiles)
valid_graph_dataset = concat_feature_reactive_atom(valid_graph_for_rxn, valid_augm_smiles)
test_graph_dataset = concat_feature_reactive_atom(test_graph_for_rxn, test_augm_smiles)

In [13]:
train_loader = DataLoader(train_graph_dataset, batch_size=256,shuffle=False,
                          collate_fn=collate_molgraphs)
valid_loader = DataLoader(valid_graph_dataset, batch_size=256,shuffle=False,
                          collate_fn=collate_molgraphs)
test_loader = DataLoader(test_graph_dataset, batch_size=256,shuffle=False,
                          collate_fn=collate_molgraphs)


In [14]:
# Out of distribution dataloader preparation

In [15]:
df_ood_1000 = pd.read_csv('OOD_elementary_step_3647.csv')
print(df_ood_1000)
ood_augm_smiles = smiles_augmentation(df_ood_1000, 5, augmentation =False)
ood_graph_for_rxn = graph_generation(ood_augm_smiles)
ood_graph_dataset = concat_feature_reactive_atom(ood_graph_for_rxn, ood_augm_smiles)
ood_loader = DataLoader(ood_graph_dataset, batch_size=256,shuffle=False,
                          collate_fn=collate_molgraphs)


                                                 smiles       atom_mapped  \
0     [Pd][P+](c1ccccc1)(c1ccccc1)c1ccccc1.COC(=O)c1...       [0, 29, 30]   
1     COC(=O)c1cncc2c([Pd](Br)[P+](c3ccccc3)(c3ccccc...  [47, 48, 42, 43]   
2     COC(=O)c1cncc2c([Pd](Br)[P+](c3ccccc3)(c3ccccc...  [10, 11, 42, 43]   
3     COC(=O)c1cncc2c([Pd](c3cccc(OC)c3)[P+](c3ccccc...       [9, 10, 11]   
4     [Pd][P+](c1ccccc1)(c1ccccc1)c1ccccc1.Cc1cc(C(=...       [0, 31, 32]   
...                                                 ...               ...   
3642  FC(F)(F)c1ccc([Pd](Cl)[P+](C2CCCCC2)(C2CCCCC2)...    [8, 9, 35, 36]   
3643  C[Pd](c1ccc(C(F)(F)F)cc1C(Cl)(Cl)Cl)[P+](C1CCC...         [0, 1, 2]   
3644  CC(C)(C)[P+]([Pd])(c1cc2ccccc2[cH-]1)C(C)(C)C....       [5, 32, 33]   
3645  Cc1cc(C(=O)O)cc(S(=O)(=O)F)c1[Pd](Br)[P+](c1cc...  [14, 15, 38, 39]   
3646  Cc1ccc([Pd](c2c(C)cc(C(=O)O)cc2S(=O)(=O)F)[P+]...         [4, 5, 6]   

      class_label  
0               0  
1               1  
2              

In [16]:
# Modify the model to fit your classification task
model = AttentiveFPPredictor_rxn(node_feat_size=n_feats,
                                   edge_feat_size=e_feats,
                                   num_layers=2,
                                   num_timesteps=1,
                                   graph_feat_size=200,
                                   n_tasks=8,
                                   dropout=0.1
                                    )



In [17]:
model.to(device)

AttentiveFPPredictor_rxn(
  (gnn): AttentiveFPGNN(
    (init_context): GetContext(
      (project_node): Sequential(
        (0): Linear(in_features=39, out_features=200, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
      (project_edge1): Sequential(
        (0): Linear(in_features=49, out_features=200, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
      (project_edge2): Sequential(
        (0): Dropout(p=0.1, inplace=False)
        (1): Linear(in_features=400, out_features=1, bias=True)
        (2): LeakyReLU(negative_slope=0.01)
      )
      (attentive_gru): AttentiveGRU1(
        (edge_transform): Sequential(
          (0): Dropout(p=0.1, inplace=False)
          (1): Linear(in_features=200, out_features=200, bias=True)
        )
        (gru): GRUCell(200, 200)
      )
    )
    (gnn_layers): ModuleList(
      (0): GNNLayer(
        (project_edge): Sequential(
          (0): Dropout(p=0.1, inplace=False)
          (1): Linear(in_features=400, out

In [18]:
# Define loss function and optimizer
loss_fn_graph = nn.CrossEntropyLoss()
loss_fn_node = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.000001)

In [19]:
def run_a_train_epoch(n_epochs, epoch, model, data_loader, loss_criterion1, loss_criterion2, optimizer):
    model.train()
    losses = []
    
    y_true = []
    y_pred = []
    loss_node_app = []
    loss_graph_radomize_app = []
    y_true_node = []
    y_pred_node = []
    
    for batch_id, batch_data in enumerate(data_loader):
        
        smiles, bg, labels = batch_data
        
        bg = bg.to(device)
        labels = labels.to(device)
        n_feats_w_l = bg.ndata.pop('hv').to(device)
        e_feats_ = bg.edata.pop('he').to(device)
        n_feats_ = n_feats_w_l[:,:n_feats]
        prediction1, prediction2, graph_feat = model(bg, n_feats_, e_feats_)
        n_labels = n_feats_w_l[:,n_feats].unsqueeze(1)
    
        # Calculate the weights
        counts = torch.bincount(n_labels.view(-1).long())
        class_weights = 1.0 / counts.float()
        class_weights = class_weights / class_weights.sum()
    
        
        loss_graph = loss_fn_graph(prediction1, labels.squeeze(1).long())
        loss_node = weighted_binary_cross_entropy(prediction2,n_labels ,class_weights)
        loss = loss_graph + loss_node 
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss_graph.data.item())
        loss_node_app.append(loss_node.data.item())
        
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(prediction1.detach().cpu().numpy())
    
        y_true_node.extend(n_labels.cpu().numpy())
        y_pred_node.extend(prediction2.detach().cpu().numpy())

    
    total_loss = np.mean(losses)
    total_loss_node = np.mean(loss_node_app)
    total_loss_graph_random = np.mean(loss_graph_radomize_app)
    accuracy = accuracy_score(y_true, np.argmax(y_pred, axis=1))
    print('F1 score classification task:', f1_score(y_true,np.argmax(y_pred, axis=1), average='macro'))

    # Threshold for binary prediction
    threshold_1 = 0.5
    # Convert predicted probabilities to binary values
    y_pred_binary = [1 if pred >= threshold_1 else 0 for pred in np.concatenate(y_pred_node)]
    y_true_flat = np.concatenate(y_true_node)
    # Calculate accuracy score
    accuracy_node = accuracy_score(y_true_flat, y_pred_binary)
    print('F1 score reactive atom task:', f1_score(y_true_flat,np.array(y_pred_binary,dtype=np.float32), average='macro'))


    if epoch % 1 == 0:
        print('epoch {:d}/{:d},train_acc_classification {:.4f},train_node_acc {:.4f},train_loss {:.4f},train_node_loss {:.4f}'.format(
            epoch + 1, n_epochs, accuracy,accuracy_node, total_loss, total_loss_node))
    return accuracy, total_loss, labels, prediction1, y_true_node, y_pred_node, model

In [20]:
def run_a_valid_epoch(n_epochs, epoch, model, data_loader, loss_criterion1, loss_criterion2):
    model.eval()
    losses = []
    
    y_true = []
    y_pred = []
    loss_node_app = []
    loss_graph_radomize_app = []
    y_true_node = []
    y_pred_node = []

    
    with torch.no_grad():
        for batch_id, batch_data in enumerate(data_loader):
            
            smiles, bg, labels = batch_data
            
            bg = bg.to(device)
            labels = labels.to(device)
            n_feats_w_l = bg.ndata.pop('hv').to(device)
            e_feats_ = bg.edata.pop('he').to(device)
            n_feats_ = n_feats_w_l[:,:n_feats]
            prediction1, prediction2, graph_feat = model(bg, n_feats_, e_feats_)
            n_labels = n_feats_w_l[:,n_feats].unsqueeze(1)
        
            # Calculate the weights
            counts = torch.bincount(n_labels.view(-1).long())
            class_weights = 1.0 / counts.float()
            class_weights = class_weights / class_weights.sum()
            
            loss_graph = loss_fn_graph(prediction1, labels.squeeze(1).long())
            loss_node = weighted_binary_cross_entropy(prediction2,n_labels ,class_weights) #class_weights

            loss = loss_graph + loss_node 
            
            losses.append(loss_graph.data.item())
            loss_node_app.append(loss_node.data.item())
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(prediction1.detach().cpu().numpy())
        
            y_true_node.extend(n_labels.cpu().numpy())
            y_pred_node.extend(prediction2.detach().cpu().numpy())

    
    total_loss = np.mean(losses)
    total_loss_node = np.mean(loss_node_app)
    total_loss_graph_random = np.mean(loss_graph_radomize_app)
    accuracy = accuracy_score(y_true, np.argmax(y_pred, axis=1))
    print('F1 score classification task:', f1_score(y_true,np.argmax(y_pred, axis=1), average='macro'))

    # Threshold for binary prediction
    threshold_1 = 0.5
    # Convert predicted probabilities to binary values
    y_pred_binary = [1 if pred >= threshold_1 else 0 for pred in np.concatenate(y_pred_node)]
    y_true_flat = np.concatenate(y_true_node)
    # Calculate accuracy score
    accuracy_node = accuracy_score(y_true_flat, y_pred_binary)
    print('F1 score reactive atom task:', f1_score(y_true_flat,np.array(y_pred_binary,dtype=np.float32), average='macro'))

    if epoch % 1 == 0:
        print('epoch {:d}/{:d},valid_acc_classification {:.4f}, valid_node_acc {:.4f},valid_loss {:.4f},valid_node_loss {:.4f}'.format(
            epoch + 1, n_epochs, accuracy, accuracy_node, total_loss, total_loss_node))
    return accuracy, total_loss, labels, prediction1, y_true_node, y_pred_node, model

In [21]:
import time
st_time = time.time()
stopper = EarlyStopping(mode='higher', patience=5)
n_epochs = 5
for e in range(n_epochs):
    accuracy, total_loss, labels, prediction, y_true_node, y_pred_node, train_model= run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn_graph, loss_fn_node, optimizer)
    accuracy_, total_loss_, labels_, prediction_, y_true_node_, y_pred_node_, train_model_= run_a_valid_epoch(n_epochs, e, model, valid_loader, loss_fn_graph, loss_fn_node)

    #fn = 'model_' + str(e)
        #torch.save(train_model.state_dict(), fn)
en_time = time.time()
print('time required:', (en_time-st_time)/60)

F1 score classification task: 0.8474743598379123
F1 score reactive atom task: 0.5779244759219162
epoch 1/5,train_acc_classification 0.8731,train_node_acc 0.7493,train_loss 0.3512,train_node_loss 0.0455
F1 score classification task: 1.0
F1 score reactive atom task: 0.7857063021831172
epoch 1/5,valid_acc_classification 1.0000, valid_node_acc 0.9280,valid_loss 0.0006,valid_node_loss 0.0177
F1 score classification task: 0.9853952685086285
F1 score reactive atom task: 0.8427954227755883
epoch 2/5,train_acc_classification 0.9967,train_node_acc 0.9533,train_loss 0.0176,train_node_loss 0.0134
F1 score classification task: 0.9999181805540261
F1 score reactive atom task: 0.8476891967809697
epoch 2/5,valid_acc_classification 0.9999, valid_node_acc 0.9555,valid_loss 0.0030,valid_node_loss 0.0183
F1 score classification task: 0.9999880434324548
F1 score reactive atom task: 0.8595148759211537
epoch 3/5,train_acc_classification 1.0000,train_node_acc 0.9595,train_loss 0.0006,train_node_loss 0.0103
F1 

In [22]:
# Test accuracy calculation
accuracy_, total_loss_, labels_, prediction_, y_true_node_, y_pred_node_, train_model_= run_a_valid_epoch(1, 1, model, test_loader, loss_fn_graph, loss_fn_node)

F1 score classification task: 1.0
F1 score reactive atom task: 0.8712175687627253
epoch 2/1,valid_acc_classification 1.0000, valid_node_acc 0.9639,valid_loss 0.0000,valid_node_loss 0.0076


In [23]:
# OOD accuracy calculation
accuracy_, total_loss_, labels_, prediction_, y_true_node_, y_pred_node_, train_model_= run_a_valid_epoch(1, 1, model, ood_loader, loss_fn_graph, loss_fn_node)

F1 score classification task: 0.9937953987601318
F1 score reactive atom task: 0.8438283029242029
epoch 2/1,valid_acc_classification 0.9956, valid_node_acc 0.9493,valid_loss 0.0300,valid_node_loss 0.0212


In [26]:
# if you want to save your model run this cell
# fn = 'final_trained_ReactAIvate_model'
# torch.save(model.state_dict(), fn)