In [1]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils import data
from torch import nn 
import copy

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
from time import time
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, roc_curve, confusion_matrix, precision_score, recall_score, auc
from sklearn.model_selection import KFold
torch.manual_seed(1)    # reproducible torch:2 np:3
np.random.seed(1)

from config import BIN_config_DBPE
from models import BIN_Interaction_Flat, BIN_Transformer_Single
from stream import BIN_Data_Encoder, BIN_combined_encoder

use_cuda = torch.cuda.is_available()
print(use_cuda)
device = torch.device("cuda:0" if use_cuda else "cpu")

False


In [2]:
import sys
print(sys.executable)

/Users/zdx_macos/.pyenv/versions/meta-rl/bin/python3


In [3]:
def test(data_generator, model):
    y_pred = []
    y_label = []
    model.eval()
    loss_accumulate = 0.0
    count = 0.0
#     for i, (d, p, d_mask, p_mask, label) in enumerate(data_generator):
# #         score = model(d.long().cuda(), p.long().cuda(), d_mask.long().cuda(), p_mask.long().cuda())
#         score = model(d.long(), p.long(), d_mask.long(), p_mask.long())
#         m = torch.nn.Sigmoid()
#         logits = torch.squeeze(m(score))
        
#         loss_fct = torch.nn.BCELoss()            
        
# #         label = Variable(torch.from_numpy(np.array(label)).float()).cuda()
#         label = Variable(torch.from_numpy(np.array(label)).float())

#         loss = loss_fct(logits, label)
        
#         loss_accumulate += loss
#         count += 1
        
#         logits = logits.detach().cpu().numpy()
        
#         label_ids = label.to('cpu').numpy()
#         y_label = y_label + label_ids.flatten().tolist()
#         y_pred = y_pred + logits.flatten().tolist()
        
    for i, (feature, mask, label) in enumerate(data_generator):
#         score = model(d.long().cuda(), p.long().cuda(), d_mask.long().cuda(), p_mask.long().cuda())
        score = model(feature.long(), mask.long())
        m = torch.nn.Sigmoid()
        logits = torch.squeeze(m(score))
        
        loss_fct = torch.nn.BCELoss()            
        
#         label = Variable(torch.from_numpy(np.array(label)).float()).cuda()
        label = Variable(torch.from_numpy(np.array(label)).float())

        loss = loss_fct(logits, label)
        
        loss_accumulate += loss
        count += 1
        
        logits = logits.detach().cpu().numpy()
        
        label_ids = label.to('cpu').numpy()
        y_label = y_label + label_ids.flatten().tolist()
        y_pred = y_pred + logits.flatten().tolist()
        
    loss = loss_accumulate/count
    
    fpr, tpr, thresholds = roc_curve(y_label, y_pred)

    precision = tpr / (tpr + fpr)

    f1 = 2 * precision * tpr / (tpr + precision + 0.00001)

    thred_optim = thresholds[5:][np.argmax(f1[5:])]

    print("optimal threshold: " + str(thred_optim))

    y_pred_s = [1 if i else 0 for i in (y_pred >= thred_optim)]

    auc_k = auc(fpr, tpr)
    print("AUROC:" + str(auc_k))
    print("AUPRC: "+ str(average_precision_score(y_label, y_pred)))

    cm1 = confusion_matrix(y_label, y_pred_s)
    print('Confusion Matrix : \n', cm1)
    print('Recall : ', recall_score(y_label, y_pred_s))
    print('Precision : ', precision_score(y_label, y_pred_s))

    total1=sum(sum(cm1))
    #####from confusion matrix calculate accuracy
    accuracy1=(cm1[0,0]+cm1[1,1])/total1
    print ('Accuracy : ', accuracy1)

    sensitivity1 = cm1[0,0]/(cm1[0,0]+cm1[0,1])
    print('Sensitivity : ', sensitivity1 )

    specificity1 = cm1[1,1]/(cm1[1,0]+cm1[1,1])
    print('Specificity : ', specificity1)

    outputs = np.asarray([1 if i else 0 for i in (np.asarray(y_pred) >= 0.5)])
    return roc_auc_score(y_label, y_pred), average_precision_score(y_label, y_pred), f1_score(y_label, outputs), y_pred, loss.item()


def main(fold_n, lr):
    config = BIN_config_DBPE()
    
    lr = lr
    BATCH_SIZE = config['batch_size']
    train_epoch = 10
    
    loss_history = []
    
#     model = BIN_Interaction_Flat(**config)
    model = BIN_Transformer_Single(**config)
    
    if use_cuda:
        model = model.cuda()

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model, dim = 0)
            
    opt = torch.optim.Adam(model.parameters(), lr = lr)
    #opt = torch.optim.SGD(model.parameters(), lr = lr, momentum=0.9)
    
    print('--- Data Preparation ---')
    
    params = {'batch_size': BATCH_SIZE,
              'shuffle': True,
              'num_workers': 6, 
              'drop_last': True}

    dataFolder = './dataset/BindingDB'
    df_train = pd.read_csv(dataFolder + '/train.csv')
    df_val = pd.read_csv(dataFolder + '/val.csv')
    df_test = pd.read_csv(dataFolder + '/test.csv')
    
#     training_set = BIN_Data_Encoder(df_train.index.values, df_train.Label.values, df_train)
#     training_generator = data.DataLoader(training_set, **params)

#     validation_set = BIN_Data_Encoder(df_val.index.values, df_val.Label.values, df_val)
#     validation_generator = data.DataLoader(validation_set, **params)
    
#     testing_set = BIN_Data_Encoder(df_test.index.values, df_test.Label.values, df_test)
#     testing_generator = data.DataLoader(testing_set, **params)
    
    train_set = BIN_combined_encoder(df_train.index.values, df_train.Label.values, df_train)
    train_gen = data.DataLoader(train_set, **params)
    
    valid_set = BIN_combined_encoder(df_val.index.values, df_val.Label.values, df_val)
    valid_gen = data.DataLoader(valid_set, **params)
    
    test_set = BIN_combined_encoder(df_test.index.values, df_test.Label.values, df_test)
    test_gen = data.DataLoader(test_set, **params)
    # early stopping
    max_auc = 0
    model_max = copy.deepcopy(model)
    
    print('--- Go for Training ---')
    torch.backends.cudnn.benchmark = True
    for epo in range(train_epoch):
        model.train()
        print(len(train_gen))
        for i, (feature, mask, label) in enumerate(train_gen):
            score = model(feature.long(), mask.long())
            label = Variable(torch.from_numpy(np.array(label)).float())
            loss_fct = torch.nn.BCELoss()
            m = torch.nn.Sigmoid()
            n = torch.squeeze(m(score))
            
            loss = loss_fct(n, label)
            loss_history.append(loss)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            if (i % 100 == 0):
                print('Training at Epoch ' + str(epo + 1) + ' iteration ' + str(i) + ' with loss ' + str(loss.cpu().detach().numpy()))
            if (i % 1000 == 0):
                # every epoch test
                with torch.set_grad_enabled(False):
                    auc, auprc, f1, logits, loss = test(valid_gen, model)
                    if auc > max_auc:
                        model_max = copy.deepcopy(model)
                        max_auc = auc
            
                print('Validation at Epoch '+ str(epo + 1) + ' , AUROC: '+ str(auc) + ' , AUPRC: ' + str(auprc) + ' , F1: '+str(f1))
            
#         for i, (d, p, d_mask, p_mask, label) in enumerate(training_generator):
# #             score = model(d.long().cuda(), p.long().cuda(), d_mask.long().cuda(), p_mask.long().cuda())
#             score = model(d.long(), p.long(), d_mask.long(), p_mask.long())
# #             label = Variable(torch.from_numpy(np.array(label)).float()).cuda()
#             label = Variable(torch.from_numpy(np.array(label)).float())
#             loss_fct = torch.nn.BCELoss()
#             m = torch.nn.Sigmoid()
#             n = torch.squeeze(m(score))
            
#             loss = loss_fct(n, label)
#             loss_history.append(loss)
            
#             opt.zero_grad()
#             loss.backward()
#             opt.step()
            
#             if (i % 100 == 0):
#                 print('Training at Epoch ' + str(epo + 1) + ' iteration ' + str(i) + ' with loss ' + str(loss.cpu().detach().numpy()))
            # every epoch test
        with torch.set_grad_enabled(False):
            auc, auprc, f1, logits, loss = test(valid_gen, model)
            if auc > max_auc:
                model_max = copy.deepcopy(model)
                max_auc = auc
        
    print('--- Go for Testing ---')
    try:
        with torch.set_grad_enabled(False):
            auc, auprc, f1, logits, loss = test(test_gen, model_max)
            print('Testing AUROC: ' + str(auc) + ' , AUPRC: ' + str(auprc) + ' , F1: '+str(f1) + ' , Test loss: '+str(loss))
    except:
        print('testing failed')
    return model_max, loss_history

In [None]:
# fold 1
#biosnap interaction times 1e-6, flat, batch size 64, len 205, channel 3, epoch 50
s = time()
model_max, loss_history = main(1, 5e-6)
e = time()
print(e-s)
lh = list(filter(lambda x: x < 1, loss_history))
plt.plot(lh)

--- Data Preparation ---
--- Go for Training ---
3167
Training at Epoch 1 iteration 0 with loss 0.61577725


  precision = tpr / (tpr + fpr)


optimal threshold: 0.4846193492412567
AUROC:0.4983798957631048
AUPRC: 0.13455971626822696
Confusion Matrix : 
 [[ 110 5607]
 [   5  922]]
Recall :  0.9946062567421791
Precision :  0.1412161127278297
Accuracy :  0.15532811559301626
Sensitivity :  0.01924086059121917
Specificity :  0.9946062567421791
Validation at Epoch 1 , AUROC: 0.4983798957631048 , AUPRC: 0.13455971626822696 , F1: 0.06906077348066299
Training at Epoch 1 iteration 100 with loss 0.69492555
Training at Epoch 1 iteration 200 with loss 0.7137166
Training at Epoch 1 iteration 300 with loss 0.67332923
Training at Epoch 1 iteration 400 with loss 0.606265
Training at Epoch 1 iteration 500 with loss 0.6076871
Training at Epoch 1 iteration 600 with loss 0.54651904
Training at Epoch 1 iteration 700 with loss 0.7958389
Training at Epoch 1 iteration 800 with loss 0.7409822
Training at Epoch 1 iteration 900 with loss 0.735723
Training at Epoch 1 iteration 1000 with loss 0.7000242


  precision = tpr / (tpr + fpr)


optimal threshold: 0.21927951276302338
AUROC:0.8113720146900019
AUPRC: 0.40630383336549664
Confusion Matrix : 
 [[3100 2617]
 [ 108  819]]
Recall :  0.883495145631068
Precision :  0.23835855646100115
Accuracy :  0.589855508729681
Sensitivity :  0.5422424348434494
Specificity :  0.883495145631068
Validation at Epoch 1 , AUROC: 0.8113720146900019 , AUPRC: 0.40630383336549664 , F1: 0.33216538192011213
Training at Epoch 1 iteration 1100 with loss 0.3320628
Training at Epoch 1 iteration 1200 with loss 0.497371
Training at Epoch 1 iteration 1300 with loss 0.9707896
Training at Epoch 1 iteration 1400 with loss 0.35332763
Training at Epoch 1 iteration 1500 with loss 0.69269955
Training at Epoch 1 iteration 1600 with loss 0.5850935
Training at Epoch 1 iteration 1700 with loss 0.43192643
Training at Epoch 1 iteration 1800 with loss 0.7938974
Training at Epoch 1 iteration 1900 with loss 0.24118125
Training at Epoch 1 iteration 2000 with loss 0.42056596


  precision = tpr / (tpr + fpr)


optimal threshold: 0.3317345380783081
AUROC:0.8399227195561073
AUPRC: 0.4467011329092356
Confusion Matrix : 
 [[3656 2061]
 [ 114  813]]
Recall :  0.8770226537216829
Precision :  0.2828810020876827
Accuracy :  0.6726369656833233
Sensitivity :  0.639496239286339
Specificity :  0.8770226537216829
Validation at Epoch 1 , AUROC: 0.8399227195561073 , AUPRC: 0.4467011329092356 , F1: 0.4850976361767729
Training at Epoch 1 iteration 2100 with loss 0.19934243
Training at Epoch 1 iteration 2200 with loss 0.65585375
Training at Epoch 1 iteration 2300 with loss 1.2823445
Training at Epoch 1 iteration 2400 with loss 0.8955481
Training at Epoch 1 iteration 2500 with loss 0.40252626
Training at Epoch 1 iteration 2600 with loss 0.5986549
Training at Epoch 1 iteration 2700 with loss 0.32788062
Training at Epoch 1 iteration 2800 with loss 0.21964283
Training at Epoch 1 iteration 2900 with loss 0.39691642
Training at Epoch 1 iteration 3000 with loss 0.47516686


  precision = tpr / (tpr + fpr)


optimal threshold: 0.5726822018623352
AUROC:0.8584557044141896
AUPRC: 0.4720614615047078
Confusion Matrix : 
 [[4007 1710]
 [ 121  806]]
Recall :  0.8694714131607335
Precision :  0.3203497615262321
Accuracy :  0.7244130042143287
Sensitivity :  0.7008920762637747
Specificity :  0.8694714131607335
Validation at Epoch 1 , AUROC: 0.8584557044141896 , AUPRC: 0.4720614615047078 , F1: 0.4373208235600729
Training at Epoch 1 iteration 3100 with loss 0.29240456


  precision = tpr / (tpr + fpr)


optimal threshold: 0.4996640682220459
AUROC:0.8588705424254655
AUPRC: 0.4733989777048316
Confusion Matrix : 
 [[4130 1587]
 [ 131  796]]
Recall :  0.8586839266450917
Precision :  0.3340327318506085
Accuracy :  0.7414208308248044
Sensitivity :  0.722406856743047
Specificity :  0.8586839266450917
3167
Training at Epoch 2 iteration 0 with loss 0.5406074


  precision = tpr / (tpr + fpr)


optimal threshold: 0.43817105889320374
AUROC:0.8588199731341206
AUPRC: 0.47234001928282277
Confusion Matrix : 
 [[4041 1676]
 [ 125  802]]
Recall :  0.8651564185544768
Precision :  0.32364810330912025
Accuracy :  0.7289283564118001
Sensitivity :  0.7068392513556061
Specificity :  0.8651564185544768
Validation at Epoch 2 , AUROC: 0.8588199731341206 , AUPRC: 0.47234001928282277 , F1: 0.49417098445595853
Training at Epoch 2 iteration 100 with loss 0.44319266
Training at Epoch 2 iteration 200 with loss 0.2846982
Training at Epoch 2 iteration 300 with loss 0.4985318
Training at Epoch 2 iteration 400 with loss 0.27463728
Training at Epoch 2 iteration 500 with loss 0.802832
Training at Epoch 2 iteration 600 with loss 0.46320063
Training at Epoch 2 iteration 700 with loss 0.19527961
Training at Epoch 2 iteration 800 with loss 0.45500228
Training at Epoch 2 iteration 900 with loss 0.37262326
Training at Epoch 2 iteration 1000 with loss 0.8798284


  precision = tpr / (tpr + fpr)


optimal threshold: 0.42454206943511963
AUROC:0.8642106973297716
AUPRC: 0.48650331455527596
Confusion Matrix : 
 [[4069 1648]
 [ 119  808]]
Recall :  0.8716289104638619
Precision :  0.3289902280130293
Accuracy :  0.7340457555689344
Sensitivity :  0.7117369249606437
Specificity :  0.8716289104638619
Validation at Epoch 2 , AUROC: 0.8642106973297716 , AUPRC: 0.48650331455527596 , F1: 0.4987046632124352
Training at Epoch 2 iteration 1100 with loss 0.95102763
Training at Epoch 2 iteration 1200 with loss 0.38340774
Training at Epoch 2 iteration 1300 with loss 0.17756128
Training at Epoch 2 iteration 1400 with loss 0.52574444
Training at Epoch 2 iteration 1500 with loss 0.49427295
Training at Epoch 2 iteration 1600 with loss 0.5835446
Training at Epoch 2 iteration 1700 with loss 0.3931974
Training at Epoch 2 iteration 1800 with loss 0.5070167
Training at Epoch 2 iteration 1900 with loss 0.16478011
Training at Epoch 2 iteration 2000 with loss 0.4687624


  precision = tpr / (tpr + fpr)


optimal threshold: 0.36018383502960205
AUROC:0.8686909855898275
AUPRC: 0.491732100548838
Confusion Matrix : 
 [[3905 1812]
 [  94  833]]
Recall :  0.8985976267529665
Precision :  0.3149338374291115
Accuracy :  0.7131246237206502
Sensitivity :  0.6830505509882806
Specificity :  0.8985976267529665
Validation at Epoch 2 , AUROC: 0.8686909855898275 , AUPRC: 0.491732100548838 , F1: 0.5080321285140562
Training at Epoch 2 iteration 2100 with loss 0.25949788
Training at Epoch 2 iteration 2200 with loss 0.16244513
Training at Epoch 2 iteration 2300 with loss 0.5551017
Training at Epoch 2 iteration 2400 with loss 0.52883434
Training at Epoch 2 iteration 2500 with loss 0.7644173
Training at Epoch 2 iteration 2600 with loss 0.866174
Training at Epoch 2 iteration 2700 with loss 0.08938089
Training at Epoch 2 iteration 2800 with loss 0.35048765
Training at Epoch 2 iteration 2900 with loss 0.96930885
Training at Epoch 2 iteration 3000 with loss 0.3660919


  precision = tpr / (tpr + fpr)


optimal threshold: 0.27352046966552734
AUROC:0.8676690330453336
AUPRC: 0.4934242986869045
Confusion Matrix : 
 [[4052 1665]
 [ 104  823]]
Recall :  0.8878101402373247
Precision :  0.3307877813504823
Accuracy :  0.733744732089103
Sensitivity :  0.708763337414728
Specificity :  0.8878101402373247
Validation at Epoch 2 , AUROC: 0.8676690330453336 , AUPRC: 0.4934242986869045 , F1: 0.5263940520446097
Training at Epoch 2 iteration 3100 with loss 0.19492024


  precision = tpr / (tpr + fpr)


optimal threshold: 0.47487708926200867
AUROC:0.8687277804100226
AUPRC: 0.4968558959667218
Confusion Matrix : 
 [[4048 1669]
 [ 102  825]]
Recall :  0.889967637540453
Precision :  0.33079390537289494
Accuracy :  0.7334437086092715
Sensitivity :  0.7080636697568655
Specificity :  0.889967637540453
3167
Training at Epoch 3 iteration 0 with loss 0.6430591


  precision = tpr / (tpr + fpr)


optimal threshold: 0.518994927406311
AUROC:0.8676809206026275
AUPRC: 0.4961984321636985
Confusion Matrix : 
 [[4044 1673]
 [ 106  821]]
Recall :  0.8856526429341963
Precision :  0.32919005613472335
Accuracy :  0.7322396146899458
Sensitivity :  0.707364002099003
Specificity :  0.8856526429341963
Validation at Epoch 3 , AUROC: 0.8676809206026275 , AUPRC: 0.4961984321636985 , F1: 0.46992054483541434
Training at Epoch 3 iteration 100 with loss 0.53027356
Training at Epoch 3 iteration 200 with loss 0.38584924
Training at Epoch 3 iteration 300 with loss 0.20619969
Training at Epoch 3 iteration 400 with loss 0.22243239
Training at Epoch 3 iteration 500 with loss 0.31105506
Training at Epoch 3 iteration 600 with loss 0.15269217
Training at Epoch 3 iteration 700 with loss 0.2923265
Training at Epoch 3 iteration 800 with loss 0.21346644
Training at Epoch 3 iteration 900 with loss 0.21843973
Training at Epoch 3 iteration 1000 with loss 0.54255563


  precision = tpr / (tpr + fpr)


optimal threshold: 0.3297687768936157
AUROC:0.8690642171505751
AUPRC: 0.49265638884292157
Confusion Matrix : 
 [[4092 1625]
 [ 107  820]]
Recall :  0.8845738942826321
Precision :  0.33537832310838445
Accuracy :  0.7393136664659844
Sensitivity :  0.7157600139933532
Specificity :  0.8845738942826321
Validation at Epoch 3 , AUROC: 0.8690642171505751 , AUPRC: 0.49265638884292157 , F1: 0.5201129146083274
Training at Epoch 3 iteration 1100 with loss 0.36407232
Training at Epoch 3 iteration 1200 with loss 0.69899493
Training at Epoch 3 iteration 1300 with loss 0.45425254
Training at Epoch 3 iteration 1400 with loss 0.14813574
Training at Epoch 3 iteration 1500 with loss 0.39443246
Training at Epoch 3 iteration 1600 with loss 0.2902448
Training at Epoch 3 iteration 1700 with loss 0.12585858
Training at Epoch 3 iteration 1800 with loss 0.1421633
Training at Epoch 3 iteration 1900 with loss 0.24421185
Training at Epoch 3 iteration 2000 with loss 1.317399


  precision = tpr / (tpr + fpr)


optimal threshold: 0.3346344828605652
AUROC:0.8700582433699978
AUPRC: 0.5064592862751999
Confusion Matrix : 
 [[4004 1713]
 [  93  834]]
Recall :  0.8996763754045307
Precision :  0.32744405182567726
Accuracy :  0.7281757977122215
Sensitivity :  0.7003673255203778
Specificity :  0.8996763754045307
Validation at Epoch 3 , AUROC: 0.8700582433699978 , AUPRC: 0.5064592862751999 , F1: 0.5232308791994282
Training at Epoch 3 iteration 2100 with loss 0.28367004
Training at Epoch 3 iteration 2200 with loss 0.3160927
Training at Epoch 3 iteration 2300 with loss 0.6100583
Training at Epoch 3 iteration 2400 with loss 0.5121267
Training at Epoch 3 iteration 2500 with loss 0.8503914
Training at Epoch 3 iteration 2600 with loss 0.27623165
Training at Epoch 3 iteration 2700 with loss 0.39404356
Training at Epoch 3 iteration 2800 with loss 0.35859782
Training at Epoch 3 iteration 2900 with loss 0.7085028
Training at Epoch 3 iteration 3000 with loss 0.3573506


  precision = tpr / (tpr + fpr)


optimal threshold: 0.35699698328971863
AUROC:0.8743504063185952
AUPRC: 0.5156011951686714
Confusion Matrix : 
 [[3587 2130]
 [  55  872]]
Recall :  0.9406688241639698
Precision :  0.290473017988008
Accuracy :  0.6711318482841662
Sensitivity :  0.6274269721882106
Specificity :  0.9406688241639698
Validation at Epoch 3 , AUROC: 0.8743504063185952 , AUPRC: 0.5156011951686714 , F1: 0.48714069591528
Training at Epoch 3 iteration 3100 with loss 0.6029926
optimal threshold: 0.4508635103702545
AUROC:0.872596331197913
AUPRC: 0.5131072117735732
Confusion Matrix : 
 [[4132 1585]
 [ 112  815]]
Recall :  0.8791801510248112
Precision :  0.33958333333333335
Accuracy :  0.7445815773630343
Sensitivity :  0.7227566905719783
Specificity :  0.8791801510248112
3167


  precision = tpr / (tpr + fpr)


Training at Epoch 4 iteration 0 with loss 0.5726141


  precision = tpr / (tpr + fpr)


optimal threshold: 0.45463812351226807
AUROC:0.8730231511121754
AUPRC: 0.514534706551943
Confusion Matrix : 
 [[4165 1552]
 [ 115  812]]
Recall :  0.8759439050701187
Precision :  0.34348561759729274
Accuracy :  0.7490969295605057
Sensitivity :  0.728528948749344
Specificity :  0.8759439050701187
Validation at Epoch 4 , AUROC: 0.8730231511121754 , AUPRC: 0.514534706551943 , F1: 0.5111480865224626
Training at Epoch 4 iteration 100 with loss 0.9380932
Training at Epoch 4 iteration 200 with loss 0.50664675
Training at Epoch 4 iteration 300 with loss 0.25501102
Training at Epoch 4 iteration 400 with loss 0.23696336
Training at Epoch 4 iteration 500 with loss 0.62910193
Training at Epoch 4 iteration 600 with loss 0.40599987
Training at Epoch 4 iteration 700 with loss 0.8589818
Training at Epoch 4 iteration 800 with loss 0.15249366
Training at Epoch 4 iteration 900 with loss 0.45786887
Training at Epoch 4 iteration 1000 with loss 0.6256413


  precision = tpr / (tpr + fpr)


optimal threshold: 0.5452971458435059
AUROC:0.8713939896887706
AUPRC: 0.5074862687001727
Confusion Matrix : 
 [[4243 1474]
 [ 122  805]]
Recall :  0.8683926645091694
Precision :  0.35322509872751207
Accuracy :  0.7597832630945214
Sensitivity :  0.7421724680776631
Specificity :  0.8683926645091694
Validation at Epoch 4 , AUROC: 0.8713939896887706 , AUPRC: 0.5074862687001727 , F1: 0.4842543077837195
Training at Epoch 4 iteration 1100 with loss 0.4358356
Training at Epoch 4 iteration 1200 with loss 0.18841448
Training at Epoch 4 iteration 1300 with loss 0.55863225
Training at Epoch 4 iteration 1400 with loss 0.27078298
Training at Epoch 4 iteration 1500 with loss 0.4914625
Training at Epoch 4 iteration 1600 with loss 0.27992606
Training at Epoch 4 iteration 1700 with loss 0.45280135
Training at Epoch 4 iteration 1800 with loss 0.28927088
Training at Epoch 4 iteration 1900 with loss 0.08814472
Training at Epoch 4 iteration 2000 with loss 0.19959417


  precision = tpr / (tpr + fpr)


optimal threshold: 0.32143062353134155
AUROC:0.8750702639547184
AUPRC: 0.514047658958491
Confusion Matrix : 
 [[3992 1725]
 [  89  838]]
Recall :  0.9039913700107874
Precision :  0.3269605930550137
Accuracy :  0.7269717037928959
Sensitivity :  0.6982683225467903
Specificity :  0.9039913700107874
Validation at Epoch 4 , AUROC: 0.8750702639547184 , AUPRC: 0.514047658958491 , F1: 0.5186206896551725
Training at Epoch 4 iteration 2100 with loss 0.11176818
Training at Epoch 4 iteration 2200 with loss 0.25227857
Training at Epoch 4 iteration 2300 with loss 0.24332327
Training at Epoch 4 iteration 2400 with loss 0.7401309
Training at Epoch 4 iteration 2500 with loss 0.08320383
Training at Epoch 4 iteration 2600 with loss 0.19413322
Training at Epoch 4 iteration 2700 with loss 0.7347281
Training at Epoch 4 iteration 2800 with loss 0.67531663
Training at Epoch 4 iteration 2900 with loss 0.75467783
Training at Epoch 4 iteration 3000 with loss 0.67413193


  precision = tpr / (tpr + fpr)


optimal threshold: 0.3288610279560089
AUROC:0.8767998091952709
AUPRC: 0.4988578575171326
Confusion Matrix : 
 [[4093 1624]
 [  97  830]]
Recall :  0.895361380798274
Precision :  0.3382233088834556
Accuracy :  0.7409692956050572
Sensitivity :  0.7159349309078188
Specificity :  0.895361380798274
Validation at Epoch 4 , AUROC: 0.8767998091952709 , AUPRC: 0.4988578575171326 , F1: 0.529889298892989
Training at Epoch 4 iteration 3100 with loss 0.20807824


  precision = tpr / (tpr + fpr)


optimal threshold: 0.4555772542953491
AUROC:0.8785384116223327
AUPRC: 0.5153392637578263
Confusion Matrix : 
 [[4170 1547]
 [ 112  815]]
Recall :  0.8791801510248112
Precision :  0.34504657070279426
Accuracy :  0.7503010234798314
Sensitivity :  0.7294035333216722
Specificity :  0.8791801510248112
3167
Training at Epoch 5 iteration 0 with loss 0.74280906


  precision = tpr / (tpr + fpr)


optimal threshold: 0.4804984927177429
AUROC:0.8787537084933201
AUPRC: 0.5153346193157738
Confusion Matrix : 
 [[4216 1501]
 [ 116  811]]
Recall :  0.8748651564185544
Precision :  0.3507785467128028
Accuracy :  0.7566225165562914
Sensitivity :  0.7374497113870911
Specificity :  0.8748651564185544
Validation at Epoch 5 , AUROC: 0.8787537084933201 , AUPRC: 0.5153346193157738 , F1: 0.5036061461273127
Training at Epoch 5 iteration 100 with loss 0.29731187
Training at Epoch 5 iteration 200 with loss 0.2557396
Training at Epoch 5 iteration 300 with loss 0.32559615
Training at Epoch 5 iteration 400 with loss 0.37961954
Training at Epoch 5 iteration 500 with loss 0.6437191
Training at Epoch 5 iteration 600 with loss 0.60949945
Training at Epoch 5 iteration 700 with loss 0.5424442
Training at Epoch 5 iteration 800 with loss 0.26197898
Training at Epoch 5 iteration 900 with loss 0.4417314
Training at Epoch 5 iteration 1000 with loss 0.18043064


  precision = tpr / (tpr + fpr)


optimal threshold: 0.42705538868904114
AUROC:0.8747979822852754
AUPRC: 0.5052901430447462
Confusion Matrix : 
 [[4071 1646]
 [ 103  824]]
Recall :  0.8888888888888888
Precision :  0.33360323886639676
Accuracy :  0.7367549668874173
Sensitivity :  0.712086758789575
Specificity :  0.8888888888888888
Validation at Epoch 5 , AUROC: 0.8747979822852754 , AUPRC: 0.5052901430447462 , F1: 0.4977945809703844
Training at Epoch 5 iteration 1100 with loss 0.44116685
Training at Epoch 5 iteration 1200 with loss 0.61942273
Training at Epoch 5 iteration 1300 with loss 0.60887796
Training at Epoch 5 iteration 1400 with loss 0.24517645
Training at Epoch 5 iteration 1500 with loss 0.14967307
Training at Epoch 5 iteration 1600 with loss 1.0075098
Training at Epoch 5 iteration 1700 with loss 0.12782757
Training at Epoch 5 iteration 1800 with loss 1.2360415
Training at Epoch 5 iteration 1900 with loss 0.809906
Training at Epoch 5 iteration 2000 with loss 0.27326286


  precision = tpr / (tpr + fpr)


optimal threshold: 0.40213173627853394
AUROC:0.8780080001373674
AUPRC: 0.5197874189979754
Confusion Matrix : 
 [[3986 1731]
 [  89  838]]
Recall :  0.9039913700107874
Precision :  0.32619696379914365
Accuracy :  0.7260686333534015
Sensitivity :  0.6972188210599966
Specificity :  0.9039913700107874
Validation at Epoch 5 , AUROC: 0.8780080001373674 , AUPRC: 0.5197874189979754 , F1: 0.5081495685522531
Training at Epoch 5 iteration 2100 with loss 0.18143326
Training at Epoch 5 iteration 2200 with loss 0.47404596
Training at Epoch 5 iteration 2300 with loss 0.6512688
Training at Epoch 5 iteration 2400 with loss 0.46703595
Training at Epoch 5 iteration 2500 with loss 0.38641798
Training at Epoch 5 iteration 2600 with loss 0.37323907
Training at Epoch 5 iteration 2700 with loss 0.34269845
Training at Epoch 5 iteration 2800 with loss 0.12169886
Training at Epoch 5 iteration 2900 with loss 0.22300425
Training at Epoch 5 iteration 3000 with loss 0.6153275


  precision = tpr / (tpr + fpr)


optimal threshold: 0.3984602987766266
AUROC:0.878924474197302
AUPRC: 0.5311418998115128
Confusion Matrix : 
 [[4158 1559]
 [ 106  821]]
Recall :  0.8856526429341963
Precision :  0.3449579831932773
Accuracy :  0.7493979530403372
Sensitivity :  0.7273045303480846
Specificity :  0.8856526429341963
Validation at Epoch 5 , AUROC: 0.878924474197302 , AUPRC: 0.5311418998115128 , F1: 0.5209176788124157
Training at Epoch 5 iteration 3100 with loss 0.44426954


  precision = tpr / (tpr + fpr)


optimal threshold: 0.5869678854942322
AUROC:0.8791899629768634
AUPRC: 0.5277006620650021
Confusion Matrix : 
 [[4176 1541]
 [ 112  815]]
Recall :  0.8791801510248112
Precision :  0.34592529711375214
Accuracy :  0.7512040939193257
Sensitivity :  0.730453034808466
Specificity :  0.8791801510248112
3167
Training at Epoch 6 iteration 0 with loss 0.5905967


  precision = tpr / (tpr + fpr)


optimal threshold: 0.4684278666973114
AUROC:0.8806342068423647
AUPRC: 0.5279369481575501
Confusion Matrix : 
 [[4051 1666]
 [  98  829]]
Recall :  0.8942826321467098
Precision :  0.3322645290581162
Accuracy :  0.7344972907886815
Sensitivity :  0.7085884205002624
Specificity :  0.8942826321467098
Validation at Epoch 6 , AUROC: 0.8806342068423647 , AUPRC: 0.5279369481575501 , F1: 0.4923076923076923
Training at Epoch 6 iteration 100 with loss 0.37067845
Training at Epoch 6 iteration 200 with loss 0.14700752
Training at Epoch 6 iteration 300 with loss 0.3302036
Training at Epoch 6 iteration 400 with loss 0.1884324
Training at Epoch 6 iteration 500 with loss 0.70143217
Training at Epoch 6 iteration 600 with loss 0.42875558
Training at Epoch 6 iteration 700 with loss 0.13495888
Training at Epoch 6 iteration 800 with loss 0.12944032
Training at Epoch 6 iteration 900 with loss 0.4136523
Training at Epoch 6 iteration 1000 with loss 0.6433211


  precision = tpr / (tpr + fpr)


optimal threshold: 0.46199095249176025
AUROC:0.8804098527848678
AUPRC: 0.5253315584233873
Confusion Matrix : 
 [[4146 1571]
 [ 106  821]]
Recall :  0.8856526429341963
Precision :  0.3432274247491639
Accuracy :  0.7475918121613486
Sensitivity :  0.7252055273744971
Specificity :  0.8856526429341963
Validation at Epoch 6 , AUROC: 0.8804098527848678 , AUPRC: 0.5253315584233873 , F1: 0.5061301477522792
Training at Epoch 6 iteration 1100 with loss 0.55111456
Training at Epoch 6 iteration 1200 with loss 0.87118477
Training at Epoch 6 iteration 1300 with loss 0.59121567
Training at Epoch 6 iteration 1400 with loss 0.35409078
Training at Epoch 6 iteration 1500 with loss 0.43342987
Training at Epoch 6 iteration 1600 with loss 0.57846844
Training at Epoch 6 iteration 1700 with loss 0.9027875
Training at Epoch 6 iteration 1800 with loss 0.41834843
Training at Epoch 6 iteration 1900 with loss 0.52590233
Training at Epoch 6 iteration 2000 with loss 0.19708519


  precision = tpr / (tpr + fpr)


optimal threshold: 0.43920204043388367
AUROC:0.8793129897602846
AUPRC: 0.5333951982885149
Confusion Matrix : 
 [[4213 1504]
 [ 116  811]]
Recall :  0.8748651564185544
Precision :  0.35032397408207344
Accuracy :  0.7561709813365443
Sensitivity :  0.7369249606436943
Specificity :  0.8748651564185544
Validation at Epoch 6 , AUROC: 0.8793129897602846 , AUPRC: 0.5333951982885149 , F1: 0.515032679738562
Training at Epoch 6 iteration 2100 with loss 0.19356833
Training at Epoch 6 iteration 2200 with loss 0.20691343
Training at Epoch 6 iteration 2300 with loss 0.098263144
Training at Epoch 6 iteration 2400 with loss 0.14012179
Training at Epoch 6 iteration 2500 with loss 0.70674783
Training at Epoch 6 iteration 2600 with loss 0.20204628
Training at Epoch 6 iteration 2700 with loss 0.34325114
Training at Epoch 6 iteration 2800 with loss 0.20856737
Training at Epoch 6 iteration 2900 with loss 0.4556668
Training at Epoch 6 iteration 3000 with loss 0.23133174


  precision = tpr / (tpr + fpr)


optimal threshold: 0.46521618962287903
AUROC:0.8817976779260703
AUPRC: 0.5385472983112274
Confusion Matrix : 
 [[4089 1628]
 [ 101  826]]
Recall :  0.8910463861920173
Precision :  0.33659331703341483
Accuracy :  0.7397652016857315
Sensitivity :  0.7152352632499562
Specificity :  0.8910463861920173
Validation at Epoch 6 , AUROC: 0.8817976779260703 , AUPRC: 0.5385472983112274 , F1: 0.4946630070143336
Training at Epoch 6 iteration 3100 with loss 0.6039671
optimal threshold: 0.4373539388179779
AUROC:0.8809401699241404
AUPRC: 0.5375813481698636
Confusion Matrix : 
 [[4234 1483]
 [ 116  811]]
Recall :  0.8748651564185544
Precision :  0.35353095030514387
Accuracy :  0.7593317278747742
Sensitivity :  0.7405982158474724
Specificity :  0.8748651564185544
3167


  precision = tpr / (tpr + fpr)


Training at Epoch 7 iteration 0 with loss 0.65147936


  precision = tpr / (tpr + fpr)


optimal threshold: 0.4399678409099579
AUROC:0.8811787701812512
AUPRC: 0.5392874262979078
Confusion Matrix : 
 [[4237 1480]
 [ 116  811]]
Recall :  0.8748651564185544
Precision :  0.3539938891313837
Accuracy :  0.7597832630945214
Sensitivity :  0.7411229665908693
Specificity :  0.8748651564185544
Validation at Epoch 7 , AUROC: 0.8811787701812512 , AUPRC: 0.5392874262979078 , F1: 0.5208053691275167
Training at Epoch 7 iteration 100 with loss 0.9035216
Training at Epoch 7 iteration 200 with loss 0.37151074
Training at Epoch 7 iteration 300 with loss 0.65344715
Training at Epoch 7 iteration 400 with loss 0.37424666
Training at Epoch 7 iteration 500 with loss 0.118147954
Training at Epoch 7 iteration 600 with loss 0.31557864
Training at Epoch 7 iteration 700 with loss 0.11143083
Training at Epoch 7 iteration 800 with loss 0.519423
Training at Epoch 7 iteration 900 with loss 0.3015175
Training at Epoch 7 iteration 1000 with loss 0.22564462


  precision = tpr / (tpr + fpr)


optimal threshold: 0.30973008275032043
AUROC:0.8828647277117263
AUPRC: 0.5363769827367689
Confusion Matrix : 
 [[4014 1703]
 [  80  847]]
Recall :  0.9137001078748651
Precision :  0.33215686274509804
Accuracy :  0.731637567730283
Sensitivity :  0.7021164946650341
Specificity :  0.9137001078748651
Validation at Epoch 7 , AUROC: 0.8828647277117263 , AUPRC: 0.5363769827367689 , F1: 0.5301632363378282
Training at Epoch 7 iteration 1100 with loss 0.23497975
Training at Epoch 7 iteration 1200 with loss 0.7589859
Training at Epoch 7 iteration 1300 with loss 0.19997072
Training at Epoch 7 iteration 1400 with loss 1.1100525
Training at Epoch 7 iteration 1500 with loss 0.30895907
Training at Epoch 7 iteration 1600 with loss 0.2906143
Training at Epoch 7 iteration 1700 with loss 0.27142146


In [None]:
torch.tensor([1.0, 2.0]).cuda()

In [4]:
print(torch.__version__)

1.11.0.dev20211106+cu113


In [23]:
config = BIN_config_DBPE()

BATCH_SIZE = config['batch_size']

params = {'batch_size': BATCH_SIZE,
              'shuffle': True,
              'num_workers': 6, 
              'drop_last': True}

dataFolder = './dataset/BIOSNAP/full_data'
df_train = pd.read_csv(dataFolder + '/train.csv')
df_val = pd.read_csv(dataFolder + '/val.csv')
df_test = pd.read_csv(dataFolder + '/test.csv')

training_set = BIN_Data_Encoder(df_train.index.values, df_train.Label.values, df_train)
training_generator = data.DataLoader(training_set, **params)

In [24]:
df_train.head()

Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,Unnamed: 0.1.1,DrugBank ID,Gene,Label,SMILES,Target Sequence
0,0,3,4,DB08533,P49862,0.0,CC1=CN=C2N1C=CN=C2NCC1=CC=NC=C1,MARSLLLPLQILLLSLALETAGEEAQGDKIIDGAPCARGSHPWQVA...
1,1,4,5,DB00755,P48443,1.0,C\C(\C=C\C1=C(C)CCCC1(C)C)=C/C=C/C(/C)=C/C(O)=O,MYGNYSHFMKFPAGYGGSPGHTGSTSMSPSAALSTGKPMDSHPSYT...
2,2,5,6,DB00361,O60218,0.0,[H][C@@]12N(C)C3=CC(OC)=C(C=C3[C@@]11CCN3CC=C[...,MATFVELSTKAKMPIVGLGTWKSPLGKVKEAVKVAIDAGYRHIDCA...
3,3,7,8,DB01136,P08588,1.0,COC1=CC=CC=C1OCCNCC(O)COC1=CC=CC2=C1C1=CC=CC=C1N2,MGAGVLVLGASEPGNLSSAAPLPDGAATAARLLVPASPPASLLPPA...
4,4,8,9,DB06963,Q9Y691,0.0,[H][C@@](C)(NC1=CC2=C(C=N1)C(C)=NN2C1=CC=CC(CC...,MFIWTSGRTSSSYRHDEKRNIYQKIRDHDLLDKRKTVTALKAGEDR...


In [25]:
df_train["drug_encoding"][0]

KeyError: 'drug_encoding'

In [26]:
for i, (d, p, d_mask, p_mask, label) in enumerate(training_generator):
    print("DRUG:")
    print(d.shape)
    print(d)
    print("PROTEIN:")
    print(p.shape)
    print(p)
    print("DRUG MASK:")
    print(d_mask)
    print("PROTEIN MASK:")
    print(p_mask)
    print("LABEL")
    print(label.shape)
    print(label)
    break

DRUG:
torch.Size([16, 50])
tensor([[ 3291,   300,   471,   117,  7522,  1804,   147,  2268,    91,  1855,
            82,   211,   623,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    7,    66,  5763,    63,  3821, 13901,     7,    66,    11,  3901,
             7,    66,   189,  5374,   156,    10,     7,    66,    11,  1077,
          1176,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [10440, 18363,  6525, 20141,   569,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0, 