In [12]:
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
from sklearn.metrics import accuracy_score
dtype = torch.FloatTensor
os.chdir('/cloud-home/U1039935/Autosurv/autobin')
# Define the new model for binary classification
class BinaryClassifier(nn.Module):
    def __init__(self, input_n, level_2_dim, Dropout_Rate_1, Dropout_Rate_2):
        super(BinaryClassifier, self).__init__()
        self.input_n = input_n
        self.level_2_dim = level_2_dim
        self.tanh = nn.Tanh()

        # Binary classification fc layers
        self.bn_input = nn.BatchNorm1d(self.input_n)
        self.fc1 = nn.Linear(self.input_n + 4, self.level_2_dim)
        self.bn2 = nn.BatchNorm1d(self.level_2_dim)
        self.fc2 = nn.Linear(self.level_2_dim, 1)
        
        # Dropout
        self.dropout_1 = nn.Dropout(Dropout_Rate_1)
        self.dropout_2 = nn.Dropout(Dropout_Rate_2)

    def forward(self, latent_features, c1, c2, c3, c4, s_dropout=False):
        if s_dropout:
            latent_features = self.dropout_1(latent_features)
        latent_features = self.bn_input(latent_features)
        clinical_layer = torch.cat((latent_features, c1, c2, c3, c4), 1)
        hidden_layer = self.tanh(self.fc1(clinical_layer))
        if s_dropout:
            hidden_layer = self.dropout_2(hidden_layer)
        hidden_layer = self.bn2(hidden_layer)
        y_pred = torch.sigmoid(self.fc2(hidden_layer))
        
        return y_pred

def train_binary_classifier(train_x1, train_age, train_stage_i, train_stage_ii, train_race_white, train_yevent,
                            eval_x1, eval_age, eval_stage_i, eval_stage_ii, eval_race_white, eval_yevent,
                            input_n, level_2_dim, Dropout_Rate_1, Dropout_Rate_2, Learning_Rate, L2, epoch_num, patience, dtype,
                            path="saved_model/binary_classifier_checkpoint.pt"):
    net = BinaryClassifier(input_n, level_2_dim, Dropout_Rate_1, Dropout_Rate_2)

    early_stopping = EarlyStopping(patience=patience, verbose=False, path=path)

    if torch.cuda.is_available():
        net.cuda()
    opt = optim.Adam(net.parameters(), lr=Learning_Rate, weight_decay=L2)

    start_time = time.time()
    for epoch in range(epoch_num):
        net.train()
        opt.zero_grad()

        y_pred = net(train_x1, train_age, train_stage_i, train_stage_ii, train_race_white, s_dropout=True)
        loss = F.binary_cross_entropy(y_pred.squeeze(), train_yevent.squeeze())

        loss.backward()
        opt.step()

        net.eval()
        eval_y_pred = net(eval_x1, eval_age, eval_stage_i, eval_stage_ii, eval_race_white, s_dropout=False)
        eval_accuracy = accuracy_score(eval_yevent.detach().cpu().numpy(), eval_y_pred.detach().cpu().numpy().round())

        early_stopping(eval_accuracy, net)
        if early_stopping.early_stop:
            print("Early stopping, number of epochs: ", epoch)
            print('Save model of Epoch {:d}'.format(early_stopping.best_epoch_num))
            break
        if (epoch+1) % 100 == 0:
            net.eval()
            train_y_pred = net(train_x1, train_age, train_stage_i, train_stage_ii, train_race_white, s_dropout=False)
            train_accuracy = accuracy_score(train_yevent.detach().cpu().numpy(), train_y_pred.detach().cpu().numpy().round())
            print("Training Accuracy: %s," % train_accuracy, "validation Accuracy: %s." % eval_accuracy)

    print("Loading model, best epoch: %s." % early_stopping.best_epoch_num)
    net.load_state_dict(torch.load(path, map_location=torch.device('cpu')))

    net.eval()
    train_y_pred = net(train_x1, train_age, train_stage_i, train_stage_ii, train_race_white, s_dropout=False)
    train_accuracy = accuracy_score(train_yevent.detach().cpu().numpy(), train_y_pred.detach().cpu().numpy().round())

    net.eval()
    eval_y_pred = net(eval_x1, eval_age, eval_stage_i, eval_stage_ii, eval_race_white, s_dropout=False)
    eval_accuracy = accuracy_score(eval_yevent.detach().cpu().numpy(), eval_y_pred.detach().cpu().numpy().round())

    print("Final training Accuracy: %s," % train_accuracy, "final validation Accuracy: %s." % eval_accuracy)
    time_elapse = np.array(time.time() - start_time).round(2)
    print("Total time elapse: %s." % time_elapse)

    return (train_y_pred, eval_y_pred, train_accuracy, eval_accuracy, early_stopping.best_epoch_num)

In [13]:
start_time = time.time()

# Example of calling the function with your data (assuming the data loading and preprocessing steps are similar)
input_n = 16
level_2_dim = [8, 16, 32]
epoch_num = 500
patience = 200
Initial_Learning_Rate = [0.05, 0.01, 0.0075, 0.005, 0.0025]
L2_Lambda = [0.001, 0.00075, 0.0005, 0.00025, 0.0001]
Dropout_rate_1 = [0.1, 0.3, 0.5]
Dropout_rate_2 = [0.1, 0.3, 0.5]

best_epoch_num = 0

patient_id_train, x_train, ytime_train, yevent_train, age_train, stage_i_train, stage_ii_train, race_white_train = load_data("tune_tr_z_2omics.csv", dtype)
patient_id_valid, x_valid, ytime_valid, yevent_valid, age_valid, stage_i_valid, stage_ii_valid, race_white_valid = load_data("tune_val_z_2omics.csv", dtype)

patient_id_train_overall, x_train_overall, ytime_train_overall, yevent_train_overall, age_train_overall, stage_i_train_overall, stage_ii_train_overall, race_white_train_overall = load_data("tr_z_2omics.csv", dtype)
patient_id_test_overall, x_test_overall, ytime_test_overall, yevent_test_overall, age_test_overall, stage_i_test_overall, stage_ii_test_overall, race_white_test_overall = load_data("tes_z_2omics.csv", dtype)
opt_l2 = 0
opt_lr = 0
opt_dim = 0
opt_dr1 = 0
opt_dr2 = 0

opt_accuracy_va = float(0)
opt_accuracy_tr = float(0)
for l2 in L2_Lambda:
    for lr in Initial_Learning_Rate:
        for dim in level_2_dim:
            for dr1 in Dropout_rate_1:
                for dr2 in Dropout_rate_2:
                    _, _, accuracy_train, accuracy_valid, best_epoch_num_tune = train_binary_classifier(x_train, age_train, stage_i_train, stage_ii_train, race_white_train, yevent_train,
                                                                                                        x_valid, age_valid, stage_i_valid, stage_ii_valid, race_white_valid, yevent_valid,
                                                                                                        input_n, dim, dr1, dr2, lr, l2, epoch_num, patience, dtype,
                                                                                                        path="saved_models/binary_classifier_checkpoint_tune.pt")

                    if accuracy_valid > opt_accuracy_va:
                        opt_l2 = l2
                        opt_lr = lr
                        opt_dim = dim
                        opt_dr1 = dr1
                        opt_dr2 = dr2
                        opt_accuracy_tr = accuracy_train
                        opt_accuracy_va = accuracy_valid
                        best_epoch_num = best_epoch_num_tune
                    print("L2: %s," % l2, "LR: %s." % lr, "dim: %s," % dim, "dr1: %s," % dr1, "dr2: %s." % dr2)
                    print("Training Accuracy: %s," % accuracy_train, "validation Accuracy: %s." % accuracy_valid)

train_y_pred, test_y_pred, accuracy_train, accuracy_test, best_epoch_num_overall = train_binary_classifier(x_train_overall, age_train_overall, stage_i_train_overall, stage_ii_train_overall, race_white_train_overall, yevent_train_overall,
                                                                                                           x_test_overall, age_test_overall, stage_i_test_overall, stage_ii_test_overall, race_white_test_overall, yevent_test_overall,
                                                                                                           input_n, opt_dim, opt_dr1, opt_dr2, opt_lr, opt_l2, epoch_num, patience, dtype,
                                                                                                           path="saved_models/binary_classifier_checkpoint_overall.pt")
print("Optimal L2: %s," % opt_l2, "optimal LR: %s," % opt_lr, "optimal dim: %s," % opt_dim, "optimal dr1: %s," % opt_dr1, "optimal dr2: %s," % opt_dr2, "best epoch number in tuning: %s." % best_epoch_num)
print("Optimal training Accuracy: %s," % opt_accuracy_tr, "optimal validation Accuracy: %s." % opt_accuracy_va)
print("Testing phase: training Accuracy: %s," % accuracy_train, "testing Accuracy: %s." % accuracy_test)
end_time = time.time()
print("--- %s seconds ---" % (end_time - start_time))

EarlyStopping counter: 20 out of 200
EarlyStopping counter: 20 out of 200
Training Accuracy: 0.8295964125560538, validation Accuracy: 0.7712765957446809.
EarlyStopping counter: 20 out of 200
EarlyStopping counter: 40 out of 200
EarlyStopping counter: 60 out of 200
EarlyStopping counter: 80 out of 200
EarlyStopping counter: 100 out of 200
Training Accuracy: 0.8684603886397608, validation Accuracy: 0.8191489361702128.
EarlyStopping counter: 120 out of 200
EarlyStopping counter: 20 out of 200
EarlyStopping counter: 40 out of 200
Training Accuracy: 0.8699551569506726, validation Accuracy: 0.8191489361702128.
EarlyStopping counter: 60 out of 200
EarlyStopping counter: 20 out of 200
EarlyStopping counter: 20 out of 200
EarlyStopping counter: 40 out of 200
Training Accuracy: 0.866965620328849, validation Accuracy: 0.8138297872340425.
EarlyStopping counter: 60 out of 200
EarlyStopping counter: 80 out of 200
EarlyStopping counter: 100 out of 200
EarlyStopping counter: 120 out of 200
EarlyStoppi

In [None]:
print("--- %s seconds ---" % (end_time - start_time))

In [None]:
#25-30 mins