# Graph Transformer Model for Oral Bioavailability Dataset 

1. In this notebook, we will build a model based on the graph transformer convolution technique that is available on Pytorch Geometric. 
2. Hyperparameters will be obtained using Optuna library following the Tree-structured Parzen Estimator Algorithm in 30 trials.
3. Models will be trained/validated/tested using 5-fold CV and results will be averaged. 


### Note
1. Ensure that this notebook is in the same working directory as data folder, config.py, utils.py, engine.py and model.py
2. To load models, please download the saved models provided in google drive link from README.md
3. Comment away the training function to load saved models and reproduce results

In [1]:
from utils import seed_everything, LoadHOBDataset
from config import SEED_NO, NUM_FEATURES, NUM_GRAPHS_PER_BATCH, NUM_TARGET, EDGE_DIM, DEVICE, PATIENCE, EPOCHS, N_SPLITS, params_gt
from engine import EngineHOB
from model import GraphTrans

import torch
import numpy as np
import optuna
from sklearn.model_selection import KFold
from torch_geometric.loader import DataLoader
import os 

## 2. Tuning of model (Suggest to skip this step and used the already found hyperparameters saved in config.py)

1. To tune model, we make use of Optuna library to help us find the best parameters in 30 trials using the Tree-structured Parzen Estimator Algorithm.
2. First, create a run_tuning function to include the train and validation step with early stopping mechanism to find the best parameters for each trial
3. Then, create an objective function for Optuna to find the best parameters.

In [5]:
def run_tuning(train_loader, valid_loader, params):
    model = GraphTrans(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_layers=params['num_layers'], hidden_size=params['hidden_size'], n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0 

    for epoch in range(EPOCHS):
        train_loss = eng.train(train_loader)
        valid_loss = eng.validate(valid_loader)
        print(f'Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}')
        if valid_loss < best_loss:
            best_loss = valid_loss 
            early_stopping_counter=0
        else:
            early_stopping_counter +=1

        if early_stopping_counter > early_stopping_iter:
            print('Early stopping...')
            break
        print(f'Early stop counter: {early_stopping_counter}')
    
    return best_loss



In [6]:
def objective(trial):
    params = {
        'num_layers' : trial.suggest_int('num_layers', 1,3),
        'hidden_size' : trial.suggest_int('hidden_size', 64, 512),
        'n_heads' : trial.suggest_int('n_heads', 1, 5),
        'dropout': trial.suggest_float('dropout', 0.1,0.4),
        'learning_rate' : trial.suggest_float('learning_rate', 1e-3, 9e-3, log=True)
    }
    
    #load dataset 
    dataset_for_cv = LoadHOBDataset(root='./data/graph_data/data_oral_avail_train/', raw_filename='data_oral_avail_train_50.csv')
    kf = KFold(n_splits=N_SPLITS)
    fold_loss = 0

    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        print(f'Fold {fold_no}')
        train_dataset= []
        valid_dataset = []
        for t_idx in train_idx:
            train_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{t_idx}.pt'))
        for v_idx in valid_idx:
            valid_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{v_idx}.pt'))
        seed_everything(SEED_NO)
        train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)

        loss = run_tuning(train_loader, valid_loader, params)
        fold_loss += loss

    return fold_loss/5


study = optuna.create_study(direction = 'minimize')
study.optimize(objective, n_trials=30)
print(f'best trial:')
trial_ = study.best_trial
print(trial_.values)
print(f'Best parameters: {trial_.params}')

[32m[I 2023-03-15 13:30:33,387][0m A new study created in memory with name: no-name-bab4f733-f846-4c1e-9b5d-aab8ca74f92d[0m


Fold 0
Epoch: 1/300, train loss : 0.6911721676588058, validation loss : 0.6812695264816284
Early stop counter: 0
Epoch: 2/300, train loss : 0.6876598298549652, validation loss : 0.6729214191436768
Early stop counter: 0
Epoch: 3/300, train loss : 0.6792905032634735, validation loss : 0.682636022567749
Early stop counter: 1
Epoch: 4/300, train loss : 0.6789697259664536, validation loss : 0.6731582283973694
Early stop counter: 2
Epoch: 5/300, train loss : 0.674514576792717, validation loss : 0.6631226539611816
Early stop counter: 0
Epoch: 6/300, train loss : 0.6718787103891373, validation loss : 0.6692750453948975
Early stop counter: 1
Epoch: 7/300, train loss : 0.6740042120218277, validation loss : 0.6720089912414551
Early stop counter: 2
Epoch: 8/300, train loss : 0.6648592799901962, validation loss : 0.6541059017181396
Early stop counter: 0
Epoch: 9/300, train loss : 0.6653721779584885, validation loss : 0.6494285464286804
Early stop counter: 0
Epoch: 10/300, train loss : 0.66064697504

[32m[I 2023-03-15 13:30:50,265][0m Trial 0 finished with value: 0.64981027841568 and parameters: {'num_layers': 1, 'hidden_size': 161, 'n_heads': 5, 'dropout': 0.3871112302613642, 'learning_rate': 0.0011787497218360226}. Best is trial 0 with value: 0.64981027841568.[0m


Epoch: 19/300, train loss : 0.6503708958625793, validation loss : 0.6889975666999817
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.6888761818408966, validation loss : 0.6358839869499207
Early stop counter: 0
Epoch: 2/300, train loss : 0.6903033405542374, validation loss : 0.7287794947624207
Early stop counter: 1
Epoch: 3/300, train loss : 0.68180713057518, validation loss : 0.6532896161079407
Early stop counter: 2
Epoch: 4/300, train loss : 0.6880973875522614, validation loss : 0.6693705320358276
Early stop counter: 3
Epoch: 5/300, train loss : 0.6761141568422318, validation loss : 0.7117066383361816
Early stop counter: 4
Epoch: 6/300, train loss : 0.6798339486122131, validation loss : 0.6660518050193787
Early stop counter: 5
Epoch: 7/300, train loss : 0.6744946539402008, validation loss : 0.6573882699012756
Early stop counter: 6
Epoch: 8/300, train loss : 0.6696155220270157, validation loss : 0.680306077003479
Early stop counter: 7
Epoch: 9/300, train loss : 0.6692708730697632

[32m[I 2023-03-15 13:30:59,912][0m Trial 1 finished with value: 0.6617109656333924 and parameters: {'num_layers': 1, 'hidden_size': 220, 'n_heads': 3, 'dropout': 0.37166155147046764, 'learning_rate': 0.002118823997957276}. Best is trial 0 with value: 0.64981027841568.[0m


Epoch: 16/300, train loss : 0.6504044681787491, validation loss : 0.6839689612388611
Early stop counter: 10
Epoch: 17/300, train loss : 0.6451492309570312, validation loss : 0.696839451789856
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.986307293176651, validation loss : 0.6442740559577942
Early stop counter: 0
Epoch: 2/300, train loss : 0.7411888390779495, validation loss : 0.6889698505401611
Early stop counter: 1
Epoch: 3/300, train loss : 0.7069151252508163, validation loss : 0.6902547478675842
Early stop counter: 2
Epoch: 4/300, train loss : 0.6900874674320221, validation loss : 0.7031627297401428
Early stop counter: 3
Epoch: 5/300, train loss : 0.68797966837883, validation loss : 0.7133988738059998
Early stop counter: 4
Epoch: 6/300, train loss : 0.676197960972786, validation loss : 0.6675675511360168
Early stop counter: 5
Epoch: 7/300, train loss : 0.6666582971811295, validation loss : 0.7230732440948486
Early stop counter: 6
Epoch: 8/300, train loss : 0.6650547534227371

[32m[I 2023-03-15 13:31:15,402][0m Trial 2 finished with value: 0.6505545020103455 and parameters: {'num_layers': 2, 'hidden_size': 202, 'n_heads': 3, 'dropout': 0.19805367666895574, 'learning_rate': 0.006678206506902249}. Best is trial 0 with value: 0.64981027841568.[0m


Epoch: 19/300, train loss : 0.6378530114889145, validation loss : 0.7617124319076538
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7702308744192123, validation loss : 0.7260724306106567
Early stop counter: 0
Epoch: 2/300, train loss : 0.6913025826215744, validation loss : 0.6427757143974304
Early stop counter: 0
Epoch: 3/300, train loss : 0.6932570785284042, validation loss : 0.6282889246940613
Early stop counter: 0
Epoch: 4/300, train loss : 0.6994657963514328, validation loss : 0.6893200874328613
Early stop counter: 1
Epoch: 5/300, train loss : 0.6907231211662292, validation loss : 0.6632497906684875
Early stop counter: 2
Epoch: 6/300, train loss : 0.6796337962150574, validation loss : 0.6969779133796692
Early stop counter: 3
Epoch: 7/300, train loss : 0.6777013838291168, validation loss : 0.6589230895042419
Early stop counter: 4
Epoch: 8/300, train loss : 0.6744575649499893, validation loss : 0.6711001992225647
Early stop counter: 5
Epoch: 9/300, train loss : 0.6681259125471

[32m[I 2023-03-15 13:31:48,447][0m Trial 3 finished with value: 0.6573238253593445 and parameters: {'num_layers': 2, 'hidden_size': 288, 'n_heads': 5, 'dropout': 0.30796144276734916, 'learning_rate': 0.001667314898632356}. Best is trial 0 with value: 0.64981027841568.[0m


Epoch: 16/300, train loss : 0.6596137136220932, validation loss : 0.6829082369804382
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7738612592220306, validation loss : 0.6835311055183411
Early stop counter: 0
Epoch: 2/300, train loss : 0.6922923177480698, validation loss : 0.6574801802635193
Early stop counter: 0
Epoch: 3/300, train loss : 0.7013497948646545, validation loss : 0.7714528441429138
Early stop counter: 1
Epoch: 4/300, train loss : 0.7011115252971649, validation loss : 0.6901012063026428
Early stop counter: 2
Epoch: 5/300, train loss : 0.7258543074131012, validation loss : 0.6936281323432922
Early stop counter: 3
Epoch: 6/300, train loss : 0.6867896318435669, validation loss : 0.7001267671585083
Early stop counter: 4
Epoch: 7/300, train loss : 0.7032410353422165, validation loss : 0.6639366149902344
Early stop counter: 5
Epoch: 8/300, train loss : 0.6758846491575241, validation loss : 0.6944116353988647
Early stop counter: 6
Epoch: 9/300, train loss : 0.6740438938140

[32m[I 2023-03-15 13:31:59,782][0m Trial 4 finished with value: 0.6598024725914001 and parameters: {'num_layers': 3, 'hidden_size': 110, 'n_heads': 2, 'dropout': 0.3240315216979728, 'learning_rate': 0.007536810936580852}. Best is trial 0 with value: 0.64981027841568.[0m


Epoch: 25/300, train loss : 0.6816695928573608, validation loss : 0.6933751106262207
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.6909482777118683, validation loss : 0.7035908102989197
Early stop counter: 0
Epoch: 2/300, train loss : 0.6822709143161774, validation loss : 0.6675543785095215
Early stop counter: 0
Epoch: 3/300, train loss : 0.6818571239709854, validation loss : 0.6878848671913147
Early stop counter: 1
Epoch: 4/300, train loss : 0.6767648160457611, validation loss : 0.6704422235488892
Early stop counter: 2
Epoch: 5/300, train loss : 0.6756220608949661, validation loss : 0.6746172308921814
Early stop counter: 3
Epoch: 6/300, train loss : 0.6718937307596207, validation loss : 0.6653411388397217
Early stop counter: 0
Epoch: 7/300, train loss : 0.6672982424497604, validation loss : 0.6864749193191528
Early stop counter: 1
Epoch: 8/300, train loss : 0.6601445227861404, validation loss : 0.6523298025131226
Early stop counter: 0
Epoch: 9/300, train loss : 0.6601818501949

[32m[I 2023-03-15 13:32:09,997][0m Trial 5 finished with value: 0.6480085492134094 and parameters: {'num_layers': 1, 'hidden_size': 194, 'n_heads': 2, 'dropout': 0.10561257324030576, 'learning_rate': 0.002104161243337813}. Best is trial 5 with value: 0.6480085492134094.[0m


Epoch: 19/300, train loss : 0.6383937746286392, validation loss : 0.6907429099082947
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7132149934768677, validation loss : 0.6589840054512024
Early stop counter: 0
Epoch: 2/300, train loss : 0.6896563917398453, validation loss : 0.6902422308921814
Early stop counter: 1
Epoch: 3/300, train loss : 0.6881930083036423, validation loss : 0.6818257570266724
Early stop counter: 2
Epoch: 4/300, train loss : 0.6817255318164825, validation loss : 0.687897264957428
Early stop counter: 3
Epoch: 5/300, train loss : 0.6769954562187195, validation loss : 0.6542435884475708
Early stop counter: 0
Epoch: 6/300, train loss : 0.6780378669500351, validation loss : 0.6723196506500244
Early stop counter: 1
Epoch: 7/300, train loss : 0.6711662709712982, validation loss : 0.6680257320404053
Early stop counter: 2
Epoch: 8/300, train loss : 0.6686166822910309, validation loss : 0.6476453542709351
Early stop counter: 0
Epoch: 9/300, train loss : 0.66696180403232

[32m[I 2023-03-15 13:32:19,937][0m Trial 6 finished with value: 0.658790135383606 and parameters: {'num_layers': 1, 'hidden_size': 289, 'n_heads': 2, 'dropout': 0.25342336283576117, 'learning_rate': 0.0042151855834971766}. Best is trial 5 with value: 0.6480085492134094.[0m


Epoch: 16/300, train loss : 0.6377052962779999, validation loss : 0.8025647401809692
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.06396122276783, validation loss : 0.655879557132721
Early stop counter: 0
Epoch: 2/300, train loss : 0.8221970349550247, validation loss : 0.6581489443778992
Early stop counter: 1
Epoch: 3/300, train loss : 0.7898900061845779, validation loss : 1.0955724716186523
Early stop counter: 2
Epoch: 4/300, train loss : 0.7499321550130844, validation loss : 0.7419232130050659
Early stop counter: 3
Epoch: 5/300, train loss : 0.6921905279159546, validation loss : 0.6522857546806335
Early stop counter: 0
Epoch: 6/300, train loss : 0.6866988241672516, validation loss : 0.6408199667930603
Early stop counter: 0
Epoch: 7/300, train loss : 0.6861744672060013, validation loss : 0.8576681017875671
Early stop counter: 1
Epoch: 8/300, train loss : 0.693555161356926, validation loss : 0.7374670505523682
Early stop counter: 2
Epoch: 9/300, train loss : 0.6748857945203781,

[32m[I 2023-03-15 13:32:37,610][0m Trial 7 finished with value: 0.6388066053390503 and parameters: {'num_layers': 2, 'hidden_size': 439, 'n_heads': 1, 'dropout': 0.269754753387312, 'learning_rate': 0.007890910361468965}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 28/300, train loss : 0.6367311030626297, validation loss : 0.7440678477287292
Early stop counter: 10
Epoch: 29/300, train loss : 0.6469134837388992, validation loss : 0.7438643574714661
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.68833227455616, validation loss : 0.7544758319854736
Early stop counter: 0
Epoch: 2/300, train loss : 0.6868565678596497, validation loss : 0.6565973162651062
Early stop counter: 0
Epoch: 3/300, train loss : 0.6825407296419144, validation loss : 0.7065930366516113
Early stop counter: 1
Epoch: 4/300, train loss : 0.6800742000341415, validation loss : 0.6586053371429443
Early stop counter: 2
Epoch: 5/300, train loss : 0.6825584024190903, validation loss : 0.7019541263580322
Early stop counter: 3
Epoch: 6/300, train loss : 0.6759952902793884, validation loss : 0.6605185270309448
Early stop counter: 4
Epoch: 7/300, train loss : 0.6745304465293884, validation loss : 0.6629371643066406
Early stop counter: 5
Epoch: 8/300, train loss : 0.6709974408149

[32m[I 2023-03-15 13:32:48,963][0m Trial 8 finished with value: 0.652019488811493 and parameters: {'num_layers': 1, 'hidden_size': 509, 'n_heads': 1, 'dropout': 0.39903533489300635, 'learning_rate': 0.001904135807346741}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 16/300, train loss : 0.6437266767024994, validation loss : 0.7040446400642395
Early stop counter: 9
Epoch: 17/300, train loss : 0.6409951448440552, validation loss : 0.7078997492790222
Early stop counter: 10
Epoch: 18/300, train loss : 0.6523823142051697, validation loss : 0.7137081027030945
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7197123169898987, validation loss : 0.6561751365661621
Early stop counter: 0
Epoch: 2/300, train loss : 0.691544383764267, validation loss : 0.6982676982879639
Early stop counter: 1
Epoch: 3/300, train loss : 0.6871223747730255, validation loss : 0.6802908778190613
Early stop counter: 2
Epoch: 4/300, train loss : 0.6800472289323807, validation loss : 0.6610471606254578
Early stop counter: 3
Epoch: 5/300, train loss : 0.682014524936676, validation loss : 0.6879809498786926
Early stop counter: 4
Epoch: 6/300, train loss : 0.6818906217813492, validation loss : 0.6789426207542419
Early stop counter: 5
Epoch: 7/300, train loss : 0.676727101206

[32m[I 2023-03-15 13:33:08,311][0m Trial 9 finished with value: 0.6627998352050781 and parameters: {'num_layers': 1, 'hidden_size': 411, 'n_heads': 4, 'dropout': 0.13474993098465052, 'learning_rate': 0.0011219000800838656}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 25/300, train loss : 0.6513558477163315, validation loss : 0.6965051889419556
Early stop counter: 10
Epoch: 26/300, train loss : 0.6468485742807388, validation loss : 0.691743791103363
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.0346713811159134, validation loss : 0.7211251258850098
Early stop counter: 0
Epoch: 2/300, train loss : 0.7440380156040192, validation loss : 0.6649755239486694
Early stop counter: 0
Epoch: 3/300, train loss : 0.7183245271444321, validation loss : 0.6959592700004578
Early stop counter: 1
Epoch: 4/300, train loss : 0.6918513178825378, validation loss : 0.6590871214866638
Early stop counter: 0
Epoch: 5/300, train loss : 0.6857163459062576, validation loss : 0.6626026630401611
Early stop counter: 1
Epoch: 6/300, train loss : 0.6824850142002106, validation loss : 0.7306544780731201
Early stop counter: 2
Epoch: 7/300, train loss : 0.6751093119382858, validation loss : 0.6282528638839722
Early stop counter: 0
Epoch: 8/300, train loss : 0.672111451625

[32m[I 2023-03-15 13:33:34,279][0m Trial 10 finished with value: 0.6527358174324036 and parameters: {'num_layers': 3, 'hidden_size': 455, 'n_heads': 1, 'dropout': 0.22274318640434548, 'learning_rate': 0.004247295868658231}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 21/300, train loss : 0.6595242321491241, validation loss : 0.7311027646064758
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7818417847156525, validation loss : 0.6475486755371094
Early stop counter: 0
Epoch: 2/300, train loss : 0.6992170363664627, validation loss : 0.6968205571174622
Early stop counter: 1
Epoch: 3/300, train loss : 0.6872373670339584, validation loss : 0.6628683805465698
Early stop counter: 2
Epoch: 4/300, train loss : 0.6816396266222, validation loss : 0.6853429079055786
Early stop counter: 3
Epoch: 5/300, train loss : 0.6858441382646561, validation loss : 0.7024410963058472
Early stop counter: 4
Epoch: 6/300, train loss : 0.6943676322698593, validation loss : 0.6789358258247375
Early stop counter: 5
Epoch: 7/300, train loss : 0.6722781807184219, validation loss : 0.7110617160797119
Early stop counter: 6
Epoch: 8/300, train loss : 0.6839869767427444, validation loss : 0.7237022519111633
Early stop counter: 7
Epoch: 9/300, train loss : 0.6742177307605743

[32m[I 2023-03-15 13:33:59,276][0m Trial 11 finished with value: 0.6622324228286743 and parameters: {'num_layers': 2, 'hidden_size': 363, 'n_heads': 2, 'dropout': 0.10557532293609712, 'learning_rate': 0.0032551153604604486}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 31/300, train loss : 0.5999542474746704, validation loss : 0.7913646697998047
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.6956493854522705, validation loss : 0.6585775017738342
Early stop counter: 0
Epoch: 2/300, train loss : 0.6954935789108276, validation loss : 0.6990618109703064
Early stop counter: 1
Epoch: 3/300, train loss : 0.6915328502655029, validation loss : 0.671927273273468
Early stop counter: 2
Epoch: 4/300, train loss : 0.6828477382659912, validation loss : 0.6974551677703857
Early stop counter: 3
Epoch: 5/300, train loss : 0.6843799650669098, validation loss : 0.6706282496452332
Early stop counter: 4
Epoch: 6/300, train loss : 0.6758583337068558, validation loss : 0.6735405921936035
Early stop counter: 5
Epoch: 7/300, train loss : 0.6722227334976196, validation loss : 0.6513606905937195
Early stop counter: 0
Epoch: 8/300, train loss : 0.6684891730546951, validation loss : 0.6529885530471802
Early stop counter: 1
Epoch: 9/300, train loss : 0.66130793094635

[32m[I 2023-03-15 13:34:21,437][0m Trial 12 finished with value: 0.646507716178894 and parameters: {'num_layers': 2, 'hidden_size': 73, 'n_heads': 1, 'dropout': 0.16811226492283524, 'learning_rate': 0.008103370959664282}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 13/300, train loss : 0.639606773853302, validation loss : 0.7039448618888855
Early stop counter: 10
Epoch: 14/300, train loss : 0.6414320170879364, validation loss : 0.6964887976646423
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.6938871592283249, validation loss : 0.6922205090522766
Early stop counter: 0
Epoch: 2/300, train loss : 0.688055694103241, validation loss : 0.6902942657470703
Early stop counter: 0
Epoch: 3/300, train loss : 0.6825551837682724, validation loss : 0.6689253449440002
Early stop counter: 0
Epoch: 4/300, train loss : 0.6753766685724258, validation loss : 0.6615397930145264
Early stop counter: 0
Epoch: 5/300, train loss : 0.6781941652297974, validation loss : 0.6443039178848267
Early stop counter: 0
Epoch: 6/300, train loss : 0.6668312400579453, validation loss : 0.6579087376594543
Early stop counter: 1
Epoch: 7/300, train loss : 0.670466348528862, validation loss : 0.7083784937858582
Early stop counter: 2
Epoch: 8/300, train loss : 0.66341114044189

[32m[I 2023-03-15 13:34:44,342][0m Trial 13 finished with value: 0.6467840909957886 and parameters: {'num_layers': 2, 'hidden_size': 75, 'n_heads': 1, 'dropout': 0.1781577460691135, 'learning_rate': 0.008958833665484608}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 17/300, train loss : 0.6215587258338928, validation loss : 0.7277522087097168
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.808802455663681, validation loss : 0.6440176963806152
Early stop counter: 0
Epoch: 2/300, train loss : 0.6951721161603928, validation loss : 0.7318446040153503
Early stop counter: 1
Epoch: 3/300, train loss : 0.6922590583562851, validation loss : 0.6560274362564087
Early stop counter: 2
Epoch: 4/300, train loss : 0.6922609955072403, validation loss : 0.6896747946739197
Early stop counter: 3
Epoch: 5/300, train loss : 0.6945546418428421, validation loss : 0.6398727893829346
Early stop counter: 0
Epoch: 6/300, train loss : 0.6855252385139465, validation loss : 0.7285165190696716
Early stop counter: 1
Epoch: 7/300, train loss : 0.6799851953983307, validation loss : 0.6217381954193115
Early stop counter: 0
Epoch: 8/300, train loss : 0.6726098358631134, validation loss : 0.7811207175254822
Early stop counter: 1
Epoch: 9/300, train loss : 0.67717942595481

[32m[I 2023-03-15 13:35:06,768][0m Trial 14 finished with value: 0.658780312538147 and parameters: {'num_layers': 3, 'hidden_size': 332, 'n_heads': 1, 'dropout': 0.26081471413758955, 'learning_rate': 0.005920660027662791}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 29/300, train loss : 0.6815356016159058, validation loss : 0.6910282969474792
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.915458932518959, validation loss : 0.9633756279945374
Early stop counter: 0
Epoch: 2/300, train loss : 0.900442898273468, validation loss : 0.9642677903175354
Early stop counter: 1
Epoch: 3/300, train loss : 0.7512510269880295, validation loss : 0.7758955359458923
Early stop counter: 0
Epoch: 4/300, train loss : 0.7099994271993637, validation loss : 0.6776680946350098
Early stop counter: 0
Epoch: 5/300, train loss : 0.6957384496927261, validation loss : 0.7288016080856323
Early stop counter: 1
Epoch: 6/300, train loss : 0.6869827806949615, validation loss : 0.6907327175140381
Early stop counter: 2
Epoch: 7/300, train loss : 0.681951493024826, validation loss : 0.6981053948402405
Early stop counter: 3
Epoch: 8/300, train loss : 0.6778732091188431, validation loss : 0.6787140965461731
Early stop counter: 4
Epoch: 9/300, train loss : 0.6708990186452866

[32m[I 2023-03-15 13:35:22,660][0m Trial 15 finished with value: 0.6482647061347961 and parameters: {'num_layers': 2, 'hidden_size': 499, 'n_heads': 1, 'dropout': 0.16973301426403262, 'learning_rate': 0.005868079516242351}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 21/300, train loss : 0.6258390247821808, validation loss : 0.8147521018981934
Early stopping...
Fold 0
Epoch: 1/300, train loss : 9.76980359852314, validation loss : 2.4636991024017334
Early stop counter: 0
Epoch: 2/300, train loss : 3.988136947154999, validation loss : 1.612573504447937
Early stop counter: 0
Epoch: 3/300, train loss : 1.6156468838453293, validation loss : 0.826129674911499
Early stop counter: 0
Epoch: 4/300, train loss : 0.8988420963287354, validation loss : 0.6852701902389526
Early stop counter: 0
Epoch: 5/300, train loss : 0.793454572558403, validation loss : 0.7113128304481506
Early stop counter: 1
Epoch: 6/300, train loss : 0.6877724379301071, validation loss : 0.6752750277519226
Early stop counter: 0
Epoch: 7/300, train loss : 0.6864973902702332, validation loss : 0.680736780166626
Early stop counter: 1
Epoch: 8/300, train loss : 0.6954207569360733, validation loss : 0.657666802406311
Early stop counter: 0
Epoch: 9/300, train loss : 0.6932704597711563, val

[32m[I 2023-03-15 13:36:05,689][0m Trial 16 finished with value: 0.6639829516410828 and parameters: {'num_layers': 2, 'hidden_size': 412, 'n_heads': 4, 'dropout': 0.27821858942725325, 'learning_rate': 0.008509205601867388}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 27/300, train loss : 0.6685350686311722, validation loss : 0.9700232744216919
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.2047146409749985, validation loss : 0.6404708027839661
Early stop counter: 0
Epoch: 2/300, train loss : 0.7820105403661728, validation loss : 0.7743569612503052
Early stop counter: 1
Epoch: 3/300, train loss : 0.7784875333309174, validation loss : 0.7227848768234253
Early stop counter: 2
Epoch: 4/300, train loss : 0.6833872348070145, validation loss : 0.643106997013092
Early stop counter: 3
Epoch: 5/300, train loss : 0.6802446097135544, validation loss : 0.7362418174743652
Early stop counter: 4
Epoch: 6/300, train loss : 0.6878821402788162, validation loss : 0.707616925239563
Early stop counter: 5
Epoch: 7/300, train loss : 0.678944543004036, validation loss : 0.7718902826309204
Early stop counter: 6
Epoch: 8/300, train loss : 0.6974522918462753, validation loss : 0.660722553730011
Early stop counter: 7
Epoch: 9/300, train loss : 0.6865204572677612,

[32m[I 2023-03-15 13:36:23,230][0m Trial 17 finished with value: 0.6675217390060425 and parameters: {'num_layers': 3, 'hidden_size': 260, 'n_heads': 2, 'dropout': 0.21997436801435227, 'learning_rate': 0.0051856851117966654}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 19/300, train loss : 0.6545210480690002, validation loss : 0.7261375188827515
Early stop counter: 10
Epoch: 20/300, train loss : 0.6460286676883698, validation loss : 0.7039074301719666
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.1188310384750366, validation loss : 0.7051771283149719
Early stop counter: 0
Epoch: 2/300, train loss : 0.7514509856700897, validation loss : 0.735794186592102
Early stop counter: 1
Epoch: 3/300, train loss : 0.7033872604370117, validation loss : 0.7515780925750732
Early stop counter: 2
Epoch: 4/300, train loss : 0.6878402233123779, validation loss : 0.6650698781013489
Early stop counter: 0
Epoch: 5/300, train loss : 0.6824768036603928, validation loss : 0.6523149609565735
Early stop counter: 0
Epoch: 6/300, train loss : 0.687137171626091, validation loss : 0.6331974267959595
Early stop counter: 0
Epoch: 7/300, train loss : 0.6936207562685013, validation loss : 0.6491449475288391
Early stop counter: 1
Epoch: 8/300, train loss : 0.6783718913793

[32m[I 2023-03-15 13:36:37,182][0m Trial 18 finished with value: 0.6534631848335266 and parameters: {'num_layers': 2, 'hidden_size': 134, 'n_heads': 3, 'dropout': 0.22752920052209807, 'learning_rate': 0.007210856344774892}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 16/300, train loss : 0.6552429497241974, validation loss : 0.7289891242980957
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.6937961131334305, validation loss : 0.7380276322364807
Early stop counter: 0
Epoch: 2/300, train loss : 0.6924740225076675, validation loss : 0.6639761924743652
Early stop counter: 0
Epoch: 3/300, train loss : 0.6864407062530518, validation loss : 0.7310504913330078
Early stop counter: 1
Epoch: 4/300, train loss : 0.683037057518959, validation loss : 0.6405654549598694
Early stop counter: 0
Epoch: 5/300, train loss : 0.6870906352996826, validation loss : 0.7027094960212708
Early stop counter: 1
Epoch: 6/300, train loss : 0.6784784197807312, validation loss : 0.6554176211357117
Early stop counter: 2
Epoch: 7/300, train loss : 0.6702664494514465, validation loss : 0.6550956964492798
Early stop counter: 3
Epoch: 8/300, train loss : 0.670955017209053, validation loss : 0.6602984070777893
Early stop counter: 4
Epoch: 9/300, train loss : 0.664301186800003

[32m[I 2023-03-15 13:36:48,138][0m Trial 19 finished with value: 0.6444326162338256 and parameters: {'num_layers': 3, 'hidden_size': 65, 'n_heads': 1, 'dropout': 0.16953461013262117, 'learning_rate': 0.008982196798158818}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 21/300, train loss : 0.6139134019613266, validation loss : 0.690922200679779
Early stopping...
Fold 0
Epoch: 1/300, train loss : 52.93872681260109, validation loss : 93.60176086425781
Early stop counter: 0
Epoch: 2/300, train loss : 160.3123335838318, validation loss : 25.159103393554688
Early stop counter: 0
Epoch: 3/300, train loss : 32.323227405548096, validation loss : 19.623918533325195
Early stop counter: 0
Epoch: 4/300, train loss : 30.697322130203247, validation loss : 192.18655395507812
Early stop counter: 1
Epoch: 5/300, train loss : 60.60475444793701, validation loss : 35.3841552734375
Early stop counter: 2
Epoch: 6/300, train loss : 37.960838317871094, validation loss : 57.08452224731445
Early stop counter: 3
Epoch: 7/300, train loss : 33.62522745132446, validation loss : 17.694175720214844
Early stop counter: 0
Epoch: 8/300, train loss : 17.03683090209961, validation loss : 3.752601385116577
Early stop counter: 0
Epoch: 9/300, train loss : 6.787857532501221, validat

[32m[I 2023-03-15 13:38:36,830][0m Trial 20 finished with value: 0.6844094395637512 and parameters: {'num_layers': 3, 'hidden_size': 343, 'n_heads': 4, 'dropout': 0.15008791745238684, 'learning_rate': 0.008906632093027085}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 28/300, train loss : 0.6854101270437241, validation loss : 0.7147173285484314
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.6842018961906433, validation loss : 0.6441923379898071
Early stop counter: 0
Epoch: 2/300, train loss : 0.690427377820015, validation loss : 0.6561754941940308
Early stop counter: 1
Epoch: 3/300, train loss : 0.6868928819894791, validation loss : 0.6869698762893677
Early stop counter: 2
Epoch: 4/300, train loss : 0.6797659397125244, validation loss : 0.6572797298431396
Early stop counter: 3
Epoch: 5/300, train loss : 0.6803489476442337, validation loss : 0.6422352194786072
Early stop counter: 0
Epoch: 6/300, train loss : 0.6756560355424881, validation loss : 0.6835928559303284
Early stop counter: 1
Epoch: 7/300, train loss : 0.6728513985872269, validation loss : 0.6932381987571716
Early stop counter: 2
Epoch: 8/300, train loss : 0.6683404743671417, validation loss : 0.6406204104423523
Early stop counter: 0
Epoch: 9/300, train loss : 0.66512677073478

[32m[I 2023-03-15 13:38:46,315][0m Trial 21 finished with value: 0.6541055679321289 and parameters: {'num_layers': 2, 'hidden_size': 106, 'n_heads': 1, 'dropout': 0.1986944098279184, 'learning_rate': 0.006708965250752004}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 13/300, train loss : 0.6308934837579727, validation loss : 0.6947129368782043
Early stop counter: 8
Epoch: 14/300, train loss : 0.6321040987968445, validation loss : 0.725161612033844
Early stop counter: 9
Epoch: 15/300, train loss : 0.6196838915348053, validation loss : 0.7628792524337769
Early stop counter: 10
Epoch: 16/300, train loss : 0.6354147791862488, validation loss : 0.7088392972946167
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.6948620975017548, validation loss : 0.7014429569244385
Early stop counter: 0
Epoch: 2/300, train loss : 0.6921496838331223, validation loss : 0.7030450701713562
Early stop counter: 1
Epoch: 3/300, train loss : 0.6902584731578827, validation loss : 0.6945787072181702
Early stop counter: 0
Epoch: 4/300, train loss : 0.6831294596195221, validation loss : 0.6883543133735657
Early stop counter: 0
Epoch: 5/300, train loss : 0.676461935043335, validation loss : 0.6918581128120422
Early stop counter: 1
Epoch: 6/300, train loss : 0.67374981939

[32m[I 2023-03-15 13:38:55,005][0m Trial 22 finished with value: 0.6457862257957458 and parameters: {'num_layers': 2, 'hidden_size': 69, 'n_heads': 1, 'dropout': 0.14698425856084318, 'learning_rate': 0.007305617092639484}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 29/300, train loss : 0.5736320465803146, validation loss : 0.7474451661109924
Early stop counter: 8
Epoch: 30/300, train loss : 0.5790301859378815, validation loss : 0.7199790477752686
Early stop counter: 9
Epoch: 31/300, train loss : 0.5782779157161713, validation loss : 0.7652970552444458
Early stop counter: 10
Epoch: 32/300, train loss : 0.578010767698288, validation loss : 0.72678142786026
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7285263240337372, validation loss : 0.665948212146759
Early stop counter: 0
Epoch: 2/300, train loss : 0.6939485818147659, validation loss : 0.6826227903366089
Early stop counter: 1
Epoch: 3/300, train loss : 0.6972006112337112, validation loss : 0.8103639483451843
Early stop counter: 2
Epoch: 4/300, train loss : 0.7035742700099945, validation loss : 0.7217844128608704
Early stop counter: 3
Epoch: 5/300, train loss : 0.6890199333429337, validation loss : 0.7226536273956299
Early stop counter: 4
Epoch: 6/300, train loss : 0.6836880892515

[32m[I 2023-03-15 13:39:10,027][0m Trial 23 finished with value: 0.6523139953613282 and parameters: {'num_layers': 3, 'hidden_size': 160, 'n_heads': 2, 'dropout': 0.13560532529643912, 'learning_rate': 0.00517539959169132}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 20/300, train loss : 0.6537051349878311, validation loss : 0.6916532516479492
Early stop counter: 10
Epoch: 21/300, train loss : 0.6597466617822647, validation loss : 0.6998510360717773
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7722149193286896, validation loss : 0.6962931752204895
Early stop counter: 0
Epoch: 2/300, train loss : 0.7882560193538666, validation loss : 0.6742398738861084
Early stop counter: 0
Epoch: 3/300, train loss : 0.6889995634555817, validation loss : 0.6822812557220459
Early stop counter: 1
Epoch: 4/300, train loss : 0.6834715753793716, validation loss : 0.7097233533859253
Early stop counter: 2
Epoch: 5/300, train loss : 0.678507998585701, validation loss : 0.6766567826271057
Early stop counter: 3
Epoch: 6/300, train loss : 0.6808218657970428, validation loss : 0.6711516380310059
Early stop counter: 0
Epoch: 7/300, train loss : 0.6854702830314636, validation loss : 1.241399884223938
Early stop counter: 1
Epoch: 8/300, train loss : 0.7452743798494

[32m[I 2023-03-15 13:39:20,475][0m Trial 24 finished with value: 0.6598753333091736 and parameters: {'num_layers': 3, 'hidden_size': 250, 'n_heads': 1, 'dropout': 0.19860830489764658, 'learning_rate': 0.007352774376606142}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 13/300, train loss : 0.6793427765369415, validation loss : 0.7234784960746765
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.697537750005722, validation loss : 0.690963864326477
Early stop counter: 0
Epoch: 2/300, train loss : 0.6908047646284103, validation loss : 0.7042104601860046
Early stop counter: 1
Epoch: 3/300, train loss : 0.6887959539890289, validation loss : 0.6676627993583679
Early stop counter: 0
Epoch: 4/300, train loss : 0.6874692440032959, validation loss : 0.6736916303634644
Early stop counter: 1
Epoch: 5/300, train loss : 0.6787645369768143, validation loss : 0.6917144656181335
Early stop counter: 2
Epoch: 6/300, train loss : 0.6761723607778549, validation loss : 0.6483314037322998
Early stop counter: 0
Epoch: 7/300, train loss : 0.6725194007158279, validation loss : 0.6806256771087646
Early stop counter: 1
Epoch: 8/300, train loss : 0.6620793640613556, validation loss : 0.6440823078155518
Early stop counter: 0
Epoch: 9/300, train loss : 0.654073745012283

[32m[I 2023-03-15 13:39:29,401][0m Trial 25 finished with value: 0.6457571506500244 and parameters: {'num_layers': 2, 'hidden_size': 65, 'n_heads': 2, 'dropout': 0.14243725380438138, 'learning_rate': 0.006063044803717367}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 13/300, train loss : 0.6570235788822174, validation loss : 0.6918418407440186
Early stop counter: 10
Epoch: 14/300, train loss : 0.6479796320199966, validation loss : 0.7010891437530518
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7490988224744797, validation loss : 0.6976571083068848
Early stop counter: 0
Epoch: 2/300, train loss : 0.6925603449344635, validation loss : 0.6447926163673401
Early stop counter: 0
Epoch: 3/300, train loss : 0.6923410892486572, validation loss : 0.6574537754058838
Early stop counter: 1
Epoch: 4/300, train loss : 0.687159538269043, validation loss : 0.6356141567230225
Early stop counter: 0
Epoch: 5/300, train loss : 0.6724632084369659, validation loss : 0.7456170320510864
Early stop counter: 1
Epoch: 6/300, train loss : 0.6814150214195251, validation loss : 0.7010384798049927
Early stop counter: 2
Epoch: 7/300, train loss : 0.6735461205244064, validation loss : 0.645534098148346
Early stop counter: 3
Epoch: 8/300, train loss : 0.6705424785614

[32m[I 2023-03-15 13:39:40,511][0m Trial 26 finished with value: 0.6502277135849 and parameters: {'num_layers': 2, 'hidden_size': 121, 'n_heads': 2, 'dropout': 0.12142598501157725, 'learning_rate': 0.008952094366901827}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 17/300, train loss : 0.6278534233570099, validation loss : 0.7247651219367981
Early stop counter: 9
Epoch: 18/300, train loss : 0.624152883887291, validation loss : 0.7477579712867737
Early stop counter: 10
Epoch: 19/300, train loss : 0.6148919612169266, validation loss : 0.7841691374778748
Early stopping...
Fold 0
Epoch: 1/300, train loss : 13.105727538466454, validation loss : 3.6648027896881104
Early stop counter: 0
Epoch: 2/300, train loss : 3.561271220445633, validation loss : 7.050837993621826
Early stop counter: 1
Epoch: 3/300, train loss : 4.92425474524498, validation loss : 0.8378349542617798
Early stop counter: 0
Epoch: 4/300, train loss : 1.6293724328279495, validation loss : 1.7360886335372925
Early stop counter: 1
Epoch: 5/300, train loss : 1.11514550447464, validation loss : 0.9356493353843689
Early stop counter: 2
Epoch: 6/300, train loss : 0.9098172187805176, validation loss : 1.1551874876022339
Early stop counter: 3
Epoch: 7/300, train loss : 0.8083591610193253,

[32m[I 2023-03-15 13:40:35,265][0m Trial 27 finished with value: 0.67371985912323 and parameters: {'num_layers': 3, 'hidden_size': 391, 'n_heads': 3, 'dropout': 0.16473975828619153, 'learning_rate': 0.0060880866335189954}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 19/300, train loss : 0.6836668848991394, validation loss : 0.7863814234733582
Early stopping...
Fold 0
Epoch: 1/300, train loss : 1.9879090338945389, validation loss : 0.8197359442710876
Early stop counter: 0
Epoch: 2/300, train loss : 1.2265285402536392, validation loss : 0.7473933696746826
Early stop counter: 0
Epoch: 3/300, train loss : 0.7065126746892929, validation loss : 0.7280492186546326
Early stop counter: 0
Epoch: 4/300, train loss : 0.6934950947761536, validation loss : 0.7240045666694641
Early stop counter: 0
Epoch: 5/300, train loss : 0.6856424063444138, validation loss : 0.6393471956253052
Early stop counter: 0
Epoch: 6/300, train loss : 0.6893102675676346, validation loss : 0.6526715755462646
Early stop counter: 1
Epoch: 7/300, train loss : 0.6750496923923492, validation loss : 0.7918513417243958
Early stop counter: 2
Epoch: 8/300, train loss : 0.6787150502204895, validation loss : 0.6242523193359375
Early stop counter: 0
Epoch: 9/300, train loss : 0.6754160672426

[32m[I 2023-03-15 13:41:06,782][0m Trial 28 finished with value: 0.6542031526565552 and parameters: {'num_layers': 2, 'hidden_size': 470, 'n_heads': 2, 'dropout': 0.12554599791946428, 'learning_rate': 0.005205407038137495}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 19/300, train loss : 0.6419005542993546, validation loss : 0.723512589931488
Early stopping...
Fold 0
Epoch: 1/300, train loss : 0.7029806971549988, validation loss : 0.6954423785209656
Early stop counter: 0
Epoch: 2/300, train loss : 0.6893594115972519, validation loss : 0.692119836807251
Early stop counter: 0
Epoch: 3/300, train loss : 0.6818856596946716, validation loss : 0.7109052538871765
Early stop counter: 1
Epoch: 4/300, train loss : 0.6780722290277481, validation loss : 0.6662260293960571
Early stop counter: 0
Epoch: 5/300, train loss : 0.6663489043712616, validation loss : 0.6305520534515381
Early stop counter: 0
Epoch: 6/300, train loss : 0.6646373420953751, validation loss : 0.7014908194541931
Early stop counter: 1
Epoch: 7/300, train loss : 0.6531979292631149, validation loss : 0.6502487659454346
Early stop counter: 2
Epoch: 8/300, train loss : 0.6591545641422272, validation loss : 0.6054357886314392
Early stop counter: 0
Epoch: 9/300, train loss : 0.664098829030990

[32m[I 2023-03-15 13:41:19,339][0m Trial 29 finished with value: 0.6483298778533936 and parameters: {'num_layers': 2, 'hidden_size': 146, 'n_heads': 1, 'dropout': 0.10040528927501444, 'learning_rate': 0.00785459603234657}. Best is trial 7 with value: 0.6388066053390503.[0m


Epoch: 15/300, train loss : 0.6446306854486465, validation loss : 0.7243225574493408
Early stopping...
best trial:
[0.6388066053390503]
Best parameters: {'num_layers': 2, 'hidden_size': 439, 'n_heads': 1, 'dropout': 0.269754753387312, 'learning_rate': 0.007890910361468965}


## 3. Train/validate/test Model 

1. After hyperparameter tuning, the best parameters are saved to config.py as params_gt
2. Next, using the best parameters, model will be trained/validated/tested using 5-fold CV and the whole process was repeated for 5 times. Results were then averaged and reported
3. Ensure that data folder, from_scratch_trained_models folder and this notebook are in the same working directory
4. run_training to repeat the training process, please create a new folder to retrain, save and load saved model
5. Else, comment the run_training function and do testing to get the results which were obtained in the paper. 

In [2]:
def run_training(train_loader, valid_loader, params, trained_model_path):
    model = GraphTrans(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_layers=params['num_layers'], hidden_size=params['hidden_size'], n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0 

    for epoch in range(EPOCHS):
        train_loss = eng.train(train_loader)
        valid_loss = eng.validate(valid_loader)
        print(f'Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}')
        if valid_loss < best_loss:
            best_loss = valid_loss 
            early_stopping_counter=0 #reset counter
            print('Saving model...')
            torch.save(model.state_dict(), trained_model_path)
        else:
            early_stopping_counter +=1

        if early_stopping_counter > early_stopping_iter:
            print('Early stopping...')
            break
        print(f'Early stop counter: {early_stopping_counter}')
    
    return best_loss

def run_testing(test_loader, params,  trained_model_path):
    model = GraphTrans(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_layers=params['num_layers'], hidden_size=params['hidden_size'], n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.load_state_dict(torch.load(trained_model_path))
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineHOB(model, optimizer, device=DEVICE)

    print('Begin testing...')
    bce, acc, f1, roc_auc= eng.test(test_loader)
    print('Test completed!')
    print(f'bce:{bce}, acc:{acc}, f1: {f1}, roc_auc: {roc_auc}')
    return bce, acc, f1, roc_auc

In [4]:
params = params_gt
n_repetitions = 5
train_data_root_path = './data/graph_data/data_oral_avail_train/'
train_data_raw_filename = 'data_oral_avail_train_50.csv'
test_data_root_path = './data/graph_data/data_oral_avail_test/'
test_data_raw_filename = 'data_oral_avail_test_1_50.csv'
path_to_save_trained_model = './from_scratch_trained_models/GT/'

bce_list = []
acc_list = []
f1_list = []
roc_auc_list = []

#load dataset 
dataset_for_cv = LoadHOBDataset(root=train_data_root_path, raw_filename=train_data_raw_filename)
test_dataset = LoadHOBDataset(root=test_data_root_path, raw_filename=test_data_raw_filename)
kf = KFold(n_splits=N_SPLITS)

for repeat in range(n_repetitions):
    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        seed_everything(SEED_NO)
        train_dataset= []
        valid_dataset = []
        for t_idx in train_idx:
            train_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{t_idx}.pt'))
        for v_idx in valid_idx:
            valid_dataset.append(torch.load(f'./data/graph_data/data_oral_avail_train/processed/molecule_{v_idx}.pt'))

        train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)

        run_training(train_loader, valid_loader, params, 
                            os.path.join(path_to_save_trained_model, f'gt_repeat_{repeat}_fold_{fold_no}.pt'))
        bce, acc, f1, roc_auc = run_testing(test_loader, params, os.path.join(path_to_save_trained_model, f'gt_repeat_{repeat}_fold_{fold_no}.pt'))
        bce_list.append(bce)
        acc_list.append(acc)
        f1_list.append(f1)
        roc_auc_list.append(roc_auc)

bce_arr = np.array(bce_list)
mean_bce = np.mean(bce_arr)
sd_bce = np.std(bce_arr)
print(f'bce:{mean_bce:.3f}±{sd_bce:.3f}')

acc_arr = np.array(acc_list)
acc_mean= np.mean(acc_arr)
acc_sd = np.std(acc_arr)
print(f'acc:{acc_mean:.3f}±{acc_sd:.3f}')

f1_arr = np.array(f1_list)
f1_mean= np.mean(f1_arr)
f1_sd = np.std(f1_arr)
print(f'f1: {f1_mean:.3f}±{f1_sd:.3f}')

roc_auc_arr = np.array(roc_auc_list)
roc_auc_mean= np.mean(roc_auc_arr)
roc_auc_sd = np.std(roc_auc_arr)
print(f'roc_auc: {roc_auc_mean:.3f}±{roc_auc_sd:.3f}')

Epoch: 1/300, train loss : 1.0818932354450226, validation loss : 0.9042943120002747
Saving model...
Early stop counter: 0
Epoch: 2/300, train loss : 0.8742576390504837, validation loss : 0.6444023251533508
Saving model...
Early stop counter: 0
Epoch: 3/300, train loss : 0.706589013338089, validation loss : 0.7645893692970276
Early stop counter: 1
Epoch: 4/300, train loss : 0.7244753837585449, validation loss : 0.697129487991333
Early stop counter: 2
Epoch: 5/300, train loss : 0.750020831823349, validation loss : 0.7543023228645325
Early stop counter: 3
Epoch: 6/300, train loss : 0.6954939365386963, validation loss : 0.7170941829681396
Early stop counter: 4
Epoch: 7/300, train loss : 0.6909188777208328, validation loss : 0.67427659034729
Early stop counter: 5
Epoch: 8/300, train loss : 0.6818986088037491, validation loss : 0.6691092252731323
Early stop counter: 6
Epoch: 9/300, train loss : 0.6825099289417267, validation loss : 0.6500078439712524
Early stop counter: 7
Epoch: 10/300, trai