# GIN Model for oral bioavailability dataset 

1. This notebook focuses on building a GIN model for oral bioavailabilty dataset using the GIN convolutional technique which is available on Pytorch Geometric 
2. The hyperparameters were found using the TPE algorithm using Optuna library after 30 trials 
3. Train/validate/test using 5-fold CV and process was repeated for 10 times.

### Note
1. Ensure that this notebook is in the same working directory as data folder, config.py, utils.py, engine.py and model.py
2. To load models, please download the saved models provided in google drive link from README.md
3. Comment away the training function to load saved models and reproduce results

In [1]:
from utils import seed_everything, LoadHOBDataset
from config import SEED_NO, NUM_FEATURES, NUM_GRAPHS_PER_BATCH, NUM_TARGET, DEVICE, PATIENCE, EPOCHS, N_SPLITS, params_gin
from engine import EngineHOB_no_edge
from model import GIN

import torch
import numpy as np
import optuna
from sklearn.model_selection import KFold
from torch_geometric.loader import DataLoader
import os 

## 2. Tuning of model (Suggest to skip and used the hyperparameters saved in config.py)
1. To tune the model, ensure data folder and this notebook are in the same working directory. Then, we make use of the Optuna library which allows us to use the Tree-structure Parzen Estimator Algorithm to find the best hyperparameters for us 
2. First, create a run_tuning function to include the training and validation step with early stopping mechansim 
3. Then, create an objective function of Optuna to evaluate the best hyperparameters in 30 trials. 

In [6]:
def run_tuning(train_loader, valid_loader, params):
    '''
    This function controls the tuning step of the model.

    Args:
    train_loader: Pytorch geometric DataLoader Class of train dataset
    valid_loader: Pytorch geometric DataLoader Class of validation dataset 
    params (dict): dictionary containing the hyperparameters (num_layers, hidden_size and learning_rate)
    '''
    model = GIN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_layers=params['num_layers'], hidden_size=params['hidden_size'])
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB_no_edge(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0 

    for epoch in range(EPOCHS):
        train_loss = eng.train(train_loader)
        valid_loss = eng.validate(valid_loader)
        print(f'Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}')
        if valid_loss < best_loss:
            best_loss = valid_loss 
            early_stopping_counter=0 #reset counter
        else:
            early_stopping_counter +=1

        if early_stopping_counter > early_stopping_iter:
            print('Early stopping...')
            break
        print(f'Early stop counter: {early_stopping_counter}')
    
    return best_loss

In [6]:
def objective(trial):
    params = {
        'num_layers' : trial.suggest_int('num_layers', 1,3),
        'hidden_size' : trial.suggest_int('hidden_size', 64, 512),
        'learning_rate' : trial.suggest_float('learning_rate', 1e-3, 9e-3, log=True)
    }

    #load dataset 
    dataset_for_cv = LoadHOBDataset(root='./data/graph_data/data_oral_avail_train/', raw_filename='data_oral_avail_train_50.csv')
    kf = KFold(n_splits=N_SPLITS)
    fold_loss = 0

    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        print(f'Fold {fold_no}')
        train_dataset= []
        valid_dataset = []
        for t_idx in train_idx:
            train_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{t_idx}.pt'))
        for v_idx in valid_idx:
            valid_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{v_idx}.pt'))

        seed_everything(SEED_NO)
        train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)

        loss = run_tuning(train_loader, valid_loader, params)
        fold_loss += loss

    return fold_loss/5

study = optuna.create_study(direction = 'minimize')
study.optimize(objective, n_trials=30)
print(f'best trial:')
trial_ = study.best_trial
print(trial_.values)
print(f'Best parameters: {trial_.params}')

[32m[I 2023-03-08 17:26:41,550][0m A new study created in memory with name: no-name-4b8d77f9-2fca-4a34-8d44-8289bb1bb401[0m


Fold 0
Epoch: 1/300, train loss : 1.6848286837339401, validation loss : 0.6327188611030579
Early stop counter: 0
Epoch: 2/300, train loss : 0.7388698011636734, validation loss : 0.7309691309928894
Early stop counter: 1
Epoch: 3/300, train loss : 0.6848915070295334, validation loss : 0.6669426560401917
Early stop counter: 2
Epoch: 4/300, train loss : 0.6840784102678299, validation loss : 0.6986954808235168
Early stop counter: 3
Epoch: 5/300, train loss : 0.6805710345506668, validation loss : 0.658903181552887
Early stop counter: 4
Epoch: 6/300, train loss : 0.6755501329898834, validation loss : 0.7144209742546082
Early stop counter: 5
Epoch: 7/300, train loss : 0.680072546005249, validation loss : 0.6758288741111755
Early stop counter: 6
Epoch: 8/300, train loss : 0.6730537563562393, validation loss : 0.6894646286964417
Early stop counter: 7
Epoch: 9/300, train loss : 0.6720863878726959, validation loss : 0.6833751201629639
Early stop counter: 8
Epoch: 10/300, train loss : 0.66546277701

[32m[I 2023-03-08 17:26:50,785][0m Trial 0 finished with value: 0.646298611164093 and parameters: {'num_layers': 2, 'hidden_size': 135, 'learning_rate': 0.00693648872676605}. Best is trial 0 with value: 0.646298611164093.[0m


Epoch: 7/300, train loss : 0.6747991144657135, validation loss : 0.7073723673820496
Early stop counter: 6
Epoch: 8/300, train loss : 0.6734083294868469, validation loss : 0.697421133518219
Early stop counter: 7
Epoch: 9/300, train loss : 0.6735455095767975, validation loss : 0.7038393616676331
Early stop counter: 8
Epoch: 10/300, train loss : 0.6684256196022034, validation loss : 0.7008915543556213
Early stop counter: 9
Epoch: 11/300, train loss : 0.6659853756427765, validation loss : 0.700866162776947
Early stop counter: 10
Epoch: 12/300, train loss : 0.6623090207576752, validation loss : 0.7017859220504761
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.8657525181770325, validation loss : 0.6871082782745361
Early stop counter: 0
Epoch: 2/300, train loss : 0.6993994861841202, validation loss : 0.6894806027412415
Early stop counter: 1
Epoch: 3/300, train loss : 0.6888601183891296, validation loss : 0.7056416273117065
Early stop counter: 2
Epoch: 4/300, train loss : 0.695075169205

[32m[I 2023-03-08 17:27:00,981][0m Trial 1 finished with value: 0.6417305946350098 and parameters: {'num_layers': 3, 'hidden_size': 236, 'learning_rate': 0.0029548898157339566}. Best is trial 1 with value: 0.6417305946350098.[0m


Epoch: 41/300, train loss : 0.5783738195896149, validation loss : 0.7282333970069885
Early stop counter: 10
Epoch: 42/300, train loss : 0.5630632489919662, validation loss : 0.7511899471282959
Early stopping...
Fold 0
Epoch: 1/300, train loss : 8.63159091770649, validation loss : 0.6432852745056152
Early stop counter: 0
Epoch: 2/300, train loss : 0.7895463407039642, validation loss : 0.7142685651779175
Early stop counter: 1
Epoch: 3/300, train loss : 0.695825532078743, validation loss : 0.7048888206481934
Early stop counter: 2
Epoch: 4/300, train loss : 0.6898724138736725, validation loss : 0.692415714263916
Early stop counter: 3
Epoch: 5/300, train loss : 0.6848676055669785, validation loss : 0.6970462203025818
Early stop counter: 4
Epoch: 6/300, train loss : 0.6827777475118637, validation loss : 0.6725988984107971
Early stop counter: 5
Epoch: 7/300, train loss : 0.6812479794025421, validation loss : 0.6855850219726562
Early stop counter: 6
Epoch: 8/300, train loss : 0.677332460880279

[32m[I 2023-03-08 17:27:08,264][0m Trial 2 finished with value: 0.6520640134811402 and parameters: {'num_layers': 2, 'hidden_size': 388, 'learning_rate': 0.003996657346619775}. Best is trial 1 with value: 0.6417305946350098.[0m


Epoch: 11/300, train loss : 0.6750897467136383, validation loss : 0.68843674659729
Early stop counter: 8
Epoch: 12/300, train loss : 0.6795693933963776, validation loss : 0.7054633498191833
Early stop counter: 9
Epoch: 13/300, train loss : 0.671191468834877, validation loss : 0.6909047961235046
Early stop counter: 10
Epoch: 14/300, train loss : 0.6679084450006485, validation loss : 0.6928791999816895
Early stopping...
Fold 0
Epoch: 1/300, train loss : 2.9522322714328766, validation loss : 0.6588367819786072
Early stop counter: 0
Epoch: 2/300, train loss : 0.7018965184688568, validation loss : 0.6440079808235168
Early stop counter: 0
Epoch: 3/300, train loss : 0.7251817435026169, validation loss : 0.7277675867080688
Early stop counter: 1
Epoch: 4/300, train loss : 0.6890691667795181, validation loss : 0.6930416226387024
Early stop counter: 2
Epoch: 5/300, train loss : 0.6847665905952454, validation loss : 0.6995918154716492
Early stop counter: 3
Epoch: 6/300, train loss : 0.683430790901

[32m[I 2023-03-08 17:27:14,762][0m Trial 3 finished with value: 0.6552070617675781 and parameters: {'num_layers': 2, 'hidden_size': 316, 'learning_rate': 0.002431084610980964}. Best is trial 1 with value: 0.6417305946350098.[0m


Epoch: 11/300, train loss : 0.6663386821746826, validation loss : 0.6851768493652344
Early stop counter: 10
Epoch: 12/300, train loss : 0.6496738344430923, validation loss : 0.7077553272247314
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.5876177549362183, validation loss : 0.7121186256408691
Early stop counter: 0
Epoch: 2/300, train loss : 0.7168682366609573, validation loss : 0.6967899799346924
Early stop counter: 0
Epoch: 3/300, train loss : 0.7027144283056259, validation loss : 0.713668167591095
Early stop counter: 1
Epoch: 4/300, train loss : 0.6898261308670044, validation loss : 0.6893742084503174
Early stop counter: 0
Epoch: 5/300, train loss : 0.6893288940191269, validation loss : 0.696516215801239
Early stop counter: 1
Epoch: 6/300, train loss : 0.6885122805833817, validation loss : 0.6948941349983215
Early stop counter: 2
Epoch: 7/300, train loss : 0.6818639636039734, validation loss : 0.6830993890762329
Early stop counter: 0
Epoch: 8/300, train loss : 0.6799502670764

[32m[I 2023-03-08 17:27:25,093][0m Trial 4 finished with value: 0.6378164291381836 and parameters: {'num_layers': 3, 'hidden_size': 139, 'learning_rate': 0.005165445540071036}. Best is trial 4 with value: 0.6378164291381836.[0m


Epoch: 26/300, train loss : 0.6057845950126648, validation loss : 0.7546636462211609
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7991126030683517, validation loss : 0.7379488348960876
Early stop counter: 0
Epoch: 2/300, train loss : 0.6918845176696777, validation loss : 0.6857739090919495
Early stop counter: 0
Epoch: 3/300, train loss : 0.6821094900369644, validation loss : 0.6851747632026672
Early stop counter: 0
Epoch: 4/300, train loss : 0.682020902633667, validation loss : 0.66643226146698
Early stop counter: 0
Epoch: 5/300, train loss : 0.6821908801794052, validation loss : 0.6824136972427368
Early stop counter: 1
Epoch: 6/300, train loss : 0.6767149418592453, validation loss : 0.6735844016075134
Early stop counter: 2
Epoch: 7/300, train loss : 0.6744603961706161, validation loss : 0.6809825301170349
Early stop counter: 3
Epoch: 8/300, train loss : 0.6789331436157227, validation loss : 0.6681862473487854
Early stop counter: 4
Epoch: 9/300, train loss : 0.6757929176092148

[32m[I 2023-03-08 17:27:32,829][0m Trial 5 finished with value: 0.6506686687469483 and parameters: {'num_layers': 2, 'hidden_size': 117, 'learning_rate': 0.0015631623194900962}. Best is trial 4 with value: 0.6378164291381836.[0m


Epoch: 11/300, train loss : 0.6572900265455246, validation loss : 0.6808527708053589
Early stop counter: 10
Epoch: 12/300, train loss : 0.6538654565811157, validation loss : 0.6798198819160461
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.1226523965597153, validation loss : 0.6315076947212219
Early stop counter: 0
Epoch: 2/300, train loss : 0.8188349157571793, validation loss : 0.7081707715988159
Early stop counter: 1
Epoch: 3/300, train loss : 0.7100297510623932, validation loss : 0.747790515422821
Early stop counter: 2
Epoch: 4/300, train loss : 0.6917257457971573, validation loss : 0.6696569919586182
Early stop counter: 3
Epoch: 5/300, train loss : 0.6878180503845215, validation loss : 0.6777934432029724
Early stop counter: 4
Epoch: 6/300, train loss : 0.6755833178758621, validation loss : 0.6856722831726074
Early stop counter: 5
Epoch: 7/300, train loss : 0.6721468269824982, validation loss : 0.6805617213249207
Early stop counter: 6
Epoch: 8/300, train loss : 0.670506864786

[32m[I 2023-03-08 17:27:38,067][0m Trial 6 finished with value: 0.6552642345428467 and parameters: {'num_layers': 1, 'hidden_size': 163, 'learning_rate': 0.0016991330098596411}. Best is trial 4 with value: 0.6378164291381836.[0m


Epoch: 12/300, train loss : 0.6577487289905548, validation loss : 0.6848124861717224
Early stop counter: 8
Epoch: 13/300, train loss : 0.658298134803772, validation loss : 0.6982117891311646
Early stop counter: 9
Epoch: 14/300, train loss : 0.6473336815834045, validation loss : 0.6924318075180054
Early stop counter: 10
Epoch: 15/300, train loss : 0.6483405232429504, validation loss : 0.7057172656059265
Early stopping...
Fold 0
Epoch: 1/300, train loss : 3.1733203530311584, validation loss : 0.6855664849281311
Early stop counter: 0
Epoch: 2/300, train loss : 0.6956512629985809, validation loss : 0.6989365816116333
Early stop counter: 1
Epoch: 3/300, train loss : 0.7080753892660141, validation loss : 0.6570600867271423
Early stop counter: 0
Epoch: 4/300, train loss : 0.6908502131700516, validation loss : 0.7445995807647705
Early stop counter: 1
Epoch: 5/300, train loss : 0.6934053301811218, validation loss : 0.6845641732215881
Early stop counter: 2
Epoch: 6/300, train loss : 0.6875078529

[32m[I 2023-03-08 17:27:45,056][0m Trial 7 finished with value: 0.6522061347961425 and parameters: {'num_layers': 2, 'hidden_size': 300, 'learning_rate': 0.003120006155795529}. Best is trial 4 with value: 0.6378164291381836.[0m


Epoch: 31/300, train loss : 0.6148907542228699, validation loss : 0.679157018661499
Early stop counter: 10
Epoch: 32/300, train loss : 0.6173705011606216, validation loss : 0.6824929118156433
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.9851315915584564, validation loss : 0.6734644770622253
Early stop counter: 0
Epoch: 2/300, train loss : 0.7114749699831009, validation loss : 0.6986746788024902
Early stop counter: 1
Epoch: 3/300, train loss : 0.6869338899850845, validation loss : 0.6964678764343262
Early stop counter: 2
Epoch: 4/300, train loss : 0.6804896742105484, validation loss : 0.7061346769332886
Early stop counter: 3
Epoch: 5/300, train loss : 0.6769802421331406, validation loss : 0.6936851739883423
Early stop counter: 4
Epoch: 6/300, train loss : 0.6777147054672241, validation loss : 0.6800137758255005
Early stop counter: 5
Epoch: 7/300, train loss : 0.6711672842502594, validation loss : 0.7031738758087158
Early stop counter: 6
Epoch: 8/300, train loss : 0.670084878802

[32m[I 2023-03-08 17:27:54,129][0m Trial 8 finished with value: 0.6387906670570374 and parameters: {'num_layers': 2, 'hidden_size': 118, 'learning_rate': 0.003083866169761709}. Best is trial 4 with value: 0.6378164291381836.[0m


Epoch: 34/300, train loss : 0.5986854135990143, validation loss : 0.6844010353088379
Early stop counter: 9
Epoch: 35/300, train loss : 0.5833568572998047, validation loss : 0.754139244556427
Early stop counter: 10
Epoch: 36/300, train loss : 0.5781930238008499, validation loss : 0.7224237322807312
Early stopping...
Fold 0
Epoch: 1/300, train loss : 4.358331680297852, validation loss : 0.6432421207427979
Early stop counter: 0
Epoch: 2/300, train loss : 0.8158969432115555, validation loss : 0.7575024366378784
Early stop counter: 1
Epoch: 3/300, train loss : 0.7120636403560638, validation loss : 0.6840638518333435
Early stop counter: 2
Epoch: 4/300, train loss : 0.694124773144722, validation loss : 0.7232581973075867
Early stop counter: 3
Epoch: 5/300, train loss : 0.691305011510849, validation loss : 0.7080774903297424
Early stop counter: 4
Epoch: 6/300, train loss : 0.687859982252121, validation loss : 0.6988059282302856
Early stop counter: 5
Epoch: 7/300, train loss : 0.687110498547554

[32m[I 2023-03-08 17:28:00,668][0m Trial 9 finished with value: 0.6525529861450196 and parameters: {'num_layers': 1, 'hidden_size': 167, 'learning_rate': 0.007447175838035798}. Best is trial 4 with value: 0.6378164291381836.[0m


Epoch: 21/300, train loss : 0.5866739749908447, validation loss : 0.738029956817627
Early stop counter: 7
Epoch: 22/300, train loss : 0.5834020674228668, validation loss : 0.7834510803222656
Early stop counter: 8
Epoch: 23/300, train loss : 0.5692538768053055, validation loss : 0.8122541308403015
Early stop counter: 9
Epoch: 24/300, train loss : 0.5784130692481995, validation loss : 0.8279455900192261
Early stop counter: 10
Epoch: 25/300, train loss : 0.5717411488294601, validation loss : 0.8447062373161316
Early stopping...
Fold 0
Epoch: 1/300, train loss : 199.81769295036793, validation loss : 2.0797181129455566
Early stop counter: 0
Epoch: 2/300, train loss : 1.3958954215049744, validation loss : 0.9844027757644653
Early stop counter: 0
Epoch: 3/300, train loss : 0.8770518749952316, validation loss : 0.8980817794799805
Early stop counter: 0
Epoch: 4/300, train loss : 0.7544014900922775, validation loss : 0.6654353737831116
Early stop counter: 0
Epoch: 5/300, train loss : 0.698793902

[32m[I 2023-03-08 17:28:14,693][0m Trial 10 finished with value: 0.6421699047088623 and parameters: {'num_layers': 3, 'hidden_size': 500, 'learning_rate': 0.005321456089895417}. Best is trial 4 with value: 0.6378164291381836.[0m


Epoch: 17/300, train loss : 0.6640294045209885, validation loss : 0.6796482801437378
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7290194779634476, validation loss : 0.6916245222091675
Early stop counter: 0
Epoch: 2/300, train loss : 0.6941112130880356, validation loss : 0.7151503562927246
Early stop counter: 1
Epoch: 3/300, train loss : 0.6903774440288544, validation loss : 0.6803003549575806
Early stop counter: 0
Epoch: 4/300, train loss : 0.6824297904968262, validation loss : 0.7049717307090759
Early stop counter: 1
Epoch: 5/300, train loss : 0.6863228380680084, validation loss : 0.6762908101081848
Early stop counter: 0
Epoch: 6/300, train loss : 0.6832897961139679, validation loss : 0.6972962617874146
Early stop counter: 1
Epoch: 7/300, train loss : 0.6784681528806686, validation loss : 0.6780009269714355
Early stop counter: 2
Epoch: 8/300, train loss : 0.6774783283472061, validation loss : 0.6631118655204773
Early stop counter: 0
Epoch: 9/300, train loss : 0.6734953224658

[32m[I 2023-03-08 17:28:24,384][0m Trial 11 finished with value: 0.6412081122398376 and parameters: {'num_layers': 3, 'hidden_size': 79, 'learning_rate': 0.00468031132691518}. Best is trial 4 with value: 0.6378164291381836.[0m


Epoch: 30/300, train loss : 0.5784414708614349, validation loss : 0.6813012361526489
Early stop counter: 8
Epoch: 31/300, train loss : 0.5803799778223038, validation loss : 0.7121219635009766
Early stop counter: 9
Epoch: 32/300, train loss : 0.5827294737100601, validation loss : 0.7273911833763123
Early stop counter: 10
Epoch: 33/300, train loss : 0.585410863161087, validation loss : 0.7178755402565002
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7733245640993118, validation loss : 0.6567745804786682
Early stop counter: 0
Epoch: 2/300, train loss : 0.6969194412231445, validation loss : 0.6649909019470215
Early stop counter: 1
Epoch: 3/300, train loss : 0.6907249540090561, validation loss : 0.7139082551002502
Early stop counter: 2
Epoch: 4/300, train loss : 0.6843252629041672, validation loss : 0.6591064929962158
Early stop counter: 3
Epoch: 5/300, train loss : 0.6848434954881668, validation loss : 0.7083081007003784
Early stop counter: 4
Epoch: 6/300, train loss : 0.6835127472

[32m[I 2023-03-08 17:28:31,340][0m Trial 12 finished with value: 0.6638363599777222 and parameters: {'num_layers': 3, 'hidden_size': 222, 'learning_rate': 0.0010565629393094187}. Best is trial 4 with value: 0.6378164291381836.[0m


Epoch: 11/300, train loss : 0.6713817268610001, validation loss : 0.6990258693695068
Early stop counter: 8
Epoch: 12/300, train loss : 0.6677765399217606, validation loss : 0.7100552916526794
Early stop counter: 9
Epoch: 13/300, train loss : 0.6652011126279831, validation loss : 0.698284387588501
Early stop counter: 10
Epoch: 14/300, train loss : 0.6637815982103348, validation loss : 0.7263676524162292
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.2434329092502594, validation loss : 0.9742926359176636
Early stop counter: 0
Epoch: 2/300, train loss : 0.7705877423286438, validation loss : 0.7194169759750366
Early stop counter: 0
Epoch: 3/300, train loss : 0.6947411149740219, validation loss : 0.7146766781806946
Early stop counter: 0
Epoch: 4/300, train loss : 0.6893445551395416, validation loss : 0.7021282911300659
Early stop counter: 0
Epoch: 5/300, train loss : 0.6844590902328491, validation loss : 0.7086272239685059
Early stop counter: 1
Epoch: 6/300, train loss : 0.6815961301

[32m[I 2023-03-08 17:28:39,196][0m Trial 13 finished with value: 0.6263818264007568 and parameters: {'num_layers': 1, 'hidden_size': 66, 'learning_rate': 0.00889495369073538}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 19/300, train loss : 0.6267538070678711, validation loss : 0.7128151655197144
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.0951155573129654, validation loss : 0.7488470673561096
Early stop counter: 0
Epoch: 2/300, train loss : 0.706323504447937, validation loss : 0.6789425611495972
Early stop counter: 0
Epoch: 3/300, train loss : 0.6931903660297394, validation loss : 0.7241808176040649
Early stop counter: 1
Epoch: 4/300, train loss : 0.6873575150966644, validation loss : 0.6789000034332275
Early stop counter: 0
Epoch: 5/300, train loss : 0.6800819933414459, validation loss : 0.6905644536018372
Early stop counter: 1
Epoch: 6/300, train loss : 0.6830859482288361, validation loss : 0.690194308757782
Early stop counter: 2
Epoch: 7/300, train loss : 0.6768437474966049, validation loss : 0.6682340502738953
Early stop counter: 0
Epoch: 8/300, train loss : 0.6758593916893005, validation loss : 0.6830906271934509
Early stop counter: 1
Epoch: 9/300, train loss : 0.669304639101028

[32m[I 2023-03-08 17:28:47,006][0m Trial 14 finished with value: 0.6360784888267517 and parameters: {'num_layers': 1, 'hidden_size': 64, 'learning_rate': 0.008929222242246302}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 25/300, train loss : 0.5959462523460388, validation loss : 0.6677186489105225
Early stop counter: 7
Epoch: 26/300, train loss : 0.5915579050779343, validation loss : 0.6786025166511536
Early stop counter: 8
Epoch: 27/300, train loss : 0.591002568602562, validation loss : 0.6742405295372009
Early stop counter: 9
Epoch: 28/300, train loss : 0.5825726985931396, validation loss : 0.6888130307197571
Early stop counter: 10
Epoch: 29/300, train loss : 0.5931829810142517, validation loss : 0.7136247158050537
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.5068057477474213, validation loss : 0.9306325912475586
Early stop counter: 0
Epoch: 2/300, train loss : 0.7558557540178299, validation loss : 0.6596933007240295
Early stop counter: 0
Epoch: 3/300, train loss : 0.7223005443811417, validation loss : 0.7309799194335938
Early stop counter: 1
Epoch: 4/300, train loss : 0.7013391256332397, validation loss : 0.7472975850105286
Early stop counter: 2
Epoch: 5/300, train loss : 0.691942378

[32m[I 2023-03-08 17:28:52,772][0m Trial 15 finished with value: 0.6566055655479431 and parameters: {'num_layers': 1, 'hidden_size': 64, 'learning_rate': 0.007292736012825063}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 11/300, train loss : 0.6620710790157318, validation loss : 0.6876334547996521
Early stop counter: 8
Epoch: 12/300, train loss : 0.6553209275007248, validation loss : 0.6749137043952942
Early stop counter: 9
Epoch: 13/300, train loss : 0.6418420672416687, validation loss : 0.6854418516159058
Early stop counter: 10
Epoch: 14/300, train loss : 0.6444240510463715, validation loss : 0.6773446202278137
Early stopping...
Fold 0
Epoch: 1/300, train loss : 5.373830944299698, validation loss : 1.2680652141571045
Early stop counter: 0
Epoch: 2/300, train loss : 0.8730428665876389, validation loss : 0.6781675219535828
Early stop counter: 0
Epoch: 3/300, train loss : 0.7084522694349289, validation loss : 0.7429341673851013
Early stop counter: 1
Epoch: 4/300, train loss : 0.6880520880222321, validation loss : 0.6981235146522522
Early stop counter: 2
Epoch: 5/300, train loss : 0.6781488060951233, validation loss : 0.6943881511688232
Early stop counter: 3
Epoch: 6/300, train loss : 0.6757802516

[32m[I 2023-03-08 17:28:59,422][0m Trial 16 finished with value: 0.6489893198013306 and parameters: {'num_layers': 1, 'hidden_size': 222, 'learning_rate': 0.008240560724089978}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 24/300, train loss : 0.5892730802297592, validation loss : 0.7403672933578491
Early stop counter: 7
Epoch: 25/300, train loss : 0.578959733247757, validation loss : 0.76836758852005
Early stop counter: 8
Epoch: 26/300, train loss : 0.58347949385643, validation loss : 0.7103424668312073
Early stop counter: 9
Epoch: 27/300, train loss : 0.5735098123550415, validation loss : 0.7507114410400391
Early stop counter: 10
Epoch: 28/300, train loss : 0.556683674454689, validation loss : 0.7856432199478149
Early stopping...
Fold 0
Epoch: 1/300, train loss : 7.5111203789711, validation loss : 0.8870044946670532
Early stop counter: 0
Epoch: 2/300, train loss : 0.7481473237276077, validation loss : 0.6602721214294434
Early stop counter: 0
Epoch: 3/300, train loss : 0.6928695887327194, validation loss : 0.7082285284996033
Early stop counter: 1
Epoch: 4/300, train loss : 0.6831671297550201, validation loss : 0.6848911643028259
Early stop counter: 2
Epoch: 5/300, train loss : 0.6755701154470444,

[32m[I 2023-03-08 17:29:08,262][0m Trial 17 finished with value: 0.637643301486969 and parameters: {'num_layers': 1, 'hidden_size': 371, 'learning_rate': 0.006008755042092575}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 25/300, train loss : 0.5850342214107513, validation loss : 0.6843056678771973
Early stop counter: 8
Epoch: 26/300, train loss : 0.555417388677597, validation loss : 0.7489967346191406
Early stop counter: 9
Epoch: 27/300, train loss : 0.5824613124132156, validation loss : 0.7136040329933167
Early stop counter: 10
Epoch: 28/300, train loss : 0.5693350583314896, validation loss : 0.7250462770462036
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.3308832496404648, validation loss : 0.6393871307373047
Early stop counter: 0
Epoch: 2/300, train loss : 0.8025074750185013, validation loss : 0.7079101800918579
Early stop counter: 1
Epoch: 3/300, train loss : 0.6874194294214249, validation loss : 0.7263430953025818
Early stop counter: 2
Epoch: 4/300, train loss : 0.6862214952707291, validation loss : 0.6834619641304016
Early stop counter: 3
Epoch: 5/300, train loss : 0.6846596151590347, validation loss : 0.6791771054267883
Early stop counter: 4
Epoch: 6/300, train loss : 0.6832658350

[32m[I 2023-03-08 17:29:14,050][0m Trial 18 finished with value: 0.6540427684783936 and parameters: {'num_layers': 1, 'hidden_size': 71, 'learning_rate': 0.008465761977345514}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 11/300, train loss : 0.6615748107433319, validation loss : 0.6882041096687317
Early stop counter: 8
Epoch: 12/300, train loss : 0.6505308151245117, validation loss : 0.6952441930770874
Early stop counter: 9
Epoch: 13/300, train loss : 0.6505171209573746, validation loss : 0.6936350464820862
Early stop counter: 10
Epoch: 14/300, train loss : 0.6532359570264816, validation loss : 0.6970239877700806
Early stopping...
Fold 0
Epoch: 1/300, train loss : 45.43764539062977, validation loss : 0.796555757522583
Early stop counter: 0
Epoch: 2/300, train loss : 0.8541149348020554, validation loss : 0.8350351452827454
Early stop counter: 1
Epoch: 3/300, train loss : 0.7600741386413574, validation loss : 0.7443764805793762
Early stop counter: 0
Epoch: 4/300, train loss : 0.7103397399187088, validation loss : 0.6661850214004517
Early stop counter: 0
Epoch: 5/300, train loss : 0.7010574191808701, validation loss : 0.6859911680221558
Early stop counter: 1
Epoch: 6/300, train loss : 0.68476712703

[32m[I 2023-03-08 17:29:21,840][0m Trial 19 finished with value: 0.6520354747772217 and parameters: {'num_layers': 1, 'hidden_size': 511, 'learning_rate': 0.008972838201724386}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 13/300, train loss : 0.6879296898841858, validation loss : 0.7000828981399536
Early stop counter: 6
Epoch: 14/300, train loss : 0.6897188872098923, validation loss : 0.7023356556892395
Early stop counter: 7
Epoch: 15/300, train loss : 0.6875670105218887, validation loss : 0.703394889831543
Early stop counter: 8
Epoch: 16/300, train loss : 0.6884216368198395, validation loss : 0.7044783234596252
Early stop counter: 9
Epoch: 17/300, train loss : 0.6851161867380142, validation loss : 0.7041011452674866
Early stop counter: 10
Epoch: 18/300, train loss : 0.685950756072998, validation loss : 0.7054831981658936
Early stopping...
Fold 0
Epoch: 1/300, train loss : 2.1569988131523132, validation loss : 0.6327881813049316
Early stop counter: 0
Epoch: 2/300, train loss : 0.7408038377761841, validation loss : 0.8018288016319275
Early stop counter: 1
Epoch: 3/300, train loss : 0.7185826897621155, validation loss : 0.6808540225028992
Early stop counter: 2
Epoch: 4/300, train loss : 0.706297472

[32m[I 2023-03-08 17:29:27,843][0m Trial 20 finished with value: 0.6520498633384705 and parameters: {'num_layers': 1, 'hidden_size': 187, 'learning_rate': 0.004138220435837427}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 24/300, train loss : 0.609794870018959, validation loss : 0.6915411353111267
Early stopping...
Fold 0
Epoch: 1/300, train loss : 14.3370431214571, validation loss : 0.745606005191803
Early stop counter: 0
Epoch: 2/300, train loss : 0.6950158029794693, validation loss : 0.7219639420509338
Early stop counter: 0
Epoch: 3/300, train loss : 0.6855319291353226, validation loss : 0.6921920776367188
Early stop counter: 0
Epoch: 4/300, train loss : 0.6833368390798569, validation loss : 0.6975943446159363
Early stop counter: 1
Epoch: 5/300, train loss : 0.6798612922430038, validation loss : 0.6820301413536072
Early stop counter: 0
Epoch: 6/300, train loss : 0.6765985041856766, validation loss : 0.6774940490722656
Early stop counter: 0
Epoch: 7/300, train loss : 0.676329717040062, validation loss : 0.6767802834510803
Early stop counter: 0
Epoch: 8/300, train loss : 0.6716559678316116, validation loss : 0.6572036743164062
Early stop counter: 0
Epoch: 9/300, train loss : 0.6671964079141617, 

[32m[I 2023-03-08 17:29:35,187][0m Trial 21 finished with value: 0.6384324789047241 and parameters: {'num_layers': 1, 'hidden_size': 382, 'learning_rate': 0.005975320357374162}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 13/300, train loss : 0.6441105753183365, validation loss : 0.6863042116165161
Early stopping...
Fold 0
Epoch: 1/300, train loss : 12.486841216683388, validation loss : 0.6431815028190613
Early stop counter: 0
Epoch: 2/300, train loss : 0.7855774462223053, validation loss : 0.6784481406211853
Early stop counter: 1
Epoch: 3/300, train loss : 0.6959907859563828, validation loss : 0.7069223523139954
Early stop counter: 2
Epoch: 4/300, train loss : 0.6855828762054443, validation loss : 0.7106520533561707
Early stop counter: 3
Epoch: 5/300, train loss : 0.6916572153568268, validation loss : 0.6970338225364685
Early stop counter: 4
Epoch: 6/300, train loss : 0.6816179901361465, validation loss : 0.6894561052322388
Early stop counter: 5
Epoch: 7/300, train loss : 0.67868272960186, validation loss : 0.6697362661361694
Early stop counter: 6
Epoch: 8/300, train loss : 0.6737911552190781, validation loss : 0.6882695555686951
Early stop counter: 7
Epoch: 9/300, train loss : 0.675361260771751

[32m[I 2023-03-08 17:29:41,861][0m Trial 22 finished with value: 0.643317449092865 and parameters: {'num_layers': 1, 'hidden_size': 378, 'learning_rate': 0.006118207912171744}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 12/300, train loss : 0.6446778327226639, validation loss : 0.6968487501144409
Early stopping...
Fold 0
Epoch: 1/300, train loss : 19.34402060508728, validation loss : 0.8873581886291504
Early stop counter: 0
Epoch: 2/300, train loss : 1.03233303129673, validation loss : 0.9028478264808655
Early stop counter: 1
Epoch: 3/300, train loss : 0.7671102434396744, validation loss : 0.7007587552070618
Early stop counter: 0
Epoch: 4/300, train loss : 0.6903787404298782, validation loss : 0.690047025680542
Early stop counter: 0
Epoch: 5/300, train loss : 0.6900122314691544, validation loss : 0.6984509229660034
Early stop counter: 1
Epoch: 6/300, train loss : 0.6800097227096558, validation loss : 0.7066562175750732
Early stop counter: 2
Epoch: 7/300, train loss : 0.6761295199394226, validation loss : 0.6884707808494568
Early stop counter: 0
Epoch: 8/300, train loss : 0.6792749166488647, validation loss : 0.6879791617393494
Early stop counter: 0
Epoch: 9/300, train loss : 0.670460894703865, 

[32m[I 2023-03-08 17:29:50,558][0m Trial 23 finished with value: 0.6369202971458435 and parameters: {'num_layers': 1, 'hidden_size': 456, 'learning_rate': 0.0065138763934949}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 26/300, train loss : 0.5530513823032379, validation loss : 0.7332910299301147
Early stop counter: 10
Epoch: 27/300, train loss : 0.5509796291589737, validation loss : 0.7548772096633911
Early stopping...
Fold 0
Epoch: 1/300, train loss : 18.40979328751564, validation loss : 0.640316903591156
Early stop counter: 0
Epoch: 2/300, train loss : 0.7562841325998306, validation loss : 0.7704113125801086
Early stop counter: 1
Epoch: 3/300, train loss : 0.705472931265831, validation loss : 0.6833069920539856
Early stop counter: 2
Epoch: 4/300, train loss : 0.6869908571243286, validation loss : 0.7214648723602295
Early stop counter: 3
Epoch: 5/300, train loss : 0.686518669128418, validation loss : 0.6682657599449158
Early stop counter: 4
Epoch: 6/300, train loss : 0.6860639154911041, validation loss : 0.6945240497589111
Early stop counter: 5
Epoch: 7/300, train loss : 0.6826287508010864, validation loss : 0.6983597874641418
Early stop counter: 6
Epoch: 8/300, train loss : 0.678141489624977

[32m[I 2023-03-08 17:29:55,896][0m Trial 24 finished with value: 0.6566458582878113 and parameters: {'num_layers': 1, 'hidden_size': 455, 'learning_rate': 0.006932428006517908}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 11/300, train loss : 0.6626231819391251, validation loss : 0.6871582865715027
Early stop counter: 8
Epoch: 12/300, train loss : 0.6554464101791382, validation loss : 0.6962214112281799
Early stop counter: 9
Epoch: 13/300, train loss : 0.6528105437755585, validation loss : 0.7007930874824524
Early stop counter: 10
Epoch: 14/300, train loss : 0.6441080868244171, validation loss : 0.6947892308235168
Early stopping...
Fold 0
Epoch: 1/300, train loss : 9.32884594798088, validation loss : 0.9820179343223572
Early stop counter: 0
Epoch: 2/300, train loss : 1.197615996003151, validation loss : 0.7746633291244507
Early stop counter: 0
Epoch: 3/300, train loss : 0.7156254351139069, validation loss : 0.6786531805992126
Early stop counter: 0
Epoch: 4/300, train loss : 0.687994197010994, validation loss : 0.721997082233429
Early stop counter: 1
Epoch: 5/300, train loss : 0.6813667565584183, validation loss : 0.7020596265792847
Early stop counter: 2
Epoch: 6/300, train loss : 0.67380812764167

[32m[I 2023-03-08 17:30:02,387][0m Trial 25 finished with value: 0.6484314441680908 and parameters: {'num_layers': 1, 'hidden_size': 328, 'learning_rate': 0.007959733526900763}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 12/300, train loss : 0.6720413863658905, validation loss : 0.7063478231430054
Early stop counter: 9
Epoch: 13/300, train loss : 0.6659236997365952, validation loss : 0.7015162706375122
Early stop counter: 10
Epoch: 14/300, train loss : 0.6637903451919556, validation loss : 0.6991452574729919
Early stopping...
Fold 0
Epoch: 1/300, train loss : 14.774766713380814, validation loss : 3.9021756649017334
Early stop counter: 0
Epoch: 2/300, train loss : 1.7601355165243149, validation loss : 0.692152202129364
Early stop counter: 0
Epoch: 3/300, train loss : 0.6952750980854034, validation loss : 0.6717875003814697
Early stop counter: 0
Epoch: 4/300, train loss : 0.6930428296327591, validation loss : 0.7338197827339172
Early stop counter: 1
Epoch: 5/300, train loss : 0.6924112737178802, validation loss : 0.689782977104187
Early stop counter: 2
Epoch: 6/300, train loss : 0.6863170862197876, validation loss : 0.6939197182655334
Early stop counter: 3
Epoch: 7/300, train loss : 0.685636326670

[32m[I 2023-03-08 17:30:13,645][0m Trial 26 finished with value: 0.6386957287788391 and parameters: {'num_layers': 2, 'hidden_size': 459, 'learning_rate': 0.004104739353512161}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 41/300, train loss : 0.5775585323572159, validation loss : 0.6672062277793884
Early stop counter: 10
Epoch: 42/300, train loss : 0.5901357531547546, validation loss : 0.7057383060455322
Early stopping...
Fold 0
Epoch: 1/300, train loss : 11.205159902572632, validation loss : 0.6504194140434265
Early stop counter: 0
Epoch: 2/300, train loss : 0.7159443348646164, validation loss : 0.6827235817909241
Early stop counter: 1
Epoch: 3/300, train loss : 0.7009281069040298, validation loss : 0.7492140531539917
Early stop counter: 2
Epoch: 4/300, train loss : 0.6903630942106247, validation loss : 0.6726157069206238
Early stop counter: 3
Epoch: 5/300, train loss : 0.6961796283721924, validation loss : 0.6821300983428955
Early stop counter: 4
Epoch: 6/300, train loss : 0.6835672110319138, validation loss : 0.7094243764877319
Early stop counter: 5
Epoch: 7/300, train loss : 0.6871622800827026, validation loss : 0.6718339323997498
Early stop counter: 6
Epoch: 8/300, train loss : 0.67411707341

[32m[I 2023-03-08 17:30:20,003][0m Trial 27 finished with value: 0.640365743637085 and parameters: {'num_layers': 1, 'hidden_size': 264, 'learning_rate': 0.008886286354965165}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 13/300, train loss : 0.6234825849533081, validation loss : 0.7275334000587463
Early stop counter: 10
Epoch: 14/300, train loss : 0.6175742000341415, validation loss : 0.6916236281394958
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.4392645359039307, validation loss : 0.6363632082939148
Early stop counter: 0
Epoch: 2/300, train loss : 0.7332737296819687, validation loss : 0.7198876738548279
Early stop counter: 1
Epoch: 3/300, train loss : 0.6942693889141083, validation loss : 0.7095155119895935
Early stop counter: 2
Epoch: 4/300, train loss : 0.6912022829055786, validation loss : 0.690505862236023
Early stop counter: 3
Epoch: 5/300, train loss : 0.6890516877174377, validation loss : 0.7084537148475647
Early stop counter: 4
Epoch: 6/300, train loss : 0.6867453902959824, validation loss : 0.6814353466033936
Early stop counter: 5
Epoch: 7/300, train loss : 0.6775660812854767, validation loss : 0.716947615146637
Early stop counter: 6
Epoch: 8/300, train loss : 0.6824343949556

[32m[I 2023-03-08 17:30:26,689][0m Trial 28 finished with value: 0.6440068125724793 and parameters: {'num_layers': 1, 'hidden_size': 99, 'learning_rate': 0.006318000724240208}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 12/300, train loss : 0.6475618779659271, validation loss : 0.6905208230018616
Early stop counter: 10
Epoch: 13/300, train loss : 0.6425556987524033, validation loss : 0.6812742948532104
Early stopping...
Fold 0
Epoch: 1/300, train loss : 41.74121472239494, validation loss : 1.5528576374053955
Early stop counter: 0
Epoch: 2/300, train loss : 1.9624030590057373, validation loss : 0.6341509222984314
Early stop counter: 0
Epoch: 3/300, train loss : 0.7329131364822388, validation loss : 0.7390626072883606
Early stop counter: 1
Epoch: 4/300, train loss : 0.6985687762498856, validation loss : 0.6848932504653931
Early stop counter: 2
Epoch: 5/300, train loss : 0.6915372610092163, validation loss : 0.6954882740974426
Early stop counter: 3
Epoch: 6/300, train loss : 0.6900333166122437, validation loss : 0.7153339982032776
Early stop counter: 4
Epoch: 7/300, train loss : 0.6860326379537582, validation loss : 0.6842126250267029
Early stop counter: 5
Epoch: 8/300, train loss : 0.682306110858

[32m[I 2023-03-08 17:30:32,719][0m Trial 29 finished with value: 0.65691579580307 and parameters: {'num_layers': 2, 'hidden_size': 427, 'learning_rate': 0.00667738334985686}. Best is trial 13 with value: 0.6263818264007568.[0m


Epoch: 11/300, train loss : 0.6682745516300201, validation loss : 0.7099030613899231
Early stop counter: 8
Epoch: 12/300, train loss : 0.6685814410448074, validation loss : 0.7033895254135132
Early stop counter: 9
Epoch: 13/300, train loss : 0.6641942262649536, validation loss : 0.6898854970932007
Early stop counter: 10
Epoch: 14/300, train loss : 0.6578496545553207, validation loss : 0.6835047006607056
Early stopping...
best trial:
[0.6263818264007568]
Best parameters: {'num_layers': 1, 'hidden_size': 66, 'learning_rate': 0.00889495369073538}


## 3. Train/validate/test

1. After hyperparameters tuning, the best parameters saved to config.py as params_gin
2. Next, we will train/validate/test the model using 5-fold CV followed by repeating the process for 5 times.
3. Ensure that data folder, from_scratch_trained_models folder and this notebook are in the same working directory
4. run_training to repeat the training process, please create a new folder to retrain, save and load saved model
5. Else, comment the run_training function and do testing to get the results which were obtained in the paper. 

In [2]:
def run_training(train_loader, valid_loader, params, trained_model_path):
    model = GIN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_layers=params['num_layers'], hidden_size=params['hidden_size'])
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB_no_edge(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0 

    for epoch in range(EPOCHS):
        train_loss= eng.train(train_loader)
        valid_loss= eng.validate(valid_loader)
        print(f'Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}')
        if valid_loss < best_loss:
            best_loss = valid_loss 
            early_stopping_counter=0 #reset counter
            print('Saving model...')
            torch.save(model.state_dict(), trained_model_path)
        else:
            early_stopping_counter +=1

        if early_stopping_counter > early_stopping_iter:
            print('Early stopping...')
            break
        print(f'Early stop counter: {early_stopping_counter}')
    
    return best_loss

def run_testing(test_loader, params, trained_model_path):
    model = GIN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_layers=params['num_layers'], hidden_size=params['hidden_size'])
    model.load_state_dict(torch.load(trained_model_path))
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB_no_edge(model, optimizer, device=DEVICE)

    print('Begin testing...')
    bce, acc, f1, roc_auc= eng.test(test_loader)
    print('Test completed!')
    print(f'bce:{bce}, acc :{acc}, f1: {f1}, roc_auc: {roc_auc}')
    return bce, acc, f1, roc_auc

In [7]:
n_repetitions = 5
params = params_gin
train_data_root_path = './data/graph_data/data_oral_avail_train/'
train_data_raw_filename = 'data_oral_avail_train_50.csv'
test_data_root_path = './data/graph_data/data_oral_avail_test/'
test_data_raw_filename = 'data_oral_avail_test_1_50.csv'
path_to_save_trained_model = './from_scratch_trained_models/GIN/'

bce_list = []
acc_list = []
f1_list = []
roc_auc_list = []

#load dataset 
dataset_for_cv = LoadHOBDataset(root=train_data_root_path, raw_filename=train_data_raw_filename)
test_dataset = LoadHOBDataset(root=test_data_root_path, raw_filename=test_data_raw_filename)

kf = KFold(n_splits= N_SPLITS)

for repeat in range(n_repetitions):
    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        print(f'For rep: {repeat}, fold: {fold_no}')
        seed_everything(SEED_NO)
        train_dataset= []
        valid_dataset = []
        for t_idx in train_idx:
            train_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{t_idx}.pt'))
        for v_idx in valid_idx:
            valid_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{v_idx}.pt'))

        train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)

        run_training(train_loader, valid_loader, params, os.path.join(path_to_save_trained_model, f'gin_repeat_{repeat}_fold_{fold_no}.pt'))
        bce, acc, f1, roc_auc = run_testing(test_loader, params, os.path.join(path_to_save_trained_model, f'gin_repeat_{repeat}_fold_{fold_no}.pt'))
        bce_list.append(bce)
        acc_list.append(acc)
        f1_list.append(f1)
        roc_auc_list.append(roc_auc)

bce_arr = np.array(bce_list)
mean_bce = np.mean(bce_arr)
sd_bce = np.std(bce_arr)
print(f'bce:{mean_bce:.3f}±{sd_bce:.3f}')

acc_arr = np.array(acc_list)
acc_mean= np.mean(acc_arr)
acc_sd = np.std(acc_arr)
print(f'acc:{acc_mean:.3f}±{acc_sd:.3f}')

f1_arr = np.array(f1_list)
f1_mean= np.mean(f1_arr)
f1_sd = np.std(f1_arr)
print(f'f1: {f1_mean:.3f}±{f1_sd:.3f}')

roc_auc_arr = np.array(roc_auc_list)
roc_auc_mean= np.mean(roc_auc_arr)
roc_auc_sd = np.std(roc_auc_arr)
print(f'roc_auc: {roc_auc_mean:.3f}±{roc_auc_sd:.3f}')

For rep: 0, fold: 0
Epoch: 1/300, train loss : 1.4474536180496216, validation loss : 0.6468539237976074
Saving model...
Early stop counter: 0
Epoch: 2/300, train loss : 0.7944114208221436, validation loss : 0.7610831260681152
Early stop counter: 1
Epoch: 3/300, train loss : 0.7047408521175385, validation loss : 0.7292219996452332
Early stop counter: 2
Epoch: 4/300, train loss : 0.6896027475595474, validation loss : 0.7106878161430359
Early stop counter: 3
Epoch: 5/300, train loss : 0.686515599489212, validation loss : 0.6983362436294556
Early stop counter: 4
Epoch: 6/300, train loss : 0.6845705062150955, validation loss : 0.6978377103805542
Early stop counter: 5
Epoch: 7/300, train loss : 0.6809114813804626, validation loss : 0.7046172022819519
Early stop counter: 6
Epoch: 8/300, train loss : 0.6812890619039536, validation loss : 0.6997275948524475
Early stop counter: 7
Epoch: 9/300, train loss : 0.6773758679628372, validation loss : 0.6876591444015503
Early stop counter: 8
Epoch: 10/3