# Parallel GNN Model for Oral Bioavailability Dataset 

1. This notebook focuses on building the parallel GNN model from scratch.
2. Model's hyperparameters were found using the Optuna library using the Tree-structured Parzen Estimator Algorithm in 30 trials.
3. Model was trained/validated/tested using the best parameters and 5-fold CV. Process was repeated for 10 times and the results were averaged and reported at the end.

### Note
1. Ensure that this notebook is in the same working directory as data folder, config.py, utils.py, engine.py and model.py
2. To load models, please download the saved models provided in google drive link from README.md
3. Comment away the training function to load saved models and reproduce results

In [1]:
from utils import seed_everything, LoadHOBDataset
from config import SEED_NO, NUM_FEATURES, NUM_GRAPHS_PER_BATCH, NUM_TARGET, EDGE_DIM, DEVICE, PATIENCE, EPOCHS, N_SPLITS, params_parallel_gnn
from engine import EngineHOB
from model import ParallelGNN

import torch
import numpy as np
import optuna
from sklearn.model_selection import KFold
from torch_geometric.loader import DataLoader
import os 

## 2. Tuning the Model (Suggest to skip and use the already found hyperparameters saved in config.py)

1. Model was tuned using Optuna library following the Tree-structured Parzen Estimator Algorithm for 30 trials. 
2. run_tuning function created to facilitate the train and validation step using an early stopping mechanism 
3. objective function created to facilitate the tuning process

In [5]:
def run_tuning(train_loader, valid_loader, params):
    model = ParallelGNN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_gin_layers=params['num_gin_layers'], num_graph_trans_layers=params['num_graph_trans_layers'], 
                            hidden_size=params['hidden_size'], n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0 

    for epoch in range(EPOCHS):
        train_loss = eng.train(train_loader)
        valid_loss = eng.validate(valid_loader)
        print(f'Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}')
        if valid_loss < best_loss:
            best_loss = valid_loss 
            early_stopping_counter=0
        else:
            early_stopping_counter +=1

        if early_stopping_counter > early_stopping_iter:
            print('Early stopping...')
            break
        print(f'Early stop counter: {early_stopping_counter}')
    
    return best_loss

In [6]:
from sklearn.model_selection import KFold

def objective(trial):
    params = {
        'num_gin_layers' : trial.suggest_int('num_gin_layers', 1,3),
        'num_graph_trans_layers' : trial.suggest_int('num_graph_trans_layers', 1,3),
        'hidden_size' : trial.suggest_int('hidden_size', 64, 512),
        'n_heads' : trial.suggest_int('n_heads', 1, 5),
        'dropout': trial.suggest_float('dropout', 0.1,0.4),
        'learning_rate' : trial.suggest_float('learning_rate', 1e-3, 9e-3, log=True)
    }
    #load dataset 
    dataset_for_cv = LoadHOBDataset(root='./data/graph_data/data_oral_avail_train/', raw_filename='data_oral_avail_train_50.csv')
    kf = KFold(n_splits=N_SPLITS)
    fold_loss = 0

    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        print(f'Fold {fold_no}')
        train_dataset= []
        valid_dataset = []
        for t_idx in train_idx:
            train_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{t_idx}.pt'))
        for v_idx in valid_idx:
            valid_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{v_idx}.pt'))
        seed_everything(SEED_NO)
        train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)
        loss = run_tuning(train_loader, valid_loader, params)
        fold_loss+=loss
    return fold_loss/5

if __name__ == '__main__':
    study = optuna.create_study(direction = 'minimize')
    study.optimize(objective, n_trials=30)
    print(f'best trial:')
    trial_ = study.best_trial
    print(trial_.values)
    best_parameters= trial_.params
    print(f'Best parameters: {trial_.params}')

[32m[I 2023-03-08 18:18:09,338][0m A new study created in memory with name: no-name-7664fac9-a620-45d5-b1e9-d0cc90e84306[0m


Fold 0
Epoch: 1/300, train loss : 2.4939538538455963, validation loss : 0.7173627018928528
Early stop counter: 0
Epoch: 2/300, train loss : 0.7330368608236313, validation loss : 0.643173336982727
Early stop counter: 0
Epoch: 3/300, train loss : 0.7579510807991028, validation loss : 0.7866787314414978
Early stop counter: 1
Epoch: 4/300, train loss : 0.789401575922966, validation loss : 0.7663371562957764
Early stop counter: 2
Epoch: 5/300, train loss : 0.7215721011161804, validation loss : 0.7744593620300293
Early stop counter: 3
Epoch: 6/300, train loss : 0.7517992556095123, validation loss : 0.6887840628623962
Early stop counter: 4
Epoch: 7/300, train loss : 0.6680254936218262, validation loss : 0.6480331420898438
Early stop counter: 5
Epoch: 8/300, train loss : 0.6980799287557602, validation loss : 0.6425188183784485
Early stop counter: 0
Epoch: 9/300, train loss : 0.6851582080125809, validation loss : 0.666217029094696
Early stop counter: 1
Epoch: 10/300, train loss : 0.677630230784

[32m[I 2023-03-08 18:18:38,130][0m Trial 0 finished with value: 0.639851176738739 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 262, 'n_heads': 2, 'dropout': 0.1643323224060143, 'learning_rate': 0.003910770545963381}. Best is trial 0 with value: 0.639851176738739.[0m


Epoch: 17/300, train loss : 0.6131858229637146, validation loss : 0.8122575283050537
Early stopping...
Fold 0
Epoch: 1/300, train loss : 14.046315759420395, validation loss : 2.5939178466796875
Early stop counter: 0
Epoch: 2/300, train loss : 2.439465045928955, validation loss : 0.7556847333908081
Early stop counter: 0
Epoch: 3/300, train loss : 0.863432914018631, validation loss : 0.8677375912666321
Early stop counter: 1
Epoch: 4/300, train loss : 0.7609700560569763, validation loss : 0.6680319905281067
Early stop counter: 0
Epoch: 5/300, train loss : 0.6839437335729599, validation loss : 0.755303144454956
Early stop counter: 1
Epoch: 6/300, train loss : 0.7025461494922638, validation loss : 0.6212304830551147
Early stop counter: 0
Epoch: 7/300, train loss : 0.6692395359277725, validation loss : 0.7014698386192322
Early stop counter: 1
Epoch: 8/300, train loss : 0.678019106388092, validation loss : 0.626433789730072
Early stop counter: 2
Epoch: 9/300, train loss : 0.6762939244508743, 

[32m[I 2023-03-08 18:19:33,793][0m Trial 1 finished with value: 0.6356054782867432 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 480, 'n_heads': 3, 'dropout': 0.3779991183223783, 'learning_rate': 0.0037923301889909054}. Best is trial 1 with value: 0.6356054782867432.[0m


Epoch: 21/300, train loss : 0.6138100177049637, validation loss : 0.9740639328956604
Early stopping...
Fold 0
Epoch: 1/300, train loss : 3.0664569586515427, validation loss : 0.6473145484924316
Early stop counter: 0
Epoch: 2/300, train loss : 0.7778340280056, validation loss : 0.6807993054389954
Early stop counter: 1
Epoch: 3/300, train loss : 0.7000665813684464, validation loss : 0.6222838163375854
Early stop counter: 0
Epoch: 4/300, train loss : 0.6967761069536209, validation loss : 0.6398777365684509
Early stop counter: 1
Epoch: 5/300, train loss : 0.6694557368755341, validation loss : 0.6322041749954224
Early stop counter: 2
Epoch: 6/300, train loss : 0.6577355563640594, validation loss : 0.6146419644355774
Early stop counter: 0
Epoch: 7/300, train loss : 0.6707749515771866, validation loss : 0.6050277352333069
Early stop counter: 0
Epoch: 8/300, train loss : 0.6863097548484802, validation loss : 0.6105940341949463
Early stop counter: 1
Epoch: 9/300, train loss : 0.6615946590900421

[32m[I 2023-03-08 18:20:00,001][0m Trial 2 finished with value: 0.6332100749015808 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 205, 'n_heads': 4, 'dropout': 0.2540067653173666, 'learning_rate': 0.005282303864855328}. Best is trial 2 with value: 0.6332100749015808.[0m


Epoch: 17/300, train loss : 0.6195158213376999, validation loss : 0.8248212933540344
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.8329921662807465, validation loss : 0.6957094669342041
Early stop counter: 0
Epoch: 2/300, train loss : 0.7445800751447678, validation loss : 0.6652548909187317
Early stop counter: 0
Epoch: 3/300, train loss : 0.6903203129768372, validation loss : 0.6869654655456543
Early stop counter: 1
Epoch: 4/300, train loss : 0.6840294003486633, validation loss : 0.6551342606544495
Early stop counter: 0
Epoch: 5/300, train loss : 0.6815444380044937, validation loss : 0.6324110627174377
Early stop counter: 0
Epoch: 6/300, train loss : 0.6727654933929443, validation loss : 0.6878629326820374
Early stop counter: 1
Epoch: 7/300, train loss : 0.6772772818803787, validation loss : 0.6244980096817017
Early stop counter: 0
Epoch: 8/300, train loss : 0.6715857237577438, validation loss : 0.6653940677642822
Early stop counter: 1
Epoch: 9/300, train loss : 0.6608133465051

[32m[I 2023-03-08 18:21:12,116][0m Trial 3 finished with value: 0.6282815575599671 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 303, 'n_heads': 4, 'dropout': 0.15291000325767937, 'learning_rate': 0.001013496255090026}. Best is trial 3 with value: 0.6282815575599671.[0m


Epoch: 13/300, train loss : 0.6493795663118362, validation loss : 0.716230571269989
Early stopping...
Fold 0
Epoch: 1/300, train loss : 3.122802644968033, validation loss : 0.8108410835266113
Early stop counter: 0
Epoch: 2/300, train loss : 0.7093391865491867, validation loss : 0.7067710161209106
Early stop counter: 0
Epoch: 3/300, train loss : 0.682949885725975, validation loss : 0.6477230191230774
Early stop counter: 0
Epoch: 4/300, train loss : 0.6873316168785095, validation loss : 0.6423975825309753
Early stop counter: 0
Epoch: 5/300, train loss : 0.6803435981273651, validation loss : 0.6139384508132935
Early stop counter: 0
Epoch: 6/300, train loss : 0.689430370926857, validation loss : 0.6558903455734253
Early stop counter: 1
Epoch: 7/300, train loss : 0.6652538329362869, validation loss : 0.6306726932525635
Early stop counter: 2
Epoch: 8/300, train loss : 0.6564293801784515, validation loss : 0.6270866394042969
Early stop counter: 3
Epoch: 9/300, train loss : 0.6495833992958069,

[32m[I 2023-03-08 18:21:36,356][0m Trial 4 finished with value: 0.6284029841423034 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 233, 'n_heads': 3, 'dropout': 0.2674822804707799, 'learning_rate': 0.004219955536396461}. Best is trial 3 with value: 0.6282815575599671.[0m


Epoch: 13/300, train loss : 0.6397076994180679, validation loss : 0.7344841361045837
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.2262698411941528, validation loss : 0.6343650817871094
Early stop counter: 0
Epoch: 2/300, train loss : 0.7402738779783249, validation loss : 0.7538187503814697
Early stop counter: 1
Epoch: 3/300, train loss : 0.7027489244937897, validation loss : 0.6792540550231934
Early stop counter: 2
Epoch: 4/300, train loss : 0.6883874088525772, validation loss : 0.655296802520752
Early stop counter: 3
Epoch: 5/300, train loss : 0.6864889711141586, validation loss : 0.6983345746994019
Early stop counter: 4
Epoch: 6/300, train loss : 0.6781614571809769, validation loss : 0.6574143767356873
Early stop counter: 5
Epoch: 7/300, train loss : 0.6716096252202988, validation loss : 0.6496698260307312
Early stop counter: 6
Epoch: 8/300, train loss : 0.6692304164171219, validation loss : 0.6614819765090942
Early stop counter: 7
Epoch: 9/300, train loss : 0.66903729736804

[32m[I 2023-03-08 18:21:56,411][0m Trial 5 finished with value: 0.6337059497833252 and parameters: {'num_gin_layers': 1, 'num_graph_trans_layers': 1, 'hidden_size': 295, 'n_heads': 3, 'dropout': 0.14844832489847876, 'learning_rate': 0.0012590933758312558}. Best is trial 3 with value: 0.6282815575599671.[0m


Epoch: 19/300, train loss : 0.6210044026374817, validation loss : 0.7553722262382507
Early stop counter: 10
Epoch: 20/300, train loss : 0.6201108694076538, validation loss : 0.7666088342666626
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.1476402282714844, validation loss : 0.7066643834114075
Early stop counter: 0
Epoch: 2/300, train loss : 0.6848025321960449, validation loss : 0.6737179160118103
Early stop counter: 0
Epoch: 3/300, train loss : 0.685208335518837, validation loss : 0.6714944243431091
Early stop counter: 0
Epoch: 4/300, train loss : 0.6731132566928864, validation loss : 0.6339462995529175
Early stop counter: 0
Epoch: 5/300, train loss : 0.673894539475441, validation loss : 0.6327696442604065
Early stop counter: 0
Epoch: 6/300, train loss : 0.6730838418006897, validation loss : 0.6640677452087402
Early stop counter: 1
Epoch: 7/300, train loss : 0.6654458343982697, validation loss : 0.629962146282196
Early stop counter: 0
Epoch: 8/300, train loss : 0.67031583189964

[32m[I 2023-03-08 18:22:28,223][0m Trial 6 finished with value: 0.6273252248764039 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 2, 'hidden_size': 238, 'n_heads': 2, 'dropout': 0.3085123248122944, 'learning_rate': 0.002063729228053148}. Best is trial 6 with value: 0.6273252248764039.[0m


Epoch: 13/300, train loss : 0.6348036974668503, validation loss : 0.7438912391662598
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.3950881659984589, validation loss : 0.6500511169433594
Early stop counter: 0
Epoch: 2/300, train loss : 0.6920879632234573, validation loss : 0.6642302870750427
Early stop counter: 1
Epoch: 3/300, train loss : 0.6764357686042786, validation loss : 0.6401227712631226
Early stop counter: 0
Epoch: 4/300, train loss : 0.6689997613430023, validation loss : 0.6358216404914856
Early stop counter: 0
Epoch: 5/300, train loss : 0.6778045445680618, validation loss : 0.6221375465393066
Early stop counter: 0
Epoch: 6/300, train loss : 0.667238637804985, validation loss : 0.6303021907806396
Early stop counter: 1
Epoch: 7/300, train loss : 0.664186641573906, validation loss : 0.6516015529632568
Early stop counter: 2
Epoch: 8/300, train loss : 0.6676724404096603, validation loss : 0.6153518557548523
Early stop counter: 0
Epoch: 9/300, train loss : 0.672151282429695

[32m[I 2023-03-08 18:22:54,727][0m Trial 7 finished with value: 0.6358532547950745 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 1, 'hidden_size': 385, 'n_heads': 3, 'dropout': 0.3490573865160821, 'learning_rate': 0.00154137111029525}. Best is trial 6 with value: 0.6273252248764039.[0m


Epoch: 15/300, train loss : 0.6334307491779327, validation loss : 0.7294367551803589
Early stop counter: 10
Epoch: 16/300, train loss : 0.6364044696092606, validation loss : 0.7452634572982788
Early stopping...
Fold 0
Epoch: 1/300, train loss : 8.899005264043808, validation loss : 0.8464029431343079
Early stop counter: 0
Epoch: 2/300, train loss : 1.0569775998592377, validation loss : 1.2304880619049072
Early stop counter: 1
Epoch: 3/300, train loss : 1.126981407403946, validation loss : 0.8914934396743774
Early stop counter: 2
Epoch: 4/300, train loss : 0.8496424406766891, validation loss : 1.0499292612075806
Early stop counter: 3
Epoch: 5/300, train loss : 0.7862498164176941, validation loss : 0.7219454050064087
Early stop counter: 0
Epoch: 6/300, train loss : 0.6952113807201385, validation loss : 0.7883784770965576
Early stop counter: 1
Epoch: 7/300, train loss : 0.7124496400356293, validation loss : 0.6326088309288025
Early stop counter: 0
Epoch: 8/300, train loss : 0.6704098284244

[32m[I 2023-03-08 18:24:25,699][0m Trial 8 finished with value: 0.6333232641220092 and parameters: {'num_gin_layers': 1, 'num_graph_trans_layers': 2, 'hidden_size': 467, 'n_heads': 4, 'dropout': 0.14540687689292747, 'learning_rate': 0.004037874152406392}. Best is trial 6 with value: 0.6273252248764039.[0m


Epoch: 17/300, train loss : 0.6277913004159927, validation loss : 0.7470327019691467
Early stopping...
Fold 0
Epoch: 1/300, train loss : 77.26173876225948, validation loss : 8.191191673278809
Early stop counter: 0
Epoch: 2/300, train loss : 7.432128190994263, validation loss : 6.774438381195068
Early stop counter: 0
Epoch: 3/300, train loss : 5.911448359489441, validation loss : 3.2387311458587646
Early stop counter: 0
Epoch: 4/300, train loss : 3.05622598528862, validation loss : 0.976057767868042
Early stop counter: 0
Epoch: 5/300, train loss : 1.2085831612348557, validation loss : 0.9026588797569275
Early stop counter: 0
Epoch: 6/300, train loss : 0.9944210648536682, validation loss : 0.9721174240112305
Early stop counter: 1
Epoch: 7/300, train loss : 0.8782038539648056, validation loss : 0.680586576461792
Early stop counter: 0
Epoch: 8/300, train loss : 0.7931916862726212, validation loss : 0.8603593111038208
Early stop counter: 1
Epoch: 9/300, train loss : 0.9185087233781815, vali

[32m[I 2023-03-08 18:25:27,895][0m Trial 9 finished with value: 0.6460881352424621 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 2, 'hidden_size': 482, 'n_heads': 3, 'dropout': 0.3177271441368341, 'learning_rate': 0.0069902332920308684}. Best is trial 6 with value: 0.6273252248764039.[0m


Epoch: 16/300, train loss : 0.6259364783763885, validation loss : 0.9877999424934387
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.6984277367591858, validation loss : 0.6836860179901123
Early stop counter: 0
Epoch: 2/300, train loss : 0.6893937289714813, validation loss : 0.669323742389679
Early stop counter: 0
Epoch: 3/300, train loss : 0.6856524795293808, validation loss : 0.6844540238380432
Early stop counter: 1
Epoch: 4/300, train loss : 0.6741533130407333, validation loss : 0.6631792187690735
Early stop counter: 0
Epoch: 5/300, train loss : 0.6744329184293747, validation loss : 0.6356397867202759
Early stop counter: 0
Epoch: 6/300, train loss : 0.6775569915771484, validation loss : 0.6649448871612549
Early stop counter: 1
Epoch: 7/300, train loss : 0.6719945222139359, validation loss : 0.6681332588195801
Early stop counter: 2
Epoch: 8/300, train loss : 0.6676302999258041, validation loss : 0.6452798247337341
Early stop counter: 3
Epoch: 9/300, train loss : 0.67381986975669

[32m[I 2023-03-08 18:25:58,810][0m Trial 10 finished with value: 0.6365692853927613 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 69, 'n_heads': 1, 'dropout': 0.21125908405379046, 'learning_rate': 0.002188212451772836}. Best is trial 6 with value: 0.6273252248764039.[0m


Epoch: 25/300, train loss : 0.6042482554912567, validation loss : 0.7650191187858582
Early stop counter: 10
Epoch: 26/300, train loss : 0.602934405207634, validation loss : 0.7975103259086609
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7515048682689667, validation loss : 0.6768089532852173
Early stop counter: 0
Epoch: 2/300, train loss : 0.6849499642848969, validation loss : 0.7479279637336731
Early stop counter: 1
Epoch: 3/300, train loss : 0.7024220824241638, validation loss : 0.6591842770576477
Early stop counter: 0
Epoch: 4/300, train loss : 0.6786878556013107, validation loss : 0.6883030533790588
Early stop counter: 1
Epoch: 5/300, train loss : 0.6749297380447388, validation loss : 0.6537299752235413
Early stop counter: 0
Epoch: 6/300, train loss : 0.671771764755249, validation loss : 0.6463850140571594
Early stop counter: 0
Epoch: 7/300, train loss : 0.6639823168516159, validation loss : 0.6339784860610962
Early stop counter: 0
Epoch: 8/300, train loss : 0.6574694216251

[32m[I 2023-03-08 18:26:31,198][0m Trial 11 finished with value: 0.6374643802642822 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 151, 'n_heads': 5, 'dropout': 0.10188487662780901, 'learning_rate': 0.0010506640099765402}. Best is trial 6 with value: 0.6273252248764039.[0m


Epoch: 17/300, train loss : 0.5864483714103699, validation loss : 0.8605859279632568
Early stopping...
Fold 0
Epoch: 1/300, train loss : 2.966016635298729, validation loss : 0.651385486125946
Early stop counter: 0
Epoch: 2/300, train loss : 0.7017583698034286, validation loss : 0.662064254283905
Early stop counter: 1
Epoch: 3/300, train loss : 0.6968808472156525, validation loss : 0.6414375305175781
Early stop counter: 0
Epoch: 4/300, train loss : 0.6817531734704971, validation loss : 0.6713016629219055
Early stop counter: 1
Epoch: 5/300, train loss : 0.6770351827144623, validation loss : 0.6530677080154419
Early stop counter: 2
Epoch: 6/300, train loss : 0.6721996366977692, validation loss : 0.6189485192298889
Early stop counter: 0
Epoch: 7/300, train loss : 0.6724823713302612, validation loss : 0.717102587223053
Early stop counter: 1
Epoch: 8/300, train loss : 0.6758623123168945, validation loss : 0.6239417195320129
Early stop counter: 2
Epoch: 9/300, train loss : 0.67644102871418, v

[32m[I 2023-03-08 18:26:52,690][0m Trial 12 finished with value: 0.626483964920044 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 340, 'n_heads': 1, 'dropout': 0.2902901535513505, 'learning_rate': 0.002141928838095207}. Best is trial 12 with value: 0.626483964920044.[0m


Epoch: 19/300, train loss : 0.6219053417444229, validation loss : 0.7544025778770447
Early stop counter: 10
Epoch: 20/300, train loss : 0.6200795322656631, validation loss : 0.7990614175796509
Early stopping...
Fold 0
Epoch: 1/300, train loss : 3.445525974035263, validation loss : 0.6593570113182068
Early stop counter: 0
Epoch: 2/300, train loss : 0.7122109979391098, validation loss : 0.6634157299995422
Early stop counter: 1
Epoch: 3/300, train loss : 0.6981222033500671, validation loss : 0.6716336011886597
Early stop counter: 2
Epoch: 4/300, train loss : 0.6886319071054459, validation loss : 0.651577889919281
Early stop counter: 0
Epoch: 5/300, train loss : 0.6756789535284042, validation loss : 0.6836816668510437
Early stop counter: 1
Epoch: 6/300, train loss : 0.6720159202814102, validation loss : 0.6398669481277466
Early stop counter: 0
Epoch: 7/300, train loss : 0.6725600957870483, validation loss : 0.6566201448440552
Early stop counter: 1
Epoch: 8/300, train loss : 0.6736810207366

[32m[I 2023-03-08 18:27:05,093][0m Trial 13 finished with value: 0.6334024429321289 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 1, 'hidden_size': 376, 'n_heads': 1, 'dropout': 0.2986682822515492, 'learning_rate': 0.002334605090689267}. Best is trial 12 with value: 0.626483964920044.[0m


Epoch: 14/300, train loss : 0.6233761608600616, validation loss : 0.7691138386726379
Early stop counter: 10
Epoch: 15/300, train loss : 0.6278494000434875, validation loss : 0.787321150302887
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.5067407935857773, validation loss : 0.8001697063446045
Early stop counter: 0
Epoch: 2/300, train loss : 0.7186295986175537, validation loss : 0.6506810188293457
Early stop counter: 0
Epoch: 3/300, train loss : 0.7867259979248047, validation loss : 0.9062061905860901
Early stop counter: 1
Epoch: 4/300, train loss : 0.7187607735395432, validation loss : 0.6661841869354248
Early stop counter: 2
Epoch: 5/300, train loss : 0.6863533407449722, validation loss : 0.7342085242271423
Early stop counter: 3
Epoch: 6/300, train loss : 0.687626451253891, validation loss : 0.6581243872642517
Early stop counter: 4
Epoch: 7/300, train loss : 0.675747886300087, validation loss : 0.6480746865272522
Early stop counter: 0
Epoch: 8/300, train loss : 0.66654178500175

[32m[I 2023-03-08 18:27:45,143][0m Trial 14 finished with value: 0.6324865937232971 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 370, 'n_heads': 2, 'dropout': 0.39792132239704375, 'learning_rate': 0.0018483240356309098}. Best is trial 12 with value: 0.626483964920044.[0m


Epoch: 17/300, train loss : 0.6250402182340622, validation loss : 0.7795653343200684
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.0814793705940247, validation loss : 0.6406800746917725
Early stop counter: 0
Epoch: 2/300, train loss : 0.7020580619573593, validation loss : 0.713262677192688
Early stop counter: 1
Epoch: 3/300, train loss : 0.6845222413539886, validation loss : 0.6599865555763245
Early stop counter: 2
Epoch: 4/300, train loss : 0.681761160492897, validation loss : 0.6707145571708679
Early stop counter: 3
Epoch: 5/300, train loss : 0.6777154058218002, validation loss : 0.6551600098609924
Early stop counter: 4
Epoch: 6/300, train loss : 0.6711774319410324, validation loss : 0.6593573093414307
Early stop counter: 5
Epoch: 7/300, train loss : 0.6677936166524887, validation loss : 0.6412084102630615
Early stop counter: 6
Epoch: 8/300, train loss : 0.6631640195846558, validation loss : 0.622248113155365
Early stop counter: 0
Epoch: 9/300, train loss : 0.656763106584549,

[32m[I 2023-03-08 18:27:59,959][0m Trial 15 finished with value: 0.6290222406387329 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 170, 'n_heads': 2, 'dropout': 0.3074992829203516, 'learning_rate': 0.0026877373405416534}. Best is trial 12 with value: 0.626483964920044.[0m


Epoch: 16/300, train loss : 0.6374734938144684, validation loss : 0.7082340121269226
Early stop counter: 10
Epoch: 17/300, train loss : 0.6307713240385056, validation loss : 0.7309616804122925
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.1097168326377869, validation loss : 0.6791792511940002
Early stop counter: 0
Epoch: 2/300, train loss : 0.6872386932373047, validation loss : 0.6932333111763
Early stop counter: 1
Epoch: 3/300, train loss : 0.6778950095176697, validation loss : 0.6372545957565308
Early stop counter: 0
Epoch: 4/300, train loss : 0.6755769997835159, validation loss : 0.6926737427711487
Early stop counter: 1
Epoch: 5/300, train loss : 0.6984847635030746, validation loss : 0.6598134636878967
Early stop counter: 2
Epoch: 6/300, train loss : 0.6804216802120209, validation loss : 0.6640033721923828
Early stop counter: 3
Epoch: 7/300, train loss : 0.6698323041200638, validation loss : 0.6412436366081238
Early stop counter: 4
Epoch: 8/300, train loss : 0.66668616235256

[32m[I 2023-03-08 18:28:22,323][0m Trial 16 finished with value: 0.6268911838531495 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 328, 'n_heads': 1, 'dropout': 0.22256183330929216, 'learning_rate': 0.001649169343173939}. Best is trial 12 with value: 0.626483964920044.[0m


Epoch: 15/300, train loss : 0.6040416210889816, validation loss : 0.7920339703559875
Early stop counter: 10
Epoch: 16/300, train loss : 0.6163129657506943, validation loss : 0.789689302444458
Early stopping...
Fold 0
Epoch: 1/300, train loss : 2.1704580932855606, validation loss : 0.6304512023925781
Early stop counter: 0
Epoch: 2/300, train loss : 0.7914424538612366, validation loss : 0.8193125128746033
Early stop counter: 1
Epoch: 3/300, train loss : 0.7326686680316925, validation loss : 0.6901851296424866
Early stop counter: 2
Epoch: 4/300, train loss : 0.69322070479393, validation loss : 0.6670675873756409
Early stop counter: 3
Epoch: 5/300, train loss : 0.6875872761011124, validation loss : 0.7131337523460388
Early stop counter: 4
Epoch: 6/300, train loss : 0.686006560921669, validation loss : 0.6639062166213989
Early stop counter: 5
Epoch: 7/300, train loss : 0.6768019050359726, validation loss : 0.6728531122207642
Early stop counter: 6
Epoch: 8/300, train loss : 0.671872034668922

[32m[I 2023-03-08 18:28:38,574][0m Trial 17 finished with value: 0.6337471842765808 and parameters: {'num_gin_layers': 1, 'num_graph_trans_layers': 3, 'hidden_size': 338, 'n_heads': 1, 'dropout': 0.20830079105043148, 'learning_rate': 0.0015373676413761625}. Best is trial 12 with value: 0.626483964920044.[0m


Epoch: 19/300, train loss : 0.6183717250823975, validation loss : 0.8175073862075806
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.6729237735271454, validation loss : 0.6440417170524597
Early stop counter: 0
Epoch: 2/300, train loss : 0.7238806784152985, validation loss : 0.7446349263191223
Early stop counter: 1
Epoch: 3/300, train loss : 0.7010809183120728, validation loss : 0.6671257019042969
Early stop counter: 2
Epoch: 4/300, train loss : 0.6843334436416626, validation loss : 0.6693592667579651
Early stop counter: 3
Epoch: 5/300, train loss : 0.6761268526315689, validation loss : 0.6603164076805115
Early stop counter: 4
Epoch: 6/300, train loss : 0.670737236738205, validation loss : 0.6451408267021179
Early stop counter: 5
Epoch: 7/300, train loss : 0.6677006632089615, validation loss : 0.6334590315818787
Early stop counter: 0
Epoch: 8/300, train loss : 0.6634684801101685, validation loss : 0.6798432469367981
Early stop counter: 1
Epoch: 9/300, train loss : 0.67723560333251

[32m[I 2023-03-08 18:29:02,215][0m Trial 18 finished with value: 0.6350647330284118 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 3, 'hidden_size': 415, 'n_heads': 1, 'dropout': 0.2175484952583547, 'learning_rate': 0.0015793835395170024}. Best is trial 12 with value: 0.626483964920044.[0m


Epoch: 13/300, train loss : 0.6364284008741379, validation loss : 0.7816250324249268
Early stopping...
Fold 0
Epoch: 1/300, train loss : 14.649573504924774, validation loss : 0.7987235188484192
Early stop counter: 0
Epoch: 2/300, train loss : 0.7166242897510529, validation loss : 0.7202543020248413
Early stop counter: 0
Epoch: 3/300, train loss : 0.692073792219162, validation loss : 0.7007927298545837
Early stop counter: 0
Epoch: 4/300, train loss : 0.6841224879026413, validation loss : 0.6285472512245178
Early stop counter: 0
Epoch: 5/300, train loss : 0.6867532134056091, validation loss : 0.709362268447876
Early stop counter: 1
Epoch: 6/300, train loss : 0.6778994202613831, validation loss : 0.6558545827865601
Early stop counter: 2
Epoch: 7/300, train loss : 0.6721829622983932, validation loss : 0.6600181460380554
Early stop counter: 3
Epoch: 8/300, train loss : 0.6685891300439835, validation loss : 0.6427370309829712
Early stop counter: 4
Epoch: 9/300, train loss : 0.663321360945701

[32m[I 2023-03-08 18:29:26,524][0m Trial 19 finished with value: 0.6306280851364136 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 429, 'n_heads': 1, 'dropout': 0.27928994094169296, 'learning_rate': 0.002787541169951241}. Best is trial 12 with value: 0.626483964920044.[0m


Epoch: 17/300, train loss : 0.6490472257137299, validation loss : 0.6773761510848999
Early stopping...
Fold 0
Epoch: 1/300, train loss : 3.560432493686676, validation loss : 0.662259042263031
Early stop counter: 0
Epoch: 2/300, train loss : 0.7086225301027298, validation loss : 0.8324496150016785
Early stop counter: 1
Epoch: 3/300, train loss : 0.8449989259243011, validation loss : 1.0975674390792847
Early stop counter: 2
Epoch: 4/300, train loss : 0.7896507978439331, validation loss : 0.6772955656051636
Early stop counter: 3
Epoch: 5/300, train loss : 0.6741520017385483, validation loss : 0.7118306756019592
Early stop counter: 4
Epoch: 6/300, train loss : 0.6845822185277939, validation loss : 0.6320321559906006
Early stop counter: 0
Epoch: 7/300, train loss : 0.6827774941921234, validation loss : 0.647460401058197
Early stop counter: 1
Epoch: 8/300, train loss : 0.6748842895030975, validation loss : 0.6694908142089844
Early stop counter: 2
Epoch: 9/300, train loss : 0.6624189168214798

[32m[I 2023-03-08 18:29:54,289][0m Trial 20 finished with value: 0.6374416828155518 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 341, 'n_heads': 2, 'dropout': 0.23174845325981164, 'learning_rate': 0.003132189699970318}. Best is trial 12 with value: 0.626483964920044.[0m


Epoch: 15/300, train loss : 0.6382597237825394, validation loss : 0.7849808931350708
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.1143578886985779, validation loss : 0.7088374495506287
Early stop counter: 0
Epoch: 2/300, train loss : 0.6961649358272552, validation loss : 0.7852005958557129
Early stop counter: 1
Epoch: 3/300, train loss : 0.6978995501995087, validation loss : 0.6412951946258545
Early stop counter: 0
Epoch: 4/300, train loss : 0.6820433586835861, validation loss : 0.6587911248207092
Early stop counter: 1
Epoch: 5/300, train loss : 0.6768922954797745, validation loss : 0.6539334058761597
Early stop counter: 2
Epoch: 6/300, train loss : 0.672736719250679, validation loss : 0.6616366505622864
Early stop counter: 3
Epoch: 7/300, train loss : 0.6731900274753571, validation loss : 0.6324825286865234
Early stop counter: 0
Epoch: 8/300, train loss : 0.6672771126031876, validation loss : 0.6731951236724854
Early stop counter: 1
Epoch: 9/300, train loss : 0.65957096219062

[32m[I 2023-03-08 18:30:15,070][0m Trial 21 finished with value: 0.624030876159668 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 2, 'hidden_size': 248, 'n_heads': 2, 'dropout': 0.343018493669972, 'learning_rate': 0.002091824465119609}. Best is trial 21 with value: 0.624030876159668.[0m


Epoch: 13/300, train loss : 0.6515361666679382, validation loss : 0.6935710906982422
Early stop counter: 10
Epoch: 14/300, train loss : 0.6346854567527771, validation loss : 0.7274290919303894
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.0742025077342987, validation loss : 0.6557216048240662
Early stop counter: 0
Epoch: 2/300, train loss : 0.6959473788738251, validation loss : 0.6714611649513245
Early stop counter: 1
Epoch: 3/300, train loss : 0.6817672550678253, validation loss : 0.6831558346748352
Early stop counter: 2
Epoch: 4/300, train loss : 0.6742902398109436, validation loss : 0.6490635275840759
Early stop counter: 0
Epoch: 5/300, train loss : 0.6714750528335571, validation loss : 0.6220293641090393
Early stop counter: 0
Epoch: 6/300, train loss : 0.6794154047966003, validation loss : 0.6346258521080017
Early stop counter: 1
Epoch: 7/300, train loss : 0.6668739020824432, validation loss : 0.6311656832695007
Early stop counter: 2
Epoch: 8/300, train loss : 0.66666549444

[32m[I 2023-03-08 18:30:33,255][0m Trial 22 finished with value: 0.6331879496574402 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 2, 'hidden_size': 311, 'n_heads': 1, 'dropout': 0.3483847335270847, 'learning_rate': 0.0017429985375586653}. Best is trial 21 with value: 0.624030876159668.[0m


Epoch: 11/300, train loss : 0.6497674137353897, validation loss : 0.7399892807006836
Early stop counter: 9
Epoch: 12/300, train loss : 0.6649683564901352, validation loss : 0.717595100402832
Early stop counter: 10
Epoch: 13/300, train loss : 0.6416233777999878, validation loss : 0.7302404046058655
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.9270865172147751, validation loss : 0.6845552921295166
Early stop counter: 0
Epoch: 2/300, train loss : 0.6943684816360474, validation loss : 0.6521631479263306
Early stop counter: 0
Epoch: 3/300, train loss : 0.6911061704158783, validation loss : 0.7298539280891418
Early stop counter: 1
Epoch: 4/300, train loss : 0.6791773289442062, validation loss : 0.6437825560569763
Early stop counter: 0
Epoch: 5/300, train loss : 0.6798640638589859, validation loss : 0.6280087232589722
Early stop counter: 0
Epoch: 6/300, train loss : 0.6976183354854584, validation loss : 0.6708685755729675
Early stop counter: 1
Epoch: 7/300, train loss : 0.67437973618

[32m[I 2023-03-08 18:31:05,602][0m Trial 23 finished with value: 0.6240952134132385 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 278, 'n_heads': 2, 'dropout': 0.34029138801445824, 'learning_rate': 0.0013461897041275462}. Best is trial 21 with value: 0.624030876159668.[0m


Epoch: 18/300, train loss : 0.6415664255619049, validation loss : 0.7195770144462585
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.8179685771465302, validation loss : 0.7262129783630371
Early stop counter: 0
Epoch: 2/300, train loss : 0.6886090636253357, validation loss : 0.6700384616851807
Early stop counter: 0
Epoch: 3/300, train loss : 0.6794845312833786, validation loss : 0.6424809098243713
Early stop counter: 0
Epoch: 4/300, train loss : 0.6740519106388092, validation loss : 0.6742764711380005
Early stop counter: 1
Epoch: 5/300, train loss : 0.6741428971290588, validation loss : 0.6393031477928162
Early stop counter: 0
Epoch: 6/300, train loss : 0.6682565957307816, validation loss : 0.647302508354187
Early stop counter: 1
Epoch: 7/300, train loss : 0.6650256961584091, validation loss : 0.6313790678977966
Early stop counter: 0
Epoch: 8/300, train loss : 0.6626593917608261, validation loss : 0.6238999366760254
Early stop counter: 0
Epoch: 9/300, train loss : 0.66357462108135

[32m[I 2023-03-08 18:31:28,671][0m Trial 24 finished with value: 0.6343602657318115 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 1, 'hidden_size': 254, 'n_heads': 2, 'dropout': 0.34875213047685694, 'learning_rate': 0.001264619832264282}. Best is trial 21 with value: 0.624030876159668.[0m


Epoch: 17/300, train loss : 0.6281279176473618, validation loss : 0.7695703506469727
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.8086089640855789, validation loss : 0.6439380645751953
Early stop counter: 0
Epoch: 2/300, train loss : 0.6933643966913223, validation loss : 0.7287093997001648
Early stop counter: 1
Epoch: 3/300, train loss : 0.6921730488538742, validation loss : 0.6465936303138733
Early stop counter: 2
Epoch: 4/300, train loss : 0.680350735783577, validation loss : 0.6879406571388245
Early stop counter: 3
Epoch: 5/300, train loss : 0.6759736388921738, validation loss : 0.6418716907501221
Early stop counter: 0
Epoch: 6/300, train loss : 0.6698229759931564, validation loss : 0.646696925163269
Early stop counter: 1
Epoch: 7/300, train loss : 0.6650950312614441, validation loss : 0.6365209221839905
Early stop counter: 0
Epoch: 8/300, train loss : 0.6606481373310089, validation loss : 0.6326572895050049
Early stop counter: 0
Epoch: 9/300, train loss : 0.660895496606826

[32m[I 2023-03-08 18:31:53,258][0m Trial 25 finished with value: 0.6331281185150146 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 174, 'n_heads': 2, 'dropout': 0.3313141741041185, 'learning_rate': 0.0012511472098751864}. Best is trial 21 with value: 0.624030876159668.[0m


Epoch: 15/300, train loss : 0.6518146693706512, validation loss : 0.7074944376945496
Early stop counter: 10
Epoch: 16/300, train loss : 0.648928314447403, validation loss : 0.6907677054405212
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7160546630620956, validation loss : 0.7369830012321472
Early stop counter: 0
Epoch: 2/300, train loss : 0.6899518072605133, validation loss : 0.6714261770248413
Early stop counter: 0
Epoch: 3/300, train loss : 0.6802912056446075, validation loss : 0.6338697075843811
Early stop counter: 0
Epoch: 4/300, train loss : 0.6765812486410141, validation loss : 0.6830011606216431
Early stop counter: 1
Epoch: 5/300, train loss : 0.6726303100585938, validation loss : 0.6390283703804016
Early stop counter: 2
Epoch: 6/300, train loss : 0.664801225066185, validation loss : 0.6567137241363525
Early stop counter: 3
Epoch: 7/300, train loss : 0.6623780727386475, validation loss : 0.6240559220314026
Early stop counter: 0
Epoch: 8/300, train loss : 0.6603812277317

[32m[I 2023-03-08 18:32:19,462][0m Trial 26 finished with value: 0.6330269575119019 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 106, 'n_heads': 2, 'dropout': 0.3807513345337173, 'learning_rate': 0.0024110800158514316}. Best is trial 21 with value: 0.624030876159668.[0m


Epoch: 17/300, train loss : 0.6237828135490417, validation loss : 0.7526057362556458
Early stopping...
Fold 0
Epoch: 1/300, train loss : 2.218531683087349, validation loss : 0.6461711525917053
Early stop counter: 0
Epoch: 2/300, train loss : 0.698594331741333, validation loss : 0.6269339323043823
Early stop counter: 0
Epoch: 3/300, train loss : 0.6747196614742279, validation loss : 0.6896173357963562
Early stop counter: 1
Epoch: 4/300, train loss : 0.6757209300994873, validation loss : 0.646864116191864
Early stop counter: 2
Epoch: 5/300, train loss : 0.6692270189523697, validation loss : 0.6200777292251587
Early stop counter: 0
Epoch: 6/300, train loss : 0.6668005287647247, validation loss : 0.619050145149231
Early stop counter: 0
Epoch: 7/300, train loss : 0.6580386608839035, validation loss : 0.6267351508140564
Early stop counter: 1
Epoch: 8/300, train loss : 0.6521382182836533, validation loss : 0.6573284864425659
Early stop counter: 2
Epoch: 9/300, train loss : 0.6451257467269897,

[32m[I 2023-03-08 18:32:44,509][0m Trial 27 finished with value: 0.6348784804344177 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 267, 'n_heads': 2, 'dropout': 0.2883612613047224, 'learning_rate': 0.0032301664044156546}. Best is trial 21 with value: 0.624030876159668.[0m


Epoch: 13/300, train loss : 0.6362909078598022, validation loss : 0.8439246416091919
Early stop counter: 10
Epoch: 14/300, train loss : 0.6352681815624237, validation loss : 0.8290923833847046
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.9711443185806274, validation loss : 0.7117790579795837
Early stop counter: 0
Epoch: 2/300, train loss : 0.6864632219076157, validation loss : 0.6813191175460815
Early stop counter: 0
Epoch: 3/300, train loss : 0.6910950243473053, validation loss : 0.6658748984336853
Early stop counter: 0
Epoch: 4/300, train loss : 0.6792733669281006, validation loss : 0.7293267846107483
Early stop counter: 1
Epoch: 5/300, train loss : 0.6814297735691071, validation loss : 0.6351163387298584
Early stop counter: 0
Epoch: 6/300, train loss : 0.6789785325527191, validation loss : 0.6725941896438599
Early stop counter: 1
Epoch: 7/300, train loss : 0.6716247946023941, validation loss : 0.6369011998176575
Early stop counter: 2
Epoch: 8/300, train loss : 0.66643063724

[32m[I 2023-03-08 18:33:13,054][0m Trial 28 finished with value: 0.6292212843894959 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 210, 'n_heads': 1, 'dropout': 0.33163199030560864, 'learning_rate': 0.0019514110839291537}. Best is trial 21 with value: 0.624030876159668.[0m


Epoch: 15/300, train loss : 0.6429032236337662, validation loss : 0.8135614991188049
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.868861049413681, validation loss : 0.7544718980789185
Early stop counter: 0
Epoch: 2/300, train loss : 0.6900448948144913, validation loss : 0.6496054530143738
Early stop counter: 0
Epoch: 3/300, train loss : 0.6739673167467117, validation loss : 0.6580160856246948
Early stop counter: 1
Epoch: 4/300, train loss : 0.6796716153621674, validation loss : 0.6836300492286682
Early stop counter: 2
Epoch: 5/300, train loss : 0.6751524657011032, validation loss : 0.6502513289451599
Early stop counter: 3
Epoch: 6/300, train loss : 0.6754467040300369, validation loss : 0.6243088245391846
Early stop counter: 0
Epoch: 7/300, train loss : 0.6838980615139008, validation loss : 0.6693140268325806
Early stop counter: 1
Epoch: 8/300, train loss : 0.6720767468214035, validation loss : 0.6575260758399963
Early stop counter: 2
Epoch: 9/300, train loss : 0.66969393193721

[32m[I 2023-03-08 18:33:51,325][0m Trial 29 finished with value: 0.6365422248840332 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 279, 'n_heads': 2, 'dropout': 0.3668419607678221, 'learning_rate': 0.0013495001997228366}. Best is trial 21 with value: 0.624030876159668.[0m


Epoch: 16/300, train loss : 0.6573258489370346, validation loss : 0.700151264667511
Early stopping...
best trial:
[0.624030876159668]
Best parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 2, 'hidden_size': 248, 'n_heads': 2, 'dropout': 0.343018493669972, 'learning_rate': 0.002091824465119609}


## 3. Train/validate/test model

1. After tuning, the best parameters saved to config.py as params_parallel_gnn
2. Model was trained/validated/tested using the best parameters and 5-fold CV. Process was repeated 5 times and the final results were averaged.
3. Ensure that data folder, from_scratch_trained_models folder and this notebook are in the same working directory
4. run_training to repeat the training process, please create a new folder to retrain, save and load saved model
5. Else, comment the run_training function and do testing to get the results which were obtained in the paper. 

In [3]:
def run_training(train_loader, valid_loader, params, trained_model_path):
    model = ParallelGNN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_gin_layers=params['num_gin_layers'], num_graph_trans_layers=params['num_graph_trans_layers'], hidden_size=params['hidden_size'], 
                            n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0 

    for epoch in range(EPOCHS):
        train_loss = eng.train(train_loader)
        valid_loss = eng.validate(valid_loader)
        print(f'Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}')
        if valid_loss < best_loss:
            best_loss = valid_loss 
            early_stopping_counter=0 #reset counter
            print('Saving model...')
            torch.save(model.state_dict(), trained_model_path)
        else:
            early_stopping_counter +=1

        if early_stopping_counter > early_stopping_iter:
            print('Early stopping...')
            break
        print(f'Early stop counter: {early_stopping_counter}')
    
    return best_loss

def run_testing(test_loader, params, trained_model_path):
    model = ParallelGNN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_gin_layers=params['num_gin_layers'], num_graph_trans_layers=params['num_graph_trans_layers'], hidden_size=params['hidden_size'], 
                        n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.load_state_dict(torch.load(trained_model_path))
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB(model, optimizer, device=DEVICE)

    print('Begin testing...')
    bce, acc, f1, roc_auc = eng.test(test_loader)
    print('Test completed!')
    print(f'bce:{bce}, acc:{acc}, f1: {f1}, roc_auc: {roc_auc}')
    return bce, acc, f1, roc_auc

In [4]:
params = params_parallel_gnn
n_repetitions = 5
train_data_root_path = './data/graph_data/data_oral_avail_train/'
train_data_raw_filename = 'data_oral_avail_train_50.csv'
test_data_root_path = './data/graph_data/data_oral_avail_test/'
test_data_raw_filename = 'data_oral_avail_test_1_50.csv'
path_to_save_trained_model = './from_scratch_trained_models/parallel/'

bce_list = []
acc_list = []
f1_list = []
roc_auc_list = []

#load dataset 
dataset_for_cv = LoadHOBDataset(root=train_data_root_path, raw_filename=train_data_raw_filename)
test_dataset = LoadHOBDataset(root=test_data_root_path, raw_filename=test_data_raw_filename)
kf = KFold(n_splits=N_SPLITS)

for repeat in range(n_repetitions): 
    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        seed_everything(SEED_NO)
        train_dataset= []
        valid_dataset = []
        for t_idx in train_idx:
            train_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{t_idx}.pt'))
        for v_idx in valid_idx:
            valid_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{v_idx}.pt'))

        train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)

        run_training(train_loader, valid_loader, params, os.path.join(path_to_save_trained_model, f'parallel_repeat_{repeat}_fold_{fold_no}.pt'))
        bce, acc, f1, roc_auc = run_testing(test_loader, params,  os.path.join(path_to_save_trained_model, f'parallel_repeat_{repeat}_fold_{fold_no}.pt'))
        bce_list.append(bce)
        acc_list.append(acc)
        f1_list.append(f1)
        roc_auc_list.append(roc_auc)

bce_arr = np.array(bce_list)
mean_bce = np.mean(bce_arr)
sd_bce = np.std(bce_arr)
print(f'bce:{mean_bce:.3f}±{sd_bce:.3f}')

acc_arr = np.array(acc_list)
acc_mean= np.mean(acc_arr)
acc_sd = np.std(acc_arr)
print(f'acc:{acc_mean:.3f}±{acc_sd:.3f}')

f1_arr = np.array(f1_list)
f1_mean= np.mean(f1_arr)
f1_sd = np.std(f1_arr)
print(f'f1: {f1_mean:.3f}±{f1_sd:.3f}')

roc_auc_arr = np.array(roc_auc_list)
roc_auc_mean= np.mean(roc_auc_arr)
roc_auc_sd = np.std(roc_auc_arr)
print(f'roc_auc: {roc_auc_mean:.3f}±{roc_auc_sd:.3f}')

Epoch: 1/300, train loss : 1.0919382125139236, validation loss : 0.7346024513244629
Saving model...
Early stop counter: 0
Epoch: 2/300, train loss : 0.6976696997880936, validation loss : 0.6866105198860168
Saving model...
Early stop counter: 0
Epoch: 3/300, train loss : 0.6784620434045792, validation loss : 0.6410688757896423
Saving model...
Early stop counter: 0
Epoch: 4/300, train loss : 0.6682775914669037, validation loss : 0.6641042828559875
Early stop counter: 1
Epoch: 5/300, train loss : 0.6798685938119888, validation loss : 0.646278440952301
Early stop counter: 2
Epoch: 6/300, train loss : 0.6772485971450806, validation loss : 0.6281134486198425
Saving model...
Early stop counter: 0
Epoch: 7/300, train loss : 0.6712305098772049, validation loss : 0.664862334728241
Early stop counter: 1
Epoch: 8/300, train loss : 0.6656370908021927, validation loss : 0.6376354694366455
Early stop counter: 2
Epoch: 9/300, train loss : 0.6571381986141205, validation loss : 0.6277853846549988
Saving