In [None]:
import os
import pickle
import torch.nn as nn
from models import Model_TGCN, batch_size
from data_pretreatment import func
import numpy as np
import pandas as pd
from rdkit import Chem
import torch
from torch.utils.data import DataLoader
import dgl
from dgllife.utils import *
from dgllife.model.model_zoo.gcn_predictor import GCNPredictor
torch.cuda.empty_cache()
import numpy as np
from sklearn.metrics import matthews_corrcoef
from dgllife.utils import smiles_to_bigraph
from dgllife.utils import EarlyStopping, Meter
from dgllife.utils import AttentiveFPAtomFeaturizer
from dgllife.utils import AttentiveFPBondFeaturizer
from dgllife.utils import SMILESToBigraph, ScaffoldSplitter, RandomSplitter
from functools import partial
from dgllife.model.model_zoo.attentivefp_predictor import AttentiveFPPredictor
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.metrics import roc_auc_score, matthews_corrcoef, recall_score, f1_score
import random
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.empty_cache()

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
mol = Chem.MolFromSmiles('CNC(=O)/C=C/c1cnc2ccc(OC(C)c3c(Cl)ccc(F)c3Cl)cc2c1')
atom_featurizer = AttentiveFPAtomFeaturizer(atom_data_field='hv')
bond_featurizer = AttentiveFPBondFeaturizer(bond_data_field='he')
n_feats = atom_featurizer.feat_size('hv')
e_feats = bond_featurizer.feat_size('he')
print(n_feats)
print(e_feats)

In [None]:
smiles_to_g = SMILESToBigraph(node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer)

In [None]:
def get_data(df):
    g = [smiles_to_g(s) for s in df['SMILES']]
    y = np.array(list((df['label'])))
    y = np.array(y, dtype=np.int64)
    return g, y

In [None]:
def train(epoch,train_d):
    gcn_net.train()
    y_l=[]
    y_p=[]
    train_epoch_loss, train_epoch_acc, train_epoch_r2 = 0, 0, 0
    for i, (X, list_num, graph, labels, index) in enumerate(train_d):
        train_labels = labels.to(device)
        n_feats = graph.ndata.pop('hv').to(device)
        e_feats = graph.edata.pop('he').to(device)
        #print(n_feats)
        #print(e_feats)
        n_feats, e_feats, train_labels = n_feats.to(device), e_feats.to(device), train_labels.to(device)
        graph = graph.to(device)
        y_ = gcn_net(graph, n_feats, e_feats).to('cpu')
        y = torch.reshape(y_, [batch_size])
        train_loss = nn.BCELoss()(y, train_labels.float().to('cpu'))
        optimizer.zero_grad()
        train_loss.requires_grad_(True)
        train_loss.backward()
        optimizer.step()
        train_epoch_loss += train_loss.detach().item()
        train_true_label = train_labels.to('cpu').numpy()
        yy = [1 if i >= 0.5 else 0 for i in y.detach().numpy()]
        train_epoch_acc += sum(train_true_label == yy)
        y_l.extend(train_labels.cpu().tolist())
        y_p.extend(y_.cpu().tolist())
    train_epoch_acc = train_epoch_acc / train_true_label.shape[0]
    train_epoch_acc /= (i + 1)
    train_epoch_loss /= (i + 1)
    auc = roc_auc_score(y_l, y_p)
    y_p_binary = [1 if p > 0.5 else 0 for p in np.array(y_p)]
    mcc = matthews_corrcoef(y_l, y_p_binary)
    y_l_np = np.array(y_l)
    y_p_binary_np = np.array(y_p_binary)
    recall_positive = recall_score(y_l_np, y_p_binary_np, pos_label=1)
    recall_negative = recall_score(y_l_np, y_p_binary_np, pos_label=0)
    ba = (recall_positive + recall_negative) / 2
    f1 = f1_score(y_l_np, y_p_binary_np)
    return train_epoch_acc,train_epoch_loss,auc, mcc,ba,f1

In [None]:
def test(epoch,test_d):
    epoch_loss, epoch_acc = 0, 0
    mlist = []
    gcn_net.eval()
    y_l=[]
    y_p=[]
    for i, (X, list_num, graph, labels, index) in enumerate(test_d):
        labels = labels.to(device)
        n_feats = graph.ndata.pop('hv').to(device)
        e_feats = graph.edata.pop('he').to(device)
        #print(n_feats)
        #print(e_feats)
        n_feats, e_feats, labels = n_feats.to(device), e_feats.to(device), labels.to(device)
        graph = graph.to(device)
        y_ = gcn_net(graph, n_feats, e_feats).to('cpu')
        y = torch.reshape(y_, [test_d.batch_size])
        loss = nn.BCELoss()(y, labels.float().to('cpu'))
        epoch_loss += loss.detach().item()
        pred_cls = y.detach().numpy()
        true_label = labels.to('cpu').numpy()
        yy = [1 if m >= 0.5 else 0 for m in y.detach().numpy()]
        mlist.extend(pred_cls)
        epoch_acc += sum(true_label == yy)
        y_l.extend(labels.cpu().tolist())
        y_p.extend(y_.cpu().tolist())
    auc = roc_auc_score(y_l, y_p)
    y_p_binary = [1 if p > 0.5 else 0 for p in np.array(y_p)]
    mcc = matthews_corrcoef(y_l, y_p_binary)
    epoch_acc = epoch_acc / true_label.shape[0]
    epoch_acc /= (i + 1)
    epoch_loss /= (i + 1)
    y_l_np = np.array(y_l)
    y_p_binary_np = np.array(y_p_binary)
    recall_positive = recall_score(y_l_np, y_p_binary_np, pos_label=1)
    recall_negative = recall_score(y_l_np, y_p_binary_np, pos_label=0)
    ba = (recall_positive + recall_negative) / 2
    f1 = f1_score(y_l_np, y_p_binary_np)

    return epoch_acc, epoch_loss, auc, mcc,ba,f1

In [None]:
def collate(sample):
    _, list_num, graphs, labels, index = map(list, zip(*sample))
    batched_graph = dgl.batch(graphs)
    batched_graph.set_n_initializer(dgl.init.zero_initializer)
    batched_graph.set_e_initializer(dgl.init.zero_initializer)
    return _, list_num, batched_graph, torch.tensor(labels), index

In [None]:
def eval(test_d):
    epoch_loss, epoch_acc = 0, 0
    mlist = []
    gcn_net.eval()
    y_l=[]
    y_p=[]
    for i, (X, list_num, graph, labels, index) in enumerate(test_d):
        labels = labels.to(device)
        n_feats = graph.ndata.pop('hv').to(device)
        e_feats = graph.edata.pop('he').to(device)
        #print(n_feats)
        #print(e_feats)
        n_feats, e_feats, labels = n_feats.to(device), e_feats.to(device), labels.to(device)
        graph = graph.to(device)
        y_ = gcn_net(graph, n_feats, e_feats).to('cpu')
        y = torch.reshape(y_, [test_d.batch_size])
        loss = nn.BCELoss()(y, labels.float().to('cpu'))
        epoch_loss += loss.detach().item()
        pred_cls = y.detach().numpy()
        true_label = labels.to('cpu').numpy()
        yy = [1 if m >= 0.5 else 0 for m in y.detach().numpy()]
        mlist.extend(pred_cls)
        epoch_acc += sum(true_label == yy)
        y_l.extend(labels.cpu().tolist())
        y_p.extend(y_.cpu().tolist())
    auc = roc_auc_score(y_l, y_p)
    y_p_binary = [1 if p > 0.5 else 0 for p in np.array(y_p)]
    mcc = matthews_corrcoef(y_l, y_p_binary)
    epoch_acc = epoch_acc / true_label.shape[0]
    epoch_acc /= (i + 1)
    epoch_loss /= (i + 1)
    y_l_np = np.array(y_l)
    y_p_binary_np = np.array(y_p_binary)
    recall_positive = recall_score(y_l_np, y_p_binary_np, pos_label=1)
    recall_negative = recall_score(y_l_np, y_p_binary_np, pos_label=0)
    ba = (recall_positive + recall_negative) / 2
    f1 = f1_score(y_l_np, y_p_binary_np)

    return epoch_acc, epoch_loss, auc, mcc,ba,f1

In [None]:
val_losses = []
val_acc = []
val_auc = []
val_mcc = []
val_ba = []
val_f1 = []
test_losses1 = []
test_acc1 = []
test_auc1 = []
test_mcc1 = []
test_ba1 = []
test_f11 = []

In [None]:
for i in range(5):
    best_param ={}
    best_param["roc_epoch"] = 0
    best_param["mcc_epoch"] = 0
    best_param["loss_epoch"] = 0
    best_param["score_epoch"] = 0
    best_param["valid_roc"] = 0
    best_param["valid_mcc"] = 0
    best_param["valid_score"] = 0
    best_param["valid_loss"] = 9e8
    print('Split '+str(i+1)+' ......')
    PATH_x_train = 'X_train{}.csv'.format(i+1)
    PATH_x_val = 'X_val{}.csv'.format(i+1)
    #PATH_x_test = 'X_test{}.csv'.format(i+1)
    
    
    df_seq_train, y_train_tensor, y_true_train, list_num_train = func(PATH_x_train)
    df_seq_val, y_val_tensor, y_true_val, list_num_val = func(PATH_x_val)
    #df_seq_test, y_test_tensor, y_true_test, list_num_test = func(PATH_x_test)
    
    train_X = pd.read_csv(PATH_x_train)
    x_train, y_train = get_data(train_X)
    train_data = list(zip(df_seq_train, list_num_train, x_train, y_train, [i for i in range(len(train_X))]))
    train_loader_ = DataLoader(train_data, batch_size=batch_size, shuffle=False, collate_fn=collate, drop_last=True)
   
    val_X = pd.read_csv(PATH_x_val)
    x_val, y_val = get_data(val_X)
    val_data = list(zip(df_seq_val, list_num_val, x_val, y_val, [i for i in range(len(val_X))]))
    val_loader_val = DataLoader(val_data, batch_size=90, shuffle=False, collate_fn=collate, drop_last=True)
    
    #test_X = pd.read_csv(PATH_x_test)
    #x_test, y_test = get_data(test_X)
    #test_data = list(zip(df_seq_test, list_num_test, x_test, y_test, [i for i in range(len(test_X))]))
    #test_loader_test = DataLoader(test_data, batch_size=1, shuffle=False, collate_fn=collate, drop_last=True)

    torch.cuda.empty_cache()
    gcn_net = AttentiveFPPredictor(node_feat_size=n_feats,
                                       edge_feat_size=e_feats,
                                       num_layers=2,
                                       num_timesteps=1,
                                       graph_feat_size=100,
                                       n_tasks=1,
                                       dropout=0.4
                                        )
    gcn_net = gcn_net.to("cuda")
    
    
    optimizer = torch.optim.Adam(gcn_net.parameters(), lr=0.0005)
    
    
    for epoch in range(800):
        train_accuracy,train_loss,train_auc, train_mcc,train_ba,train_f1 = train(epoch,train_loader_)
        test_accuracy,test_loss,test_auc, test_mcc,test_ba,test_f1 = test(epoch,val_loader_val)

        if test_auc > best_param["valid_roc"]:
            best_param["roc_epoch"] = epoch
            best_param["valid_roc"] = test_auc

        if test_mcc > best_param["valid_mcc"]:
            best_param["mcc_epoch"] = epoch
            best_param["valid_mcc"] = test_mcc
            
        if test_mcc*0.3+test_auc*0.7 > best_param["valid_score"]:
            best_param["score_epoch"] = epoch
            best_param["valid_score"] = test_mcc*0.3+test_auc*0.7
            torch.save(gcn_net, '{}_gcn.pt'.format(i+1))
        



        print("EPOCH:\t"+str(epoch)+'\n'\
            +"train_roc"+":"+str(train_auc)+'\n'\
            +"valid_roc"+":"+str(test_auc)+'\n'\
            +"train_mcc"+":"+str(train_mcc)+'\n'\
            +"valid_mcc"+":"+str(test_mcc)+'\n'\
            )
        if (epoch - best_param["roc_epoch"] >15) and (epoch - best_param["mcc_epoch"] >20):        
            break

    gcn_net = torch.load('{}_gcn.pt'.format(i+1),weights_only=False)


    val_epoch_acc, val_epoch_loss, val_epoch_auc,val_epoch_mcc,val_epoch_ba,val_epoch_f1 = eval(val_loader_val)
    #test_epoch_acc, test_epoch_loss, test_epoch_auc,test_epoch_mcc,test_epoch_ba,test_epoch_f1 = eval(test_loader_test)


    val_losses.append(val_epoch_loss)
    val_acc.append(val_epoch_acc)
    val_auc.append(val_epoch_auc)
    val_mcc.append(val_epoch_mcc)
    val_ba.append(val_epoch_ba)
    val_f1.append(val_epoch_f1)
    #test_losses1.append(test_epoch_loss)
    #test_acc1.append(test_epoch_acc)
    #test_auc1.append(test_epoch_auc)
    #test_mcc1.append(test_epoch_mcc)
    #test_ba1.append(test_epoch_ba)
    #test_f11.append(test_epoch_f1)

In [None]:
print('val_losses:',np.mean(val_losses),'+/-',np.std(val_losses))
print('val_acc:',np.mean(val_acc),'+/-',np.std(val_acc))
print('val_auc:',np.mean(val_auc),'+/-',np.std(val_auc))
print('val_mcc:',np.mean(val_mcc),'+/-',np.std(val_mcc))
print('val_ba:',np.mean(val_ba),'+/-',np.std(val_ba))
print('val_f1:',np.mean(val_f1),'+/-',np.std(val_f1))