In [1]:
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import time
import lifelines
from lifelines.utils import concordance_index
from LFSurv import LFSurv
from utils import sort_data, load_data, R_set, neg_par_log_likelihood, c_index, EarlyStopping
from train_LFSurv import train_LFSurv

dtype = torch.FloatTensor
os.chdir('/cloud-home/U1039935/Autosurv/autostat')

In [2]:
data_tune_train = pd.read_csv("processed_data_example/TCGA_BRCA/tune/minmax_normalized/data_train_gene_minmax_tune.csv")
data_tune_valid = pd.read_csv("processed_data_example/TCGA_BRCA/tune/minmax_normalized/data_valid_gene_minmax_tune.csv")
data_train = pd.read_csv("processed_data_example/TCGA_BRCA/train_test_split/minmax_normalized/data_train_gene_minmax_overall.csv")
data_valid = pd.read_csv("processed_data_example/TCGA_BRCA/train_test_split/minmax_normalized/data_test_gene_minmax_overall.csv")

In [3]:
data_train.shape

(857, 2706)

In [4]:
select_gene =['patient_id', 'OS.time', 'OS', 'age', 'stage_i', 'stage_ii', 'race_white',
 'ENSG00000169764',
 'ENSG00000115380',
 'ENSG00000168216',
 'ENSG00000131080',
 'ENSG00000154734',
 'ENSG00000184005',
 'ENSG00000154727',
 'ENSG00000185010',
 'ENSG00000069431',
 'ENSG00000171056',
 'ENSG00000137693',
 'ENSG00000140285',
 'ENSG00000138686',
 'ENSG00000138640',
 'ENSG00000005020',
 'ENSG00000157557']

In [5]:
tune_tr = data_tune_train[select_gene]
tune_val = data_tune_valid[select_gene]
tr = data_train[select_gene]
val = data_valid[select_gene]
tune_tr = pd.DataFrame(tune_tr)
tune_val = pd.DataFrame(tune_val)
tr = pd.DataFrame(tr)
val = pd.DataFrame(val)
print(tune_tr.shape, tune_val.shape, tr.shape, val.shape)

(669, 23) (188, 23) (857, 23) (201, 23)


In [97]:
tune_tr.to_csv('tune_tr_z_2omics.csv', index = False)
tune_val.to_csv("tune_val_z_2omics.csv", index = False)
tr.to_csv('tr_z_2omics.csv', index = False)
val.to_csv('tes_z_2omics.csv', index = False)

In [7]:
# modify input_n based on optim model
input_n = 16
level_2_dim = [8, 16, 32]
epoch_num = 500
patience = 200
Initial_Learning_Rate = [0.05, 0.01, 0.0075, 0.005, 0.0025]
# Initial_Learning_Rate = [0.05, 0.0075, 0.0025]
L2_Lambda = [0.001, 0.00075, 0.0005, 0.00025, 0.0001]
# L2_Lambda = [0.001, 0.0005, 0.0001]
Dropout_rate_1 = [0.1, 0.3, 0.5]
Dropout_rate_2 = [0.1, 0.3, 0.5]

best_epoch_num = 0


patient_id_train, x_train, ytime_train, yevent_train, age_train, stage_i_train, stage_ii_train, race_white_train = load_data("tune_tr_z_2omics.csv", dtype)
patient_id_valid, x_valid, ytime_valid, yevent_valid, age_valid, stage_i_valid, stage_ii_valid, race_white_valid = load_data("tune_val_z_2omics.csv", dtype)

patient_id_train_overall, x_train_overall, ytime_train_overall, yevent_train_overall, age_train_overall, stage_i_train_overall, stage_ii_train_overall, race_white_train_overall = load_data("tr_z_2omics.csv", dtype)
patient_id_test_overall, x_test_overall, ytime_test_overall, yevent_test_overall, age_test_overall, stage_i_test_overall, stage_ii_test_overall, race_white_test_overall = load_data("tes_z_2omics.csv", dtype)

In [8]:
x_train.shape

torch.Size([669, 16])

In [9]:
import time

In [10]:
start_time = time.time()
opt_l2 = 0
opt_lr = 0
opt_dim = 0
opt_dr1 = 0
opt_dr2 = 0

opt_cindex_va = float(0)
opt_cindex_tr = float(0)
for l2 in L2_Lambda:
    for lr in Initial_Learning_Rate:
        for dim in level_2_dim:
            for dr1 in Dropout_rate_1:
                for dr2 in Dropout_rate_2:
                    _, _, cindex_train, cindex_valid, best_epoch_num_tune = train_LFSurv(x_train, age_train, stage_i_train, stage_ii_train, race_white_train, ytime_train, yevent_train,
                                                                                         x_valid, age_valid, stage_i_valid, stage_ii_valid, race_white_valid, ytime_valid, yevent_valid,
                                                                                         input_n, dim, dr1, dr2, lr, l2, epoch_num, patience, dtype,
                                                                                         path = "sup_checkpoint_tune.pt")
                    
                    if cindex_valid > opt_cindex_va:
                        opt_l2 = l2
                        opt_lr = lr
                        opt_dim = dim
                        opt_dr1 = dr1
                        opt_dr2 = dr2
                        opt_cindex_tr = cindex_train
                        opt_cindex_va = cindex_valid
                        best_epoch_num = best_epoch_num_tune
                    print("L2: %s," %l2, "LR: %s." %lr, "dim: %s," %dim, "dr1: %s," %dr1, "dr2: %s." %dr2)
                    print("Training C-index: %s," %cindex_train.round(4), "validation C-index: %s." %cindex_valid.round(4))
end_time = time.time()
print("--- %s seconds ---" % (end_time - start_time))

EarlyStopping counter: 20 out of 200
EarlyStopping counter: 40 out of 200
EarlyStopping counter: 60 out of 200
Training C-index: 0.7568, validation C-index: 0.7159.
EarlyStopping counter: 80 out of 200
EarlyStopping counter: 100 out of 200
EarlyStopping counter: 120 out of 200
EarlyStopping counter: 140 out of 200
EarlyStopping counter: 160 out of 200
Training C-index: 0.7606, validation C-index: 0.6844.
EarlyStopping counter: 180 out of 200
EarlyStopping counter: 200 out of 200
Early stopping, number of epochs:  226
Save model of Epoch 27
Loading model, best epoch: 27.
Final training C-index: 0.7633, final validation C-index: 0.7517.
Total time elapse: 3.89.
L2: 0.001, LR: 0.05. dim: 8, dr1: 0.1, dr2: 0.1.
Training C-index: 0.7633, validation C-index: 0.7517.
EarlyStopping counter: 20 out of 200
EarlyStopping counter: 40 out of 200
EarlyStopping counter: 60 out of 200
Training C-index: 0.765, validation C-index: 0.7296.
EarlyStopping counter: 80 out of 200
EarlyStopping counter: 100 o

In [11]:
print("--- %s seconds ---" % (end_time - start_time))

--- 1375.6601901054382 seconds ---


In [13]:
train_y_pred, test_y_pred, cindex_train, cindex_test, best_epoch_num_overall = train_LFSurv(x_train_overall, age_train_overall, stage_i_train_overall, stage_ii_train_overall, race_white_train_overall, ytime_train_overall, yevent_train_overall,
                                                                                            x_test_overall, age_test_overall, stage_i_test_overall, stage_ii_test_overall, race_white_test_overall, ytime_test_overall, yevent_test_overall,
                                                                                            input_n, opt_dim, opt_dr1, opt_dr2, opt_lr, opt_l2, epoch_num, patience, dtype,
                                                                                            path = "sup_checkpoint_overall.pt")
print("Optimal L2: %s," %opt_l2, "optimal LR: %s," %opt_lr, "optimal dim: %s," %opt_dim, "optimal dr1: %s," %opt_dr1, "optimal dr2: %s," %opt_dr2, "best epoch number in tuning: %s." %best_epoch_num)
print("Optimal training C-index: %s," %opt_cindex_tr.round(4), "optimal validation C-index: %s." %opt_cindex_va.round(4))
print("Testing phase: training C-index: %s," %cindex_train.round(4), "testing C-index: %s." %cindex_test.round(4))

EarlyStopping counter: 20 out of 200
Training C-index: 0.7494, validation C-index: 0.6628.
EarlyStopping counter: 40 out of 200
EarlyStopping counter: 20 out of 200
EarlyStopping counter: 40 out of 200
Training C-index: 0.7513, validation C-index: 0.6655.
EarlyStopping counter: 60 out of 200
EarlyStopping counter: 80 out of 200
EarlyStopping counter: 20 out of 200
EarlyStopping counter: 40 out of 200
Training C-index: 0.7567, validation C-index: 0.6652.
EarlyStopping counter: 60 out of 200
EarlyStopping counter: 80 out of 200
EarlyStopping counter: 100 out of 200
EarlyStopping counter: 20 out of 200
Training C-index: 0.7635, validation C-index: 0.6754.
EarlyStopping counter: 40 out of 200
EarlyStopping counter: 60 out of 200
EarlyStopping counter: 80 out of 200
EarlyStopping counter: 100 out of 200
EarlyStopping counter: 120 out of 200
Training C-index: 0.7664, validation C-index: 0.6745.
Loading model, best epoch: 362.
Final training C-index: 0.7643, final validation C-index: 0.6865.
