# Vertical GNN Model for Oral Bioavailability Dataset 

1. This notebook focuses on building the vertical model to predict oral bioavailability 
2. Hyperparameters were found using Optuna library which used the Tree-structured Parzen Estimator Algorithm in 30 trials. 
3. Model was trained/validated/tested using the best parameters and 5-fold CV. Whole process was repeated 10 times and the results were averaged. 

### Note
1. Ensure that this notebook is in the same working directory as data folder, config.py, utils.py, engine.py and model.py
2. To load models, please download the saved models provided in google drive link from README.md
3. Comment away the training function to load saved models and reproduce results

In [1]:
from utils import seed_everything, LoadHOBDataset
from config import SEED_NO, NUM_FEATURES, NUM_GRAPHS_PER_BATCH, NUM_TARGET, EDGE_DIM, DEVICE, PATIENCE, EPOCHS, N_SPLITS, params_vertical_gnn
from engine import EngineHOB
from model import VerticalGNN

import torch
import numpy as np
import optuna
from sklearn.model_selection import KFold
from torch_geometric.loader import DataLoader
import os 

## 2. Tuning the Model (Suggest to skip this step and go straight to the best parameters saved in config.py)

1. Model was tuned using Optuna library following the Tree-structured Parzen Estimator Algorithm for 30 trials. 
2. run_tuning function created to facilitate the train and validation step using an early stopping mechanism 
3. objective function created to facilitate the tuning process

In [5]:
def run_tuning(train_loader, valid_loader, params):
    model = VerticalGNN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_gin_layers=params['num_gin_layers'], num_graph_trans_layers=params['num_graph_trans_layers'], 
                            hidden_size=params['hidden_size'], n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0 

    for epoch in range(EPOCHS):
        train_loss = eng.train(train_loader)
        valid_loss = eng.validate(valid_loader)
        print(f'Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}')
        if valid_loss < best_loss:
            best_loss = valid_loss 
            early_stopping_counter=0

        else:
            early_stopping_counter +=1

        if early_stopping_counter > early_stopping_iter:
            print('Early stopping...')
            break
        print(f'Early stop counter: {early_stopping_counter}')
    
    return best_loss


In [6]:
def objective(trial):
    params = {
        'num_gin_layers' : trial.suggest_int('num_gin_layers', 1,3),
        'num_graph_trans_layers' : trial.suggest_int('num_graph_trans_layers', 1,3),
        'hidden_size' : trial.suggest_int('hidden_size', 64, 512),
        'n_heads' : trial.suggest_int('n_heads', 1, 5),
        'dropout': trial.suggest_float('dropout', 0.1,0.4),
        'learning_rate' : trial.suggest_float('learning_rate', 1e-3, 9e-3, log=True)
    }
    
    
    #load dataset 
    dataset_for_cv = LoadHOBDataset(root='./data/graph_data/data_oral_avail_train/', raw_filename='data_oral_avail_train_50.csv')
    kf = KFold(n_splits=N_SPLITS)
    fold_loss = 0

    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        print(f'Fold {fold_no}')
        train_dataset= []
        valid_dataset = []
        for t_idx in train_idx:
            train_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{t_idx}.pt'))
        for v_idx in valid_idx:
            valid_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{v_idx}.pt'))

        train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)

        loss = run_tuning(train_loader, valid_loader, params)
        fold_loss += loss

    return fold_loss/5
if __name__ == '__main__':
    study = optuna.create_study(direction = 'minimize')
    study.optimize(objective, n_trials=30)
    print(f'best trial:')
    trial_ = study.best_trial
    print(trial_.values)
    print(f'Best parameters: {trial_.params}')

[32m[I 2023-03-15 14:32:09,925][0m A new study created in memory with name: no-name-35b513d2-770b-4ff8-a043-0ce4ff834b9b[0m


Fold 0
Epoch: 1/300, train loss : 6.748892426490784, validation loss : 3.036999464035034
Early stop counter: 0
Epoch: 2/300, train loss : 3.125407263636589, validation loss : 0.9028497934341431
Early stop counter: 0
Epoch: 3/300, train loss : 0.8272596895694733, validation loss : 0.7369771599769592
Early stop counter: 0
Epoch: 4/300, train loss : 0.7255954593420029, validation loss : 0.7451308369636536
Early stop counter: 1
Epoch: 5/300, train loss : 0.8327586799860001, validation loss : 0.6561927795410156
Early stop counter: 0
Epoch: 6/300, train loss : 0.6997271925210953, validation loss : 0.756048858165741
Early stop counter: 1
Epoch: 7/300, train loss : 0.6929575055837631, validation loss : 0.6666305661201477
Early stop counter: 2
Epoch: 8/300, train loss : 0.6753508299589157, validation loss : 0.7957857847213745
Early stop counter: 3
Epoch: 9/300, train loss : 0.717637374997139, validation loss : 0.6530423164367676
Early stop counter: 0
Epoch: 10/300, train loss : 0.69207571446895

[32m[I 2023-03-15 14:32:44,340][0m Trial 0 finished with value: 0.6306962013244629 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 3, 'hidden_size': 371, 'n_heads': 1, 'dropout': 0.17015542288272026, 'learning_rate': 0.004964634551335803}. Best is trial 0 with value: 0.6306962013244629.[0m


Epoch: 17/300, train loss : 0.6473052054643631, validation loss : 0.7519872188568115
Early stop counter: 10
Epoch: 18/300, train loss : 0.6421261727809906, validation loss : 0.7374377250671387
Early stopping...
Fold 0
Epoch: 1/300, train loss : 10.128466337919235, validation loss : 1.6040163040161133
Early stop counter: 0
Epoch: 2/300, train loss : 0.9656190276145935, validation loss : 0.6335681080818176
Early stop counter: 0
Epoch: 3/300, train loss : 0.6925968676805496, validation loss : 0.7112663984298706
Early stop counter: 1
Epoch: 4/300, train loss : 0.6760530322790146, validation loss : 0.6515933871269226
Early stop counter: 2
Epoch: 5/300, train loss : 0.7940814197063446, validation loss : 0.6253662109375
Early stop counter: 0
Epoch: 6/300, train loss : 0.693479061126709, validation loss : 0.654772937297821
Early stop counter: 1
Epoch: 7/300, train loss : 0.668770357966423, validation loss : 0.6547449231147766
Early stop counter: 2
Epoch: 8/300, train loss : 0.6681523472070694,

[32m[I 2023-03-15 14:33:05,052][0m Trial 1 finished with value: 0.6325212240219116 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 1, 'hidden_size': 237, 'n_heads': 3, 'dropout': 0.309002774883653, 'learning_rate': 0.006668373458318256}. Best is trial 0 with value: 0.6306962013244629.[0m


Epoch: 21/300, train loss : 0.6814696192741394, validation loss : 0.7364004254341125
Early stop counter: 10
Epoch: 22/300, train loss : 0.640919640660286, validation loss : 0.809412956237793
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.8795355409383774, validation loss : 0.6591995358467102
Early stop counter: 0
Epoch: 2/300, train loss : 0.7181357592344284, validation loss : 0.7176674008369446
Early stop counter: 1
Epoch: 3/300, train loss : 0.6851832270622253, validation loss : 0.7273228168487549
Early stop counter: 2
Epoch: 4/300, train loss : 0.6847313344478607, validation loss : 0.6415801644325256
Early stop counter: 0
Epoch: 5/300, train loss : 0.6830648183822632, validation loss : 0.6510316133499146
Early stop counter: 1
Epoch: 6/300, train loss : 0.6744579523801804, validation loss : 0.6620153188705444
Early stop counter: 2
Epoch: 7/300, train loss : 0.6694928109645844, validation loss : 0.6471895575523376
Early stop counter: 3
Epoch: 8/300, train loss : 0.6687404066324

[32m[I 2023-03-15 14:33:37,993][0m Trial 2 finished with value: 0.6227349638938904 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 2, 'hidden_size': 166, 'n_heads': 1, 'dropout': 0.336112649153141, 'learning_rate': 0.004708191915674545}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 13/300, train loss : 0.6291117817163467, validation loss : 0.7867037653923035
Early stop counter: 10
Epoch: 14/300, train loss : 0.62836953997612, validation loss : 0.777125895023346
Early stopping...
Fold 0
Epoch: 1/300, train loss : 2.226545736193657, validation loss : 0.7679882049560547
Early stop counter: 0
Epoch: 2/300, train loss : 0.7084696441888809, validation loss : 0.7146729826927185
Early stop counter: 0
Epoch: 3/300, train loss : 0.6889614164829254, validation loss : 0.696922779083252
Early stop counter: 0
Epoch: 4/300, train loss : 0.6856266856193542, validation loss : 0.6433396339416504
Early stop counter: 0
Epoch: 5/300, train loss : 0.6793364584445953, validation loss : 0.6668523550033569
Early stop counter: 1
Epoch: 6/300, train loss : 0.6901926845312119, validation loss : 0.6837530732154846
Early stop counter: 2
Epoch: 7/300, train loss : 0.6803929954767227, validation loss : 0.6513552069664001
Early stop counter: 3
Epoch: 8/300, train loss : 0.6706144660711288

[32m[I 2023-03-15 14:34:10,538][0m Trial 3 finished with value: 0.6251201748847961 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 376, 'n_heads': 1, 'dropout': 0.2475618140839548, 'learning_rate': 0.002001657308967741}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 17/300, train loss : 0.6310981661081314, validation loss : 0.7628094553947449
Early stopping...
Fold 0
Epoch: 1/300, train loss : 2.22432704269886, validation loss : 0.6434658765792847
Early stop counter: 0
Epoch: 2/300, train loss : 0.6917158663272858, validation loss : 0.645368218421936
Early stop counter: 1
Epoch: 3/300, train loss : 0.7134869545698166, validation loss : 0.6585476398468018
Early stop counter: 2
Epoch: 4/300, train loss : 0.6897723227739334, validation loss : 0.6770774722099304
Early stop counter: 3
Epoch: 5/300, train loss : 0.6817363947629929, validation loss : 0.673626184463501
Early stop counter: 4
Epoch: 6/300, train loss : 0.6758892089128494, validation loss : 0.6497521996498108
Early stop counter: 5
Epoch: 7/300, train loss : 0.6656743288040161, validation loss : 0.6445830464363098
Early stop counter: 6
Epoch: 8/300, train loss : 0.6780639439821243, validation loss : 0.6515500545501709
Early stop counter: 7
Epoch: 9/300, train loss : 0.6697481423616409,

[32m[I 2023-03-15 14:34:21,718][0m Trial 4 finished with value: 0.6277772903442382 and parameters: {'num_gin_layers': 1, 'num_graph_trans_layers': 1, 'hidden_size': 186, 'n_heads': 3, 'dropout': 0.3594802244086408, 'learning_rate': 0.0033250054076598623}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 18/300, train loss : 0.6322613507509232, validation loss : 0.7625434994697571
Early stop counter: 10
Epoch: 19/300, train loss : 0.6173858344554901, validation loss : 0.7570589184761047
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.6241798847913742, validation loss : 0.6419551372528076
Early stop counter: 0
Epoch: 2/300, train loss : 0.7672439515590668, validation loss : 0.7028898000717163
Early stop counter: 1
Epoch: 3/300, train loss : 0.6845470666885376, validation loss : 0.6846045851707458
Early stop counter: 2
Epoch: 4/300, train loss : 0.6823577731847763, validation loss : 0.6554393172264099
Early stop counter: 3
Epoch: 5/300, train loss : 0.6735333651304245, validation loss : 0.6701056361198425
Early stop counter: 4
Epoch: 6/300, train loss : 0.6729575842618942, validation loss : 0.640725314617157
Early stop counter: 0
Epoch: 7/300, train loss : 0.6710345596075058, validation loss : 0.6396326422691345
Early stop counter: 0
Epoch: 8/300, train loss : 0.668784961104

[32m[I 2023-03-15 14:34:55,152][0m Trial 5 finished with value: 0.6397471904754639 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 1, 'hidden_size': 417, 'n_heads': 5, 'dropout': 0.21282727193138923, 'learning_rate': 0.0011752009346610113}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 19/300, train loss : 0.6240558475255966, validation loss : 0.7065423130989075
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.4886675924062729, validation loss : 0.719188928604126
Early stop counter: 0
Epoch: 2/300, train loss : 0.692494198679924, validation loss : 0.7020728588104248
Early stop counter: 0
Epoch: 3/300, train loss : 0.6870514154434204, validation loss : 0.723892331123352
Early stop counter: 1
Epoch: 4/300, train loss : 0.681236982345581, validation loss : 0.642857015132904
Early stop counter: 0
Epoch: 5/300, train loss : 0.6735983490943909, validation loss : 0.6825004816055298
Early stop counter: 1
Epoch: 6/300, train loss : 0.6750924736261368, validation loss : 0.659145176410675
Early stop counter: 2
Epoch: 7/300, train loss : 0.6659624576568604, validation loss : 0.6302772760391235
Early stop counter: 0
Epoch: 8/300, train loss : 0.6603005826473236, validation loss : 0.6455551385879517
Early stop counter: 1
Epoch: 9/300, train loss : 0.6645642518997192, v

[32m[I 2023-03-15 14:35:21,137][0m Trial 6 finished with value: 0.6238780975341797 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 312, 'n_heads': 1, 'dropout': 0.2799823427453557, 'learning_rate': 0.0024229900554274017}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 13/300, train loss : 0.6522752642631531, validation loss : 0.6727927923202515
Early stopping...
Fold 0
Epoch: 1/300, train loss : 2.2742075473070145, validation loss : 1.1035287380218506
Early stop counter: 0
Epoch: 2/300, train loss : 0.8154382705688477, validation loss : 0.8077448606491089
Early stop counter: 0
Epoch: 3/300, train loss : 0.7089271545410156, validation loss : 0.6803175806999207
Early stop counter: 0
Epoch: 4/300, train loss : 0.6859500706195831, validation loss : 0.7096381187438965
Early stop counter: 1
Epoch: 5/300, train loss : 0.6983318030834198, validation loss : 0.6840311288833618
Early stop counter: 2
Epoch: 6/300, train loss : 0.6785029321908951, validation loss : 0.6597457528114319
Early stop counter: 0
Epoch: 7/300, train loss : 0.6749707609415054, validation loss : 0.6623382568359375
Early stop counter: 1
Epoch: 8/300, train loss : 0.6788096576929092, validation loss : 0.6307743787765503
Early stop counter: 0
Epoch: 9/300, train loss : 0.6775702536106

[32m[I 2023-03-15 14:35:59,729][0m Trial 7 finished with value: 0.6291199564933777 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 3, 'hidden_size': 459, 'n_heads': 1, 'dropout': 0.1975490451071128, 'learning_rate': 0.002292727920945042}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 15/300, train loss : 0.6660800129175186, validation loss : 0.6679283976554871
Early stopping...
Fold 0
Epoch: 1/300, train loss : 4.939049571752548, validation loss : 4.398576736450195
Early stop counter: 0
Epoch: 2/300, train loss : 14.466031789779663, validation loss : 3.3392417430877686
Early stop counter: 0
Epoch: 3/300, train loss : 10.15799406170845, validation loss : 2.2053234577178955
Early stop counter: 0
Epoch: 4/300, train loss : 1.971287339925766, validation loss : 2.4001047611236572
Early stop counter: 1
Epoch: 5/300, train loss : 1.4453069865703583, validation loss : 0.7353413105010986
Early stop counter: 0
Epoch: 6/300, train loss : 0.7328620553016663, validation loss : 0.6366521120071411
Early stop counter: 0
Epoch: 7/300, train loss : 0.7298051863908768, validation loss : 0.6720116138458252
Early stop counter: 1
Epoch: 8/300, train loss : 0.720588818192482, validation loss : 0.7570087313652039
Early stop counter: 2
Epoch: 9/300, train loss : 0.77742500603199, va

[32m[I 2023-03-15 14:36:28,573][0m Trial 8 finished with value: 0.6466650247573853 and parameters: {'num_gin_layers': 1, 'num_graph_trans_layers': 3, 'hidden_size': 110, 'n_heads': 4, 'dropout': 0.10660227379975493, 'learning_rate': 0.008901460985116814}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 25/300, train loss : 0.6606292277574539, validation loss : 0.7397141456604004
Early stop counter: 10
Epoch: 26/300, train loss : 0.663544699549675, validation loss : 0.7052003145217896
Early stopping...
Fold 0
Epoch: 1/300, train loss : 4.387152940034866, validation loss : 0.67697674036026
Early stop counter: 0
Epoch: 2/300, train loss : 0.6912118941545486, validation loss : 0.6463011503219604
Early stop counter: 0
Epoch: 3/300, train loss : 0.6784722357988358, validation loss : 0.6703953146934509
Early stop counter: 1
Epoch: 4/300, train loss : 0.676847517490387, validation loss : 0.6693186163902283
Early stop counter: 2
Epoch: 5/300, train loss : 0.6707155108451843, validation loss : 0.652658998966217
Early stop counter: 3
Epoch: 6/300, train loss : 0.6658669114112854, validation loss : 0.6501250267028809
Early stop counter: 4
Epoch: 7/300, train loss : 0.6612366884946823, validation loss : 0.6306907534599304
Early stop counter: 0
Epoch: 8/300, train loss : 0.6618857383728027,

[32m[I 2023-03-15 14:37:18,680][0m Trial 9 finished with value: 0.634007203578949 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 1, 'hidden_size': 424, 'n_heads': 5, 'dropout': 0.36020395980740016, 'learning_rate': 0.0018152932821810762}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 17/300, train loss : 0.6357371360063553, validation loss : 0.73529052734375
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.745319664478302, validation loss : 0.7278715372085571
Early stop counter: 0
Epoch: 2/300, train loss : 0.6906544119119644, validation loss : 0.6987801790237427
Early stop counter: 0
Epoch: 3/300, train loss : 0.688118040561676, validation loss : 0.6560823917388916
Early stop counter: 0
Epoch: 4/300, train loss : 0.6813690513372421, validation loss : 0.6653218269348145
Early stop counter: 1
Epoch: 5/300, train loss : 0.6748672723770142, validation loss : 0.6571424007415771
Early stop counter: 2
Epoch: 6/300, train loss : 0.6754325926303864, validation loss : 0.6363283395767212
Early stop counter: 0
Epoch: 7/300, train loss : 0.6663512587547302, validation loss : 0.6714757680892944
Early stop counter: 1
Epoch: 8/300, train loss : 0.6625195741653442, validation loss : 0.6331610679626465
Early stop counter: 0
Epoch: 9/300, train loss : 0.6562195718288422,

[32m[I 2023-03-15 14:37:30,809][0m Trial 10 finished with value: 0.6233137011528015 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 72, 'n_heads': 2, 'dropout': 0.3953833842972444, 'learning_rate': 0.004159297869331124}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 17/300, train loss : 0.6345505863428116, validation loss : 0.7487611174583435
Early stop counter: 8
Epoch: 18/300, train loss : 0.633037805557251, validation loss : 0.7423619627952576
Early stop counter: 9
Epoch: 19/300, train loss : 0.6223310679197311, validation loss : 0.7714013457298279
Early stop counter: 10
Epoch: 20/300, train loss : 0.612584576010704, validation loss : 0.7567546367645264
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.8070500940084457, validation loss : 0.7290478348731995
Early stop counter: 0
Epoch: 2/300, train loss : 0.6919244676828384, validation loss : 0.6532654166221619
Early stop counter: 0
Epoch: 3/300, train loss : 0.6941053569316864, validation loss : 0.6648375391960144
Early stop counter: 1
Epoch: 4/300, train loss : 0.6835976541042328, validation loss : 0.67082279920578
Early stop counter: 2
Epoch: 5/300, train loss : 0.6764417588710785, validation loss : 0.6611746549606323
Early stop counter: 3
Epoch: 6/300, train loss : 0.6706157028675

[32m[I 2023-03-15 14:37:52,641][0m Trial 11 finished with value: 0.6235228896141052 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 86, 'n_heads': 2, 'dropout': 0.3955631274133836, 'learning_rate': 0.004156314603583666}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 17/300, train loss : 0.6204242259263992, validation loss : 0.7462309002876282
Early stopping...
Fold 0
Epoch: 1/300, train loss : 4.005604088306427, validation loss : 0.6657474040985107
Early stop counter: 0
Epoch: 2/300, train loss : 0.698283240199089, validation loss : 0.7028370499610901
Early stop counter: 1
Epoch: 3/300, train loss : 0.6823550015687943, validation loss : 0.6561192274093628
Early stop counter: 0
Epoch: 4/300, train loss : 0.6850352138280869, validation loss : 0.6295526623725891
Early stop counter: 0
Epoch: 5/300, train loss : 0.6960565000772476, validation loss : 0.6673061847686768
Early stop counter: 1
Epoch: 6/300, train loss : 0.6760401278734207, validation loss : 0.6524321436882019
Early stop counter: 2
Epoch: 7/300, train loss : 0.6710426360368729, validation loss : 0.636035680770874
Early stop counter: 3
Epoch: 8/300, train loss : 0.6622183918952942, validation loss : 0.6316590309143066
Early stop counter: 4
Epoch: 9/300, train loss : 0.6649216115474701

[32m[I 2023-03-15 14:38:23,435][0m Trial 12 finished with value: 0.6264269709587097 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 167, 'n_heads': 2, 'dropout': 0.39726609365407256, 'learning_rate': 0.004696532617103457}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 21/300, train loss : 0.7080032676458359, validation loss : 0.6608945727348328
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7584007382392883, validation loss : 0.6689407229423523
Early stop counter: 0
Epoch: 2/300, train loss : 0.691729873418808, validation loss : 0.661119282245636
Early stop counter: 0
Epoch: 3/300, train loss : 0.6786101907491684, validation loss : 0.6284720301628113
Early stop counter: 0
Epoch: 4/300, train loss : 0.7048585265874863, validation loss : 0.6598718762397766
Early stop counter: 1
Epoch: 5/300, train loss : 0.6794360280036926, validation loss : 0.6494351625442505
Early stop counter: 2
Epoch: 6/300, train loss : 0.67885522544384, validation loss : 0.6702460646629333
Early stop counter: 3
Epoch: 7/300, train loss : 0.6704425811767578, validation loss : 0.6499454379081726
Early stop counter: 4
Epoch: 8/300, train loss : 0.6665206104516983, validation loss : 0.6564459204673767
Early stop counter: 5
Epoch: 9/300, train loss : 0.6653123646974564,

[32m[I 2023-03-15 14:38:53,410][0m Trial 13 finished with value: 0.6264236807823181 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 78, 'n_heads': 2, 'dropout': 0.31432427008964303, 'learning_rate': 0.003852520930958293}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 17/300, train loss : 0.6226976960897446, validation loss : 0.7221500277519226
Early stop counter: 10
Epoch: 18/300, train loss : 0.6139584332704544, validation loss : 0.703427255153656
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.2604379504919052, validation loss : 0.760235607624054
Early stop counter: 0
Epoch: 2/300, train loss : 0.7079179883003235, validation loss : 0.6534056067466736
Early stop counter: 0
Epoch: 3/300, train loss : 0.6996525973081589, validation loss : 0.6371732354164124
Early stop counter: 0
Epoch: 4/300, train loss : 0.6926891654729843, validation loss : 0.7092811465263367
Early stop counter: 1
Epoch: 5/300, train loss : 0.6848883479833603, validation loss : 0.6660487651824951
Early stop counter: 2
Epoch: 6/300, train loss : 0.6720506995916367, validation loss : 0.6393045783042908
Early stop counter: 3
Epoch: 7/300, train loss : 0.6837348639965057, validation loss : 0.6661136746406555
Early stop counter: 4
Epoch: 8/300, train loss : 0.6671392321586

[32m[I 2023-03-15 14:39:18,315][0m Trial 14 finished with value: 0.6274917721748352 and parameters: {'num_gin_layers': 1, 'num_graph_trans_layers': 2, 'hidden_size': 155, 'n_heads': 2, 'dropout': 0.3427714459799618, 'learning_rate': 0.0032276633603003707}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 17/300, train loss : 0.6138404905796051, validation loss : 0.7294663190841675
Early stopping...
Fold 0
Epoch: 1/300, train loss : 24.520649448037148, validation loss : 3.3157358169555664
Early stop counter: 0
Epoch: 2/300, train loss : 1.7072747647762299, validation loss : 0.6364015340805054
Early stop counter: 0
Epoch: 3/300, train loss : 0.7497516721487045, validation loss : 0.6385849118232727
Early stop counter: 1
Epoch: 4/300, train loss : 0.7819387763738632, validation loss : 0.6340721249580383
Early stop counter: 0
Epoch: 5/300, train loss : 0.7574281692504883, validation loss : 0.6428711414337158
Early stop counter: 1
Epoch: 6/300, train loss : 0.6881403177976608, validation loss : 0.6336440443992615
Early stop counter: 0
Epoch: 7/300, train loss : 0.6918198019266129, validation loss : 0.6302230954170227
Early stop counter: 0
Epoch: 8/300, train loss : 0.6849380135536194, validation loss : 0.6873286366462708
Early stop counter: 1
Epoch: 9/300, train loss : 0.6624999195337

[32m[I 2023-03-15 14:39:42,003][0m Trial 15 finished with value: 0.6471531391143799 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 241, 'n_heads': 2, 'dropout': 0.39660268585345304, 'learning_rate': 0.0069825894678739436}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 17/300, train loss : 0.630401000380516, validation loss : 0.7165343761444092
Early stop counter: 10
Epoch: 18/300, train loss : 0.6353053599596024, validation loss : 0.7351731061935425
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.7379689067602158, validation loss : 0.7209340929985046
Early stop counter: 0
Epoch: 2/300, train loss : 0.7387937158346176, validation loss : 0.7229207754135132
Early stop counter: 1
Epoch: 3/300, train loss : 0.7053660154342651, validation loss : 0.6654079556465149
Early stop counter: 0
Epoch: 4/300, train loss : 0.6865002065896988, validation loss : 0.6525206565856934
Early stop counter: 0
Epoch: 5/300, train loss : 0.6811643689870834, validation loss : 0.6864967346191406
Early stop counter: 1
Epoch: 6/300, train loss : 0.6734438240528107, validation loss : 0.6408601403236389
Early stop counter: 0
Epoch: 7/300, train loss : 0.6670289039611816, validation loss : 0.6688544154167175
Early stop counter: 1
Epoch: 8/300, train loss : 0.670453935861

[32m[I 2023-03-15 14:40:05,202][0m Trial 16 finished with value: 0.6410023212432862 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 135, 'n_heads': 3, 'dropout': 0.32869017102659154, 'learning_rate': 0.005335305991578674}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 19/300, train loss : 0.6463110148906708, validation loss : 0.7397531867027283
Early stop counter: 10
Epoch: 20/300, train loss : 0.6251230388879776, validation loss : 0.7478785514831543
Early stopping...
Fold 0
Epoch: 1/300, train loss : 4.6195559203624725, validation loss : 1.9402042627334595
Early stop counter: 0
Epoch: 2/300, train loss : 1.785611867904663, validation loss : 1.3542051315307617
Early stop counter: 0
Epoch: 3/300, train loss : 1.0104966461658478, validation loss : 1.5928215980529785
Early stop counter: 1
Epoch: 4/300, train loss : 0.9661788940429688, validation loss : 0.6272243857383728
Early stop counter: 0
Epoch: 5/300, train loss : 0.7118723541498184, validation loss : 0.6474635601043701
Early stop counter: 1
Epoch: 6/300, train loss : 0.6746677309274673, validation loss : 0.6403450965881348
Early stop counter: 2
Epoch: 7/300, train loss : 0.6729228347539902, validation loss : 0.6491101384162903
Early stop counter: 3
Epoch: 8/300, train loss : 0.661466404795

[32m[I 2023-03-15 14:40:41,093][0m Trial 17 finished with value: 0.632448697090149 and parameters: {'num_gin_layers': 1, 'num_graph_trans_layers': 2, 'hidden_size': 213, 'n_heads': 4, 'dropout': 0.2967060038170862, 'learning_rate': 0.002890728516676225}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 17/300, train loss : 0.6500114947557449, validation loss : 0.6973151564598083
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.8371545076370239, validation loss : 0.7129909992218018
Early stop counter: 0
Epoch: 2/300, train loss : 0.6899519264698029, validation loss : 0.707292377948761
Early stop counter: 0
Epoch: 3/300, train loss : 0.6842165440320969, validation loss : 0.6630771160125732
Early stop counter: 0
Epoch: 4/300, train loss : 0.6752964556217194, validation loss : 0.7085954546928406
Early stop counter: 1
Epoch: 5/300, train loss : 0.681494951248169, validation loss : 0.6437923312187195
Early stop counter: 0
Epoch: 6/300, train loss : 0.6772195100784302, validation loss : 0.665226399898529
Early stop counter: 1
Epoch: 7/300, train loss : 0.671209529042244, validation loss : 0.6369608640670776
Early stop counter: 0
Epoch: 8/300, train loss : 0.6690821200609207, validation loss : 0.6524055600166321
Early stop counter: 1
Epoch: 9/300, train loss : 0.6644243001937866,

[32m[I 2023-03-15 14:41:02,311][0m Trial 18 finished with value: 0.6302677035331726 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 67, 'n_heads': 1, 'dropout': 0.36388032874709747, 'learning_rate': 0.005686177253244872}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 17/300, train loss : 0.6352608948945999, validation loss : 0.7183231115341187
Early stop counter: 9
Epoch: 18/300, train loss : 0.6250704377889633, validation loss : 0.7297820448875427
Early stop counter: 10
Epoch: 19/300, train loss : 0.6261201351881027, validation loss : 0.7871959209442139
Early stopping...
Fold 0
Epoch: 1/300, train loss : 7.366481438279152, validation loss : 0.667879045009613
Early stop counter: 0
Epoch: 2/300, train loss : 0.8322846293449402, validation loss : 0.6383465528488159
Early stop counter: 0
Epoch: 3/300, train loss : 0.6894145756959915, validation loss : 0.6493169665336609
Early stop counter: 1
Epoch: 4/300, train loss : 0.6852497607469559, validation loss : 0.6400882005691528
Early stop counter: 2
Epoch: 5/300, train loss : 0.6720666587352753, validation loss : 0.6813063025474548
Early stop counter: 3
Epoch: 6/300, train loss : 0.6699598729610443, validation loss : 0.6349718570709229
Early stop counter: 0
Epoch: 7/300, train loss : 0.661526724696

[32m[I 2023-03-15 14:41:22,822][0m Trial 19 finished with value: 0.6299411654472351 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 1, 'hidden_size': 289, 'n_heads': 3, 'dropout': 0.2748508895838155, 'learning_rate': 0.003891239420470826}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 24/300, train loss : 0.61461441218853, validation loss : 0.7761399149894714
Early stop counter: 10
Epoch: 25/300, train loss : 0.6001279652118683, validation loss : 0.8165405988693237
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1975.7034999281168, validation loss : 15.538107872009277
Early stop counter: 0
Epoch: 2/300, train loss : 14.286682963371277, validation loss : 20.77129364013672
Early stop counter: 1
Epoch: 3/300, train loss : 10.766449809074402, validation loss : 1.3646594285964966
Early stop counter: 0
Epoch: 4/300, train loss : 4.2465967535972595, validation loss : 1.586684226989746
Early stop counter: 1
Epoch: 5/300, train loss : 2.1038645207881927, validation loss : 6.855830669403076
Early stop counter: 2
Epoch: 6/300, train loss : 3.9393365383148193, validation loss : 1.7546005249023438
Early stop counter: 3
Epoch: 7/300, train loss : 1.8867939561605453, validation loss : 1.6723898649215698
Early stop counter: 4
Epoch: 8/300, train loss : 1.1076807379722595

[32m[I 2023-03-15 14:41:59,109][0m Trial 20 finished with value: 0.6568501353263855 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 2, 'hidden_size': 509, 'n_heads': 2, 'dropout': 0.3337584954533939, 'learning_rate': 0.00632280728423463}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 23/300, train loss : 0.6519289761781693, validation loss : 0.7528040409088135
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.032891422510147, validation loss : 0.6611330509185791
Early stop counter: 0
Epoch: 2/300, train loss : 0.6965923458337784, validation loss : 0.6730406880378723
Early stop counter: 1
Epoch: 3/300, train loss : 0.6807992160320282, validation loss : 0.7196676135063171
Early stop counter: 2
Epoch: 4/300, train loss : 0.6809573471546173, validation loss : 0.6437917351722717
Early stop counter: 0
Epoch: 5/300, train loss : 0.6785984635353088, validation loss : 0.650452196598053
Early stop counter: 1
Epoch: 6/300, train loss : 0.6748433411121368, validation loss : 0.6840813755989075
Early stop counter: 2
Epoch: 7/300, train loss : 0.6791532635688782, validation loss : 0.6408788561820984
Early stop counter: 0
Epoch: 8/300, train loss : 0.6669837981462479, validation loss : 0.6604483127593994
Early stop counter: 1
Epoch: 9/300, train loss : 0.660811468958854

[32m[I 2023-03-15 14:42:13,617][0m Trial 21 finished with value: 0.6295963048934936 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 99, 'n_heads': 2, 'dropout': 0.3998766331805553, 'learning_rate': 0.004128711976412381}. Best is trial 2 with value: 0.6227349638938904.[0m


Epoch: 17/300, train loss : 0.6308725476264954, validation loss : 0.7597692012786865
Early stop counter: 9
Epoch: 18/300, train loss : 0.6477007865905762, validation loss : 0.7232088446617126
Early stop counter: 10
Epoch: 19/300, train loss : 0.6348482072353363, validation loss : 0.7550041675567627
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.9783162921667099, validation loss : 0.6833848357200623
Early stop counter: 0
Epoch: 2/300, train loss : 0.6978511065244675, validation loss : 0.6643924713134766
Early stop counter: 0
Epoch: 3/300, train loss : 0.6876787394285202, validation loss : 0.6769387722015381
Early stop counter: 1
Epoch: 4/300, train loss : 0.6792959421873093, validation loss : 0.6612162590026855
Early stop counter: 0
Epoch: 5/300, train loss : 0.6674123108386993, validation loss : 0.6249630451202393
Early stop counter: 0
Epoch: 6/300, train loss : 0.6957174092531204, validation loss : 0.6902725100517273
Early stop counter: 1
Epoch: 7/300, train loss : 0.6699984073

[32m[I 2023-03-15 14:42:29,683][0m Trial 22 finished with value: 0.6207164764404297 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 122, 'n_heads': 2, 'dropout': 0.36738054656589025, 'learning_rate': 0.00452976319043267}. Best is trial 22 with value: 0.6207164764404297.[0m


Epoch: 16/300, train loss : 0.643414169549942, validation loss : 0.7463109493255615
Early stop counter: 10
Epoch: 17/300, train loss : 0.6444884836673737, validation loss : 0.6917109489440918
Early stopping...
Fold 0
Epoch: 1/300, train loss : 3.2129150480031967, validation loss : 0.6644167900085449
Early stop counter: 0
Epoch: 2/300, train loss : 0.6991538852453232, validation loss : 0.7092569470405579
Early stop counter: 1
Epoch: 3/300, train loss : 0.6888102144002914, validation loss : 0.6821714043617249
Early stop counter: 2
Epoch: 4/300, train loss : 0.6834046840667725, validation loss : 0.7084105610847473
Early stop counter: 3
Epoch: 5/300, train loss : 0.6843112260103226, validation loss : 0.6573765277862549
Early stop counter: 0
Epoch: 6/300, train loss : 0.6820706427097321, validation loss : 0.6914446353912354
Early stop counter: 1
Epoch: 7/300, train loss : 0.6793266832828522, validation loss : 0.6582872867584229
Early stop counter: 2
Epoch: 8/300, train loss : 0.676413625478

[32m[I 2023-03-15 14:42:42,374][0m Trial 23 finished with value: 0.6294723153114319 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 134, 'n_heads': 1, 'dropout': 0.3644322983667116, 'learning_rate': 0.005296092583200545}. Best is trial 22 with value: 0.6207164764404297.[0m


Epoch: 21/300, train loss : 0.6159840226173401, validation loss : 0.703702449798584
Early stop counter: 10
Epoch: 22/300, train loss : 0.6156003475189209, validation loss : 0.7282737493515015
Early stopping...
Fold 0
Epoch: 1/300, train loss : 2.72357040643692, validation loss : 0.6925073266029358
Early stop counter: 0
Epoch: 2/300, train loss : 0.7352586984634399, validation loss : 0.6329557299613953
Early stop counter: 0
Epoch: 3/300, train loss : 0.7083595842123032, validation loss : 0.6853950619697571
Early stop counter: 1
Epoch: 4/300, train loss : 0.6873253285884857, validation loss : 0.7035406827926636
Early stop counter: 2
Epoch: 5/300, train loss : 0.7038541436195374, validation loss : 0.6660692691802979
Early stop counter: 3
Epoch: 6/300, train loss : 0.6729782074689865, validation loss : 0.660906970500946
Early stop counter: 4
Epoch: 7/300, train loss : 0.6722167134284973, validation loss : 0.662970244884491
Early stop counter: 5
Epoch: 8/300, train loss : 0.6683626621961594

[32m[I 2023-03-15 14:43:00,217][0m Trial 24 finished with value: 0.6320836901664734 and parameters: {'num_gin_layers': 1, 'num_graph_trans_layers': 2, 'hidden_size': 202, 'n_heads': 2, 'dropout': 0.37795746635738514, 'learning_rate': 0.00458219019357291}. Best is trial 22 with value: 0.6207164764404297.[0m


Epoch: 12/300, train loss : 0.6395412385463715, validation loss : 0.7294265031814575
Early stop counter: 9
Epoch: 13/300, train loss : 0.6246156245470047, validation loss : 0.7759663462638855
Early stop counter: 10
Epoch: 14/300, train loss : 0.6222911924123764, validation loss : 0.7810398936271667
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.0797342509031296, validation loss : 0.7569546699523926
Early stop counter: 0
Epoch: 2/300, train loss : 0.6964000463485718, validation loss : 0.6877037286758423
Early stop counter: 0
Epoch: 3/300, train loss : 0.6849812716245651, validation loss : 0.753411054611206
Early stop counter: 1
Epoch: 4/300, train loss : 0.6948285698890686, validation loss : 0.6701574921607971
Early stop counter: 0
Epoch: 5/300, train loss : 0.6836667060852051, validation loss : 0.6755722165107727
Early stop counter: 1
Epoch: 6/300, train loss : 0.6804885864257812, validation loss : 0.6688629388809204
Early stop counter: 0
Epoch: 7/300, train loss : 0.67489522695

[32m[I 2023-03-15 14:43:17,072][0m Trial 25 finished with value: 0.6301232576370239 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 124, 'n_heads': 1, 'dropout': 0.338835929528963, 'learning_rate': 0.0035971339843088096}. Best is trial 22 with value: 0.6207164764404297.[0m


Epoch: 25/300, train loss : 0.6152494698762894, validation loss : 0.8043540716171265
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.8016041070222855, validation loss : 0.7150043845176697
Early stop counter: 0
Epoch: 2/300, train loss : 0.6884400844573975, validation loss : 0.6980074048042297
Early stop counter: 0
Epoch: 3/300, train loss : 0.6833521276712418, validation loss : 0.6502328515052795
Early stop counter: 0
Epoch: 4/300, train loss : 0.6742886900901794, validation loss : 0.6876153349876404
Early stop counter: 1
Epoch: 5/300, train loss : 0.679464116692543, validation loss : 0.6399984955787659
Early stop counter: 0
Epoch: 6/300, train loss : 0.6736467331647873, validation loss : 0.648555338382721
Early stop counter: 1
Epoch: 7/300, train loss : 0.6680441796779633, validation loss : 0.6650196313858032
Early stop counter: 2
Epoch: 8/300, train loss : 0.6678393185138702, validation loss : 0.6373257040977478
Early stop counter: 0
Epoch: 9/300, train loss : 0.660387724637985

[32m[I 2023-03-15 14:43:33,659][0m Trial 26 finished with value: 0.6392564415931702 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 1, 'hidden_size': 158, 'n_heads': 4, 'dropout': 0.37537069660557243, 'learning_rate': 0.002793785742219527}. Best is trial 22 with value: 0.6207164764404297.[0m


Epoch: 19/300, train loss : 0.6338939070701599, validation loss : 0.699033796787262
Early stop counter: 10
Epoch: 20/300, train loss : 0.6274111568927765, validation loss : 0.7004782557487488
Early stopping...
Fold 0
Epoch: 1/300, train loss : 2.7689284086227417, validation loss : 0.6650418043136597
Early stop counter: 0
Epoch: 2/300, train loss : 0.8539640307426453, validation loss : 0.6347543597221375
Early stop counter: 0
Epoch: 3/300, train loss : 0.8100211918354034, validation loss : 0.683683454990387
Early stop counter: 1
Epoch: 4/300, train loss : 0.6881619542837143, validation loss : 0.6953772902488708
Early stop counter: 2
Epoch: 5/300, train loss : 0.7027537822723389, validation loss : 0.6348609924316406
Early stop counter: 3
Epoch: 6/300, train loss : 0.6910440027713776, validation loss : 0.647830605506897
Early stop counter: 4
Epoch: 7/300, train loss : 0.7480867654085159, validation loss : 0.6724045872688293
Early stop counter: 5
Epoch: 8/300, train loss : 0.72617679834365

[32m[I 2023-03-15 14:43:59,954][0m Trial 27 finished with value: 0.6383547425270081 and parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 3, 'hidden_size': 250, 'n_heads': 2, 'dropout': 0.34452621513408305, 'learning_rate': 0.0034158757157527336}. Best is trial 22 with value: 0.6207164764404297.[0m


Epoch: 27/300, train loss : 0.6365452706813812, validation loss : 0.6806275844573975
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7262941151857376, validation loss : 0.6819860935211182
Early stop counter: 0
Epoch: 2/300, train loss : 0.6951090842485428, validation loss : 0.7066640257835388
Early stop counter: 1
Epoch: 3/300, train loss : 0.6911783665418625, validation loss : 0.6842896938323975
Early stop counter: 2
Epoch: 4/300, train loss : 0.6887910664081573, validation loss : 0.697506844997406
Early stop counter: 3
Epoch: 5/300, train loss : 0.6846088469028473, validation loss : 0.6651571989059448
Early stop counter: 0
Epoch: 6/300, train loss : 0.6757830828428268, validation loss : 0.683417558670044
Early stop counter: 1
Epoch: 7/300, train loss : 0.6733037978410721, validation loss : 0.6313380599021912
Early stop counter: 0
Epoch: 8/300, train loss : 0.6815074384212494, validation loss : 0.6558529138565063
Early stop counter: 1
Epoch: 9/300, train loss : 0.667828619480133

[32m[I 2023-03-15 14:44:13,560][0m Trial 28 finished with value: 0.6322495341300964 and parameters: {'num_gin_layers': 3, 'num_graph_trans_layers': 2, 'hidden_size': 69, 'n_heads': 1, 'dropout': 0.3199990661760711, 'learning_rate': 0.0043841732304825435}. Best is trial 22 with value: 0.6207164764404297.[0m


Epoch: 16/300, train loss : 0.636401429772377, validation loss : 0.6701109409332275
Early stop counter: 9
Epoch: 17/300, train loss : 0.6319057196378708, validation loss : 0.7000323534011841
Early stop counter: 10
Epoch: 18/300, train loss : 0.6296630799770355, validation loss : 0.7435418367385864
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7940010130405426, validation loss : 0.7035757899284363
Early stop counter: 0
Epoch: 2/300, train loss : 0.6939452290534973, validation loss : 0.6708582639694214
Early stop counter: 0
Epoch: 3/300, train loss : 0.6914562731981277, validation loss : 0.7078711986541748
Early stop counter: 1
Epoch: 4/300, train loss : 0.682930126786232, validation loss : 0.7248508930206299
Early stop counter: 2
Epoch: 5/300, train loss : 0.6960975825786591, validation loss : 0.6629201769828796
Early stop counter: 0
Epoch: 6/300, train loss : 0.6860993504524231, validation loss : 0.6317836046218872
Early stop counter: 0
Epoch: 7/300, train loss : 0.681071087718

[32m[I 2023-03-15 14:44:29,899][0m Trial 29 finished with value: 0.6306508302688598 and parameters: {'num_gin_layers': 1, 'num_graph_trans_layers': 3, 'hidden_size': 117, 'n_heads': 1, 'dropout': 0.37418226724769776, 'learning_rate': 0.004902723569833367}. Best is trial 22 with value: 0.6207164764404297.[0m


Epoch: 20/300, train loss : 0.6076437681913376, validation loss : 0.7799637317657471
Early stopping...
best trial:
[0.6207164764404297]
Best parameters: {'num_gin_layers': 2, 'num_graph_trans_layers': 2, 'hidden_size': 122, 'n_heads': 2, 'dropout': 0.36738054656589025, 'learning_rate': 0.00452976319043267}


## 3. Train/validate/test model

1. After tuning, the best parameters were saved in config.py as params_vertical_gnn
2. Model was trained/validated/tested using the best parameters and 5-fold CV. Process was repeated 5 times and the final results were averaged.
3. Ensure that data folder, from_scratch_trained_models folder and this notebook are in the same working directory
4. run_training to repeat the training process, please create a new folder to retrain, save and load saved model
5. Else, comment the run_training function and do testing to get the results which were obtained in the paper. 

In [2]:
def run_training(train_loader, valid_loader, params, trained_model_path):
    model = VerticalGNN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_gin_layers=params['num_gin_layers'], num_graph_trans_layers=params['num_graph_trans_layers'], 
                            hidden_size=params['hidden_size'], n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0 

    for epoch in range(EPOCHS):
        train_loss = eng.train(train_loader)
        valid_loss = eng.validate(valid_loader)
        print(f'Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}')
        if valid_loss < best_loss:
            best_loss = valid_loss 
            early_stopping_counter=0 #reset counter
            print('Saving model...')
            torch.save(model.state_dict(), trained_model_path)
        else:
            early_stopping_counter +=1

        if early_stopping_counter > early_stopping_iter:
            print('Early stopping...')
            break
        print(f'Early stop counter: {early_stopping_counter}')
    
    return best_loss

def run_testing(test_loader, params, trained_model_path):
    model = VerticalGNN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_gin_layers=params['num_gin_layers'], num_graph_trans_layers=params['num_graph_trans_layers'], 
                            hidden_size=params['hidden_size'], n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.load_state_dict(torch.load(trained_model_path))
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB(model, optimizer, device=DEVICE)

    print('Begin testing...')
    bce, acc, f1, roc_auc = eng.test(test_loader)
    print('Test completed!')
    print(f'bce:{bce}, acc:{acc}, f1: {f1}, roc_auc: {roc_auc}')
    return bce, acc, f1, roc_auc

In [4]:
params = params_vertical_gnn
n_repetitions = 5
train_data_root_path = './data/graph_data/data_oral_avail_train/'
train_data_raw_filename = 'data_oral_avail_train_50.csv'
test_data_root_path = './data/graph_data/data_oral_avail_test/'
test_data_raw_filename = 'data_oral_avail_test_1_50.csv'
path_to_save_trained_model = './from_scratch_trained_models/vertical/'

bce_list = []
acc_list = []
f1_list = []
roc_auc_list = []

#load dataset 
dataset_for_cv = LoadHOBDataset(root=train_data_root_path, raw_filename=train_data_raw_filename)
test_dataset = LoadHOBDataset(root=test_data_root_path, raw_filename=test_data_raw_filename)
kf = KFold(n_splits=N_SPLITS)

for repeat in range(n_repetitions):
    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        seed_everything(SEED_NO)
        train_dataset= []
        valid_dataset = []
        for t_idx in train_idx:
            train_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{t_idx}.pt'))
        for v_idx in valid_idx:
            valid_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{v_idx}.pt'))

        train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)

        run_training(train_loader, valid_loader, params, os.path.join(path_to_save_trained_model, f'vertical_repeat_{repeat}_fold_{fold_no}.pt'))
        bce, acc, f1, roc_auc = run_testing(test_loader, params, os.path.join(path_to_save_trained_model, f'vertical_repeat_{repeat}_fold_{fold_no}.pt'))
        bce_list.append(bce)
        acc_list.append(acc)
        f1_list.append(f1)
        roc_auc_list.append(roc_auc)

bce_arr = np.array(bce_list)
mean_bce = np.mean(bce_arr)
sd_bce = np.std(bce_arr)
print(f'bce:{mean_bce:.3f}±{sd_bce:.3f}')

acc_arr = np.array(acc_list)
acc_mean= np.mean(acc_arr)
acc_sd = np.std(acc_arr)
print(f'acc:{acc_mean:.3f}±{acc_sd:.3f}')

f1_arr = np.array(f1_list)
f1_mean= np.mean(f1_arr)
f1_sd = np.std(f1_arr)
print(f'f1: {f1_mean:.3f}±{f1_sd:.3f}')

roc_auc_arr = np.array(roc_auc_list)
roc_auc_mean= np.mean(roc_auc_arr)
roc_auc_sd = np.std(roc_auc_arr)
print(f'roc_auc: {roc_auc_mean:.3f}±{roc_auc_sd:.3f}')

Epoch: 1/300, train loss : 0.9067558795213699, validation loss : 0.7183759808540344
Saving model...
Early stop counter: 0
Epoch: 2/300, train loss : 0.7035864740610123, validation loss : 0.6651412844657898
Saving model...
Early stop counter: 0
Epoch: 3/300, train loss : 0.6858863681554794, validation loss : 0.6628757119178772
Saving model...
Early stop counter: 0
Epoch: 4/300, train loss : 0.6888501644134521, validation loss : 0.6296036839485168
Saving model...
Early stop counter: 0
Epoch: 5/300, train loss : 0.6920813173055649, validation loss : 0.655631959438324
Early stop counter: 1
Epoch: 6/300, train loss : 0.6713947802782059, validation loss : 0.6383062601089478
Early stop counter: 2
Epoch: 7/300, train loss : 0.6747748851776123, validation loss : 0.6325588822364807
Early stop counter: 3
Epoch: 8/300, train loss : 0.6658817827701569, validation loss : 0.6485603451728821
Early stop counter: 4
Epoch: 9/300, train loss : 0.6621689051389694, validation loss : 0.6198813319206238
Savin