In [1]:
import numpy as np
import pandas as pd
import torch
import torchtuples as tt
import pickle
import warnings
import time
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

from itertools import product
from torch import Tensor
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper

from pycox import models
from pycox import datasets
from pycox.datasets import metabric, gbsg, support, flchain, nwtco
from pycox.models import CoxCC, CoxPH, CoxTime, DeepHitSingle, PCHazard
from pycox.models.loss import CoxPHLoss
from pycox.models.cox_time import MLPVanillaCoxTime
from pycox.evaluation import EvalSurv

### Training step

In [2]:
models = [CoxPH]

lst_layers = [4, 2, 1]
lst_nodes_per_layer = [512, 256, 128, 64]
lst_dropout = [0.0, 0.1, 0.2, 0.5, 0.7]
lst_weight_decay = [0.4, 0.2, 0.1, 0.05, 0.02, 0.01, 0.0]

parameters = [models, lst_layers, lst_nodes_per_layer, lst_dropout,
              lst_weight_decay]

parameters = list(product(*parameters))

In [3]:
datasets = {"metabric" : [metabric,
                           [["x0", "x1", "x2", "x3", "x8"],
                            ["x4", "x5", "x6", "x7"]]],
             "gbsg" : [gbsg,
                           [["x3", "x4", "x5", "x6"],
                            ["x0", "x1", "x2"]]],
             "support" : [support,
                           [["x0", "x7", "x8", "x9", "x10", "x11", "x12", "x13"],
                            ["x1", "x2", "x3", "x4", "x5", "x6"]]],
             "flchain" : [flchain,
                           [["x0", "x2", "x3", "x4", "x6"],
                            ["x1", "x5", "x7"],
                            ["x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "duration", "event"]]],
             "nwtco" : [nwtco,
                           [["x1"],
                            ["x0", "x2", "x3", "x4", "x5"],
                            ["x0", "x1", "x2", "x3", "x4", "x5", "duration", "event"]]]
            }

dataset_names = ["support", "metabric", "flchain", "nwtco", "gbsg"]
num_experiments = 5

In [4]:
start = time.time()

results = []

for name in dataset_names:
    print("Dataset: {}".format(name))
    i = 0
    np.random.seed(1905)
    df_train = datasets[name][0].read_df()
    if name == "flchain" or name == "nwtco":
        df_train.columns = datasets[name][1][2]
    df_test = df_train.sample(frac = 0.2)
    df_train = df_train.drop(df_test.index)
    df_val = df_train.sample(frac = 0.2)
    df_train = df_train.drop(df_val.index)
    
    cols_standardize = datasets[name][1][0]
    cols_leave = datasets[name][1][1]
    
    standardize = [([col], StandardScaler()) for col in cols_standardize]
    leave = [(col, None) for col in cols_leave]
    
    x_mapper = DataFrameMapper(standardize + leave)

    x_train = x_mapper.fit_transform(df_train).astype('float32')
    x_val = x_mapper.transform(df_val).astype('float32')
    x_test = x_mapper.fit_transform(df_test).astype('float32')
    
    get_target = lambda df: (df['duration'].values, df['event'].values)
    y_train = get_target(df_train)
    y_val = get_target(df_val)
    durations_test, events_test = get_target(df_test)
    val = tt.tuplefy(x_val, y_val)
        
    train_batch = df_train.shape[0]
    val_batch = df_val.shape[0]

    for params in parameters:
        i += 1
        params = list(params)
        cs = []
        for j in range(num_experiments):
            
            in_features = x_train.shape[1]
            out_features = 1
            batch_norm = True
            output_bias = False
            epochs = 512
            callbacks = [tt.callbacks.EarlyStopping(checkpoint_model = False, load_best = False)]
            verbose = False
            
            model_, layers, nodes_per_layer, dropout, wd = params
            num_nodes = [nodes_per_layer] * layers
                        
            net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
                                          dropout, output_bias = output_bias)
            model = model_(net, tt.optim.AdamWR(decoupled_weight_decay = wd))
            
            lr = model.lr_finder(x_train, y_train, train_batch, tolerance = 5).get_best_lr()
            model.optimizer.set_lr(lr)
            model.fit(x_train, y_train, train_batch, epochs, callbacks, verbose,
                      val_data = val, val_batch_size = val_batch)
            _ = model.compute_baseline_hazards()
            cs.append(EvalSurv(model.predict_surv_df(x_test),
                               durations_test, events_test, censor_surv='km').concordance_td())

        c = np.median(cs)
        c_std = np.std(cs)
        result = [name] + params + [c, c_std, cs]
        results.append(result)
        print("iterations: {}/{}, {}".format(i, len(parameters),
                                             result[2:6] + [round(result[6], 4), round(result[7], 4)]))
        
end = time.time()

print("wall time: {}".format(end - start))

Dataset: support
iterations: 1/420, [4, 512, 0.0, 0.4, 0.5758, 0.0024]
iterations: 2/420, [4, 512, 0.0, 0.2, 0.5797, 0.0076]
iterations: 3/420, [4, 512, 0.0, 0.1, 0.6205, 0.002]
iterations: 4/420, [4, 512, 0.0, 0.05, 0.6162, 0.0024]
iterations: 5/420, [4, 512, 0.0, 0.02, 0.6079, 0.0073]
iterations: 6/420, [4, 512, 0.0, 0.01, 0.605, 0.0053]
iterations: 7/420, [4, 512, 0.0, 0.0, 0.5965, 0.0025]
iterations: 8/420, [4, 512, 0.1, 0.4, 0.5814, 0.0037]
iterations: 9/420, [4, 512, 0.1, 0.2, 0.587, 0.0089]
iterations: 10/420, [4, 512, 0.1, 0.1, 0.6189, 0.017]
iterations: 11/420, [4, 512, 0.1, 0.05, 0.619, 0.0018]
iterations: 12/420, [4, 512, 0.1, 0.02, 0.608, 0.0037]
iterations: 13/420, [4, 512, 0.1, 0.01, 0.6049, 0.0021]
iterations: 14/420, [4, 512, 0.1, 0.0, 0.5967, 0.0109]
iterations: 15/420, [4, 512, 0.2, 0.4, 0.581, 0.0042]
iterations: 16/420, [4, 512, 0.2, 0.2, 0.5847, 0.0114]
iterations: 17/420, [4, 512, 0.2, 0.1, 0.6198, 0.0022]
iterations: 18/420, [4, 512, 0.2, 0.05, 0.6139, 0.0374]
it

In [5]:
columns = ["Dataset", "Model", "Layers", "# of Nodes", "DORate", "Weight Decay",
           "C-Index (median)", "C-Index (std)", "C-Index (listed)"]
pd.DataFrame(results, columns = columns).to_csv("results_CoxMLP.csv", index = False)

with open('results_CoxMLP.pickle', 'wb') as handle:
    pickle.dump(results, handle, protocol = pickle.HIGHEST_PROTOCOL)

handle.close()