# Train Pretrained Parallel Model for oral bioavailability prediction

1. This notebook used the parallelGNN model pretrained with high similarity solubility dataset
2. Ensure that the notebook, data folder, config.py, engine.py, model.py and utils.py are in the same directory. 
3. Whole process can be repeated by running the cells in order
4. Otherwise, download the saved models as described in the README.md from google drive, comment the run_training line to get the results in the paper. 

In [1]:
# import all required materials

import torch
import numpy as np
from torch_geometric.loader import DataLoader
from sklearn.model_selection import KFold
import os

from model import ParallelGNN
from config import NUM_FEATURES, NUM_TARGET, EDGE_DIM, DEVICE, SEED_NO, PATIENCE, EPOCHS, NUM_GRAPHS_PER_BATCH, N_SPLITS, best_params_parallel
from engine import EngineHOB
from utils import seed_everything, LoadHOBDataset

First, we define a train and test function that can aid us in transfer learning. Then, we evaluate how different factors can affect the results of transfer learning. 

In [2]:
def run_training(method_tf, train_loader, valid_loader, params,es_trigger, path_to_pretrained_model, path_to_save_trained_model):
    
    '''
    Define a function to wrap training

    Args:
    method_tf (str): freeze --> freeze parameters of feature extraction block, fine_tune_5x -> fine tune at 5x slower learning rate, 
                    fine_tune_10x -> fine tune at 10x slower learning rate
    train_loader: DataLoader class from pytorch geometric containing train data
    valid_loader: DataLoader class from pytorch geometric containing validation data
    params (dict): dictionary containing the hyperparameters
    es_trigger (int): a number to force train model before triggering early stopping mechanism 
    path_to_pretrained_model (str): path to load the pretrained models
    path_to_save_trained_model: path to save the trained models

    Return:
    best loss: return best validation loss
    '''
    
    model = ParallelGNN(
            num_features=NUM_FEATURES,
            num_targets=NUM_TARGET,
            num_gin_layers=params["num_gin_layers"],
            num_graph_trans_layers=params["num_graph_trans_layers"],
            hidden_size=params["hidden_size"],
            n_heads=params["n_heads"],
            dropout=params["dropout"],
            edge_dim=EDGE_DIM,
        )

    model.load_state_dict(torch.load(path_to_pretrained_model))  
    model.to(DEVICE)
    if method_tf == 'freeze':
        for param in model.gin_model.parameters():
            param.requires_grad=False
        for param in model.graph_trans_model.parameters():
            param.requires_grad=False

        optimizer=torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = params['learning_rate'])
    
    elif method_tf == 'fine_tune_5x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/5},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/5},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])

    elif method_tf == 'fine_tune_10x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/10},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/10},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])
        
    eng = EngineHOB(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0

    for epoch in range(EPOCHS):
        train_loss = eng.train(train_loader)
        valid_loss = eng.validate(valid_loader)
        print(
            f"Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}"
        )
        if epoch+1>es_trigger:
            if valid_loss < best_loss:
                best_loss = valid_loss
                early_stopping_counter = 0  # reset counter
                print("Saving model...")
                torch.save(model.state_dict(), path_to_save_trained_model)
            else:
                early_stopping_counter += 1

            if early_stopping_counter > early_stopping_iter:
                print("Early stopping...")
                break
            print(f"Early stop counter: {early_stopping_counter}")

    return best_loss

In [3]:
def run_testing(method_tf, test_loader, params, path_to_trained_model):
    
    '''
    Define a function to wrap training

    Args:
    method_tf (str): freeze --> freeze parameters of feature extraction block, fine_tune_5x -> fine tune at 5x slower learning rate, 
                    fine_tune_10x -> fine tune at 10x slower learning rate
    test_loader: DataLoader class from pytorch geometric containing test data
    params (dict): dictionary containing the hyperparameters
    path_to_save_trained_model: path to load the saved trained models

    Return:
    bce: logloss from pytorch geometric 
    acc: accuracy score
    f1: f1 score
    roc_auc: roc auc score
    '''
    
    model = ParallelGNN(
            num_features=NUM_FEATURES,
            num_targets=NUM_TARGET,
            num_gin_layers=params["num_gin_layers"],
            num_graph_trans_layers=params["num_graph_trans_layers"],
            hidden_size=params["hidden_size"],
            n_heads=params["n_heads"],
            dropout=params["dropout"],
            edge_dim=EDGE_DIM,
        )

    model.load_state_dict(torch.load(path_to_trained_model))  
    model.to(DEVICE)
    if method_tf == 'freeze':
        for param in model.gin_model.parameters():
            param.requires_grad=False
        for param in model.graph_trans_model.parameters():
            param.requires_grad=False

        optimizer=torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = params['learning_rate'])        
    
    elif method_tf == 'fine_tune_5x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/5},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/5},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])
        

    elif method_tf == 'fine_tune_10x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/10},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/10},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])
        

    eng = EngineHOB(model, optimizer, device=DEVICE)
    bce, acc, f1, roc_auc = eng.test(test_loader)
    print(f"bce:{bce}, acc :{acc}, f1: {f1}, roc_auc: {roc_auc}")
    return bce, acc, f1, roc_auc

# Effect of number of pre-training epochs to transfer learning prediction performance
1. Weights are first frozen for the feature extraction block. 
2. Models are allowed to train and only weights for the classifier block is allowed to be updated. 
3. To evaluate how number of pre-training epochs affect the transfer learning prediction performance.

### Note on path
1. path_to_pretrained_model = path to model pretrained with solubility dataset
2. path_to_save_trained_model = path to pretrained model trained with oral bioavailability dataset
3. path_to_trained_model = path to pretrained model trained with oral bioavailability dataset

In [4]:
train_data_root_path = './data/graph_data/data_oral_avail_train/'
train_data_raw_filename = 'data_oral_avail_train_50.csv'
test_data_root_path = './data/graph_data/data_oral_avail_test'
test_data_raw_filename = 'data_oral_avail_test_1_50.csv'
n_repetitions = 5
method_tf = 'freeze'
params = best_params_parallel
es_trigger = 0
path_to_pretrained_model = './trf_learning_models/pretrained_models/parallel/high/'
path_to_save_trained_model = './trf_learning_models/trained_models/parallel/high/pretrained_20/'
path_to_trained_model = './trf_learning_models/trained_models/parallel/high/pretrained_20/'


bce_list = []
acc_list = []
f1_list = []
roc_auc_list = []

dataset_for_cv = LoadHOBDataset(train_data_root_path, train_data_raw_filename)
kf = KFold(n_splits=N_SPLITS)

for repeat in range(n_repetitions):
    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        seed_everything(SEED_NO)
        train_dataset = []
        valid_dataset = []
        
        for t_idx in train_idx:
            train_dataset.append(
                torch.load(
                    f"./data/graph_data/data_oral_avail_train/processed/molecule_{t_idx}.pt"
                )
            )
        for v_idx in valid_idx:
            valid_dataset.append(
                torch.load(
                    f"./data/graph_data/data_oral_avail_train/processed/molecule_{v_idx}.pt"
                )
            )

        train_loader = DataLoader(
            train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True
        )
        valid_loader = DataLoader(
            valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        test_dataset = LoadHOBDataset(test_data_root_path, test_data_raw_filename)
        test_loader = DataLoader(
            test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        print(f'Rep no {repeat}, Fold no {fold_no}')
        run_training(method_tf, train_loader, valid_loader, params, es_trigger, os.path.join(path_to_pretrained_model, f'pretrained_parallel_model_20_epoch.pt'),
            os.path.join(
                path_to_save_trained_model, f"trained_parallel_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"
            ),
        )

        bce, acc, f1, roc_auc = run_testing(method_tf, test_loader, params, 
                        os.path.join(path_to_save_trained_model, f"trained_parallel_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"))

        bce_list.append(bce)
        acc_list.append(acc)
        f1_list.append(f1)
        roc_auc_list.append(roc_auc)

bce_arr = np.array(bce_list)
mean_bce = np.mean(bce_arr)
sd_bce = np.std(bce_arr)
print(f'bce:{mean_bce:.3f}±{sd_bce:.3f}')

acc_arr = np.array(acc_list)
acc_mean= np.mean(acc_arr)
acc_sd = np.std(acc_arr)
print(f'acc:{acc_mean:.3f}±{acc_sd:.3f}')

f1_arr = np.array(f1_list)
f1_mean= np.mean(f1_arr)
f1_sd = np.std(f1_arr)
print(f'f1: {f1_mean:.3f}±{f1_sd:.3f}')

roc_auc_arr = np.array(roc_auc_list)
roc_auc_mean= np.mean(roc_auc_arr)
roc_auc_sd = np.std(roc_auc_arr)
print(f'roc_auc: {roc_auc_mean:.3f}±{roc_auc_sd:.3f}')

print("Training Completed!")

Rep no 0, Fold no 0
Epoch: 1/300, train loss : 1.4411433041095734, validation loss : 0.6933850049972534
Saving model...
Early stop counter: 0
Epoch: 2/300, train loss : 0.7869274914264679, validation loss : 0.6235814690589905
Saving model...
Early stop counter: 0
Epoch: 3/300, train loss : 0.7146792262792587, validation loss : 0.7609443068504333
Early stop counter: 1
Epoch: 4/300, train loss : 0.7037916779518127, validation loss : 0.6618936657905579
Early stop counter: 2
Epoch: 5/300, train loss : 0.6728964149951935, validation loss : 0.6165732145309448
Saving model...
Early stop counter: 0
Epoch: 6/300, train loss : 0.6725444346666336, validation loss : 0.6609098315238953
Early stop counter: 1
Epoch: 7/300, train loss : 0.6612101197242737, validation loss : 0.6612956523895264
Early stop counter: 2
Epoch: 8/300, train loss : 0.6547185331583023, validation loss : 0.6326295137405396
Early stop counter: 3
Epoch: 9/300, train loss : 0.6562199145555496, validation loss : 0.6407881379127502


In [5]:
train_data_root_path = './data/graph_data/data_oral_avail_train/'
train_data_raw_filename = 'data_oral_avail_train_50.csv'
test_data_root_path = './data/graph_data/data_oral_avail_test'
test_data_raw_filename = 'data_oral_avail_test_1_50.csv'
n_repetitions = 5
method_tf = 'freeze'
params = best_params_parallel
es_trigger = 0
path_to_pretrained_model = './trf_learning_models/pretrained_models/parallel/high/'
path_to_save_trained_model = './trf_learning_models/trained_models/parallel/high/pretrained_40/'
path_to_trained_model = './trf_learning_models/trained_models/parallel/high/pretrained_40/'


bce_list = []
acc_list = []
f1_list = []
roc_auc_list = []

dataset_for_cv = LoadHOBDataset(train_data_root_path, train_data_raw_filename)
kf = KFold(n_splits=N_SPLITS)

for repeat in range(n_repetitions):
    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        seed_everything(SEED_NO)
        train_dataset = []
        valid_dataset = []
        
        for t_idx in train_idx:
            train_dataset.append(
                torch.load(
                    f"./data/graph_data/data_oral_avail_train/processed/molecule_{t_idx}.pt"
                )
            )
        for v_idx in valid_idx:
            valid_dataset.append(
                torch.load(
                    f"./data/graph_data/data_oral_avail_train/processed/molecule_{v_idx}.pt"
                )
            )

        train_loader = DataLoader(
            train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True
        )
        valid_loader = DataLoader(
            valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        test_dataset = LoadHOBDataset(test_data_root_path, test_data_raw_filename)
        test_loader = DataLoader(
            test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        print(f'Rep no {repeat}, Fold no {fold_no}')
        run_training(method_tf, train_loader, valid_loader, params, es_trigger, os.path.join(path_to_pretrained_model, f'pretrained_parallel_model_40_epoch.pt'),
            os.path.join(
                path_to_save_trained_model, f"trained_parallel_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"
            ),
        )

        bce, acc, f1, roc_auc = run_testing(method_tf, test_loader, params, 
                        os.path.join(path_to_save_trained_model, f"trained_parallel_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"))

        bce_list.append(bce)
        acc_list.append(acc)
        f1_list.append(f1)
        roc_auc_list.append(roc_auc)

bce_arr = np.array(bce_list)
mean_bce = np.mean(bce_arr)
sd_bce = np.std(bce_arr)
print(f'bce:{mean_bce:.3f}±{sd_bce:.3f}')

acc_arr = np.array(acc_list)
acc_mean= np.mean(acc_arr)
acc_sd = np.std(acc_arr)
print(f'acc:{acc_mean:.3f}±{acc_sd:.3f}')

f1_arr = np.array(f1_list)
f1_mean= np.mean(f1_arr)
f1_sd = np.std(f1_arr)
print(f'f1: {f1_mean:.3f}±{f1_sd:.3f}')

roc_auc_arr = np.array(roc_auc_list)
roc_auc_mean= np.mean(roc_auc_arr)
roc_auc_sd = np.std(roc_auc_arr)
print(f'roc_auc: {roc_auc_mean:.3f}±{roc_auc_sd:.3f}')

print("Training Completed!")

Rep no 0, Fold no 0
Epoch: 1/300, train loss : 1.4577331244945526, validation loss : 0.6919055581092834
Saving model...
Early stop counter: 0
Epoch: 2/300, train loss : 0.831117570400238, validation loss : 0.6503968238830566
Saving model...
Early stop counter: 0
Epoch: 3/300, train loss : 0.7152771502733231, validation loss : 0.789370059967041
Early stop counter: 1
Epoch: 4/300, train loss : 0.7187482416629791, validation loss : 0.6864807605743408
Early stop counter: 2
Epoch: 5/300, train loss : 0.6714933812618256, validation loss : 0.6136261224746704
Saving model...
Early stop counter: 0
Epoch: 6/300, train loss : 0.6821139603853226, validation loss : 0.624452531337738
Early stop counter: 1
Epoch: 7/300, train loss : 0.6609260439872742, validation loss : 0.6531277894973755
Early stop counter: 2
Epoch: 8/300, train loss : 0.6559162437915802, validation loss : 0.6500341892242432
Early stop counter: 3
Epoch: 9/300, train loss : 0.6552475541830063, validation loss : 0.6326161026954651
Ear

In [6]:
train_data_root_path = './data/graph_data/data_oral_avail_train/'
train_data_raw_filename = 'data_oral_avail_train_50.csv'
test_data_root_path = './data/graph_data/data_oral_avail_test'
test_data_raw_filename = 'data_oral_avail_test_1_50.csv'
n_repetitions = 5
method_tf = 'freeze'
params = best_params_parallel
es_trigger = 0
path_to_pretrained_model = './trf_learning_models/pretrained_models/parallel/high/'
path_to_save_trained_model = './trf_learning_models/trained_models/parallel/high/pretrained_60/'
path_to_trained_model = './trf_learning_models/trained_models/parallel/high/pretrained_60/'


bce_list = []
acc_list = []
f1_list = []
roc_auc_list = []

dataset_for_cv = LoadHOBDataset(train_data_root_path, train_data_raw_filename)
kf = KFold(n_splits=N_SPLITS)

for repeat in range(n_repetitions):
    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        seed_everything(SEED_NO)
        train_dataset = []
        valid_dataset = []
        
        for t_idx in train_idx:
            train_dataset.append(
                torch.load(
                    f"./data/graph_data/data_oral_avail_train/processed/molecule_{t_idx}.pt"
                )
            )
        for v_idx in valid_idx:
            valid_dataset.append(
                torch.load(
                    f"./data/graph_data/data_oral_avail_train/processed/molecule_{v_idx}.pt"
                )
            )

        train_loader = DataLoader(
            train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True
        )
        valid_loader = DataLoader(
            valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        test_dataset = LoadHOBDataset(test_data_root_path, test_data_raw_filename)
        test_loader = DataLoader(
            test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        print(f'Rep no {repeat}, Fold no {fold_no}')
        run_training(method_tf, train_loader, valid_loader, params, es_trigger, os.path.join(path_to_pretrained_model, f'pretrained_parallel_model_60_epoch.pt'),
            os.path.join(
                path_to_save_trained_model, f"trained_parallel_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"
            ),
        )

        bce, acc, f1, roc_auc = run_testing(method_tf, test_loader, params, 
                        os.path.join(path_to_save_trained_model, f"trained_parallel_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"))

        bce_list.append(bce)
        acc_list.append(acc)
        f1_list.append(f1)
        roc_auc_list.append(roc_auc)

bce_arr = np.array(bce_list)
mean_bce = np.mean(bce_arr)
sd_bce = np.std(bce_arr)
print(f'bce:{mean_bce:.3f}±{sd_bce:.3f}')

acc_arr = np.array(acc_list)
acc_mean= np.mean(acc_arr)
acc_sd = np.std(acc_arr)
print(f'acc:{acc_mean:.3f}±{acc_sd:.3f}')

f1_arr = np.array(f1_list)
f1_mean= np.mean(f1_arr)
f1_sd = np.std(f1_arr)
print(f'f1: {f1_mean:.3f}±{f1_sd:.3f}')

roc_auc_arr = np.array(roc_auc_list)
roc_auc_mean= np.mean(roc_auc_arr)
roc_auc_sd = np.std(roc_auc_arr)
print(f'roc_auc: {roc_auc_mean:.3f}±{roc_auc_sd:.3f}')

print("Training Completed!")

Rep no 0, Fold no 0
Epoch: 1/300, train loss : 1.4414126127958298, validation loss : 0.7238380312919617
Saving model...
Early stop counter: 0
Epoch: 2/300, train loss : 0.8446258902549744, validation loss : 0.6749630570411682
Saving model...
Early stop counter: 0
Epoch: 3/300, train loss : 0.7257793694734573, validation loss : 0.8153026103973389
Early stop counter: 1
Epoch: 4/300, train loss : 0.7144213169813156, validation loss : 0.6967652440071106
Early stop counter: 2
Epoch: 5/300, train loss : 0.6691636294126511, validation loss : 0.6222056746482849
Saving model...
Early stop counter: 0
Epoch: 6/300, train loss : 0.674372598528862, validation loss : 0.6329690217971802
Early stop counter: 1
Epoch: 7/300, train loss : 0.6516533344984055, validation loss : 0.6627249717712402
Early stop counter: 2
Epoch: 8/300, train loss : 0.641597256064415, validation loss : 0.6559787392616272
Early stop counter: 3
Epoch: 9/300, train loss : 0.6425077468156815, validation loss : 0.6346425414085388
Ea