# Run SurvTRACE on GBSG, Metabric, Support  datasets

In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
import os, sys
sys.path.append(os.path.abspath('../SurvTRACE'))
sys.path.append(os.path.abspath(".."))

In [4]:
import pdb
from collections import defaultdict
import matplotlib.pyplot as plt

from survtrace.dataset import load_data
from survtrace.evaluate_utils import Evaluator
from survtrace.utils import set_random_seed
from survtrace.model import SurvTraceSingle
from survtrace.train_utils import Trainer
from survtrace.config import STConfig
import prettytable as pt

In [5]:
# Hyper params for transformer training
hparams = {
    'batch_size': 128,
    'weight_decay': 1e-4,
    'learning_rate': 1e-3,
    'epochs': 20,
}

In [6]:
DATASETS = ['gbsg', 'metabric', 'support']

In [7]:
from btdsa.utils import create_logger
logger = create_logger('./logs_survtrace')

In [8]:
horizons = STConfig.horizons
headers = []
results = []

In [9]:
def run_experiment(dataset, hparams, show_plot=False):
    
    assert dataset in DATASETS
    logger.info(f"Running {dataset}...")
    headers.append(dataset)

    # define the setup parameters
    STConfig.data = dataset
    STConfig.early_stop_patience = 10

    seed = STConfig.seed # 1234
    set_random_seed(seed)
    
    # load data
    df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)

    # get model
    model = SurvTraceSingle(STConfig)

    # initialize a trainer
    trainer = Trainer(model)
    train_loss, val_loss = trainer.fit((df_train, df_y_train), (df_val, df_y_val),
            batch_size=hparams['batch_size'],
            epochs=hparams['epochs'],
            learning_rate=hparams['learning_rate'],
            weight_decay=hparams['weight_decay'],)
    
    # evaluate model
    evaluator = Evaluator(df, df_train.index)
    result_dict = evaluator.eval(model, (df_test, df_y_test), confidence=.95, nb_bootstrap=100)
    
    # Messages for pretty table summary
    cindex_mean, (cindex_lower, cindex_upper) = result_dict.pop('C-td-full')
    row_str = f"C-td (full): {cindex_mean:.6f} ({cindex_lower:.6f},{cindex_upper:.6f})\n"
    
    for horizon in horizons:
        keys = [ k for k in result_dict.keys() if k.startswith(str(horizon)) ]
        results_at_horizon = [result_dict[k] for k in keys]
        msg = [f"[{horizon*100}%]"]
        for k,res in zip(keys,results_at_horizon):
            metric = k.split('_')[1]
            mean, (lower, upper) = res
            msg.append(f"{metric}: {mean:.6f} ({lower:.6f},{upper:.6f})")
        row_str += (" ".join(msg) + "\n")
    results.append(row_str)
        
    if show_plot:
        # show training curves
        plt.plot(train_loss, label='train')
        plt.plot(val_loss, label='val')
        plt.legend(fontsize=20)
        plt.xlabel('epoch',fontsize=20)
        plt.ylabel('loss', fontsize=20)
        plt.show()
    print("done")

In [10]:
for dataset in DATASETS:
    run_experiment(dataset, hparams, show_plot=False)

Running gbsg...


GPU not found! will use cpu for training!
[Train-0]: 21.013157963752747
[Val-0]: 1.783380389213562
[Train-1]: 18.84564232826233
[Val-1]: 1.4817368984222412
[Train-2]: 16.127905249595642
[Val-2]: 1.5033963918685913
EarlyStopping counter: 1 out of 10
[Train-3]: 14.461576700210571
[Val-3]: 1.2597471475601196
[Train-4]: 14.117198586463928
[Val-4]: 1.4088094234466553
EarlyStopping counter: 1 out of 10
[Train-5]: 14.085430979728699
[Val-5]: 1.3812425136566162
EarlyStopping counter: 2 out of 10
[Train-6]: 14.063504219055176
[Val-6]: 1.2597274780273438
[Train-7]: 13.964786410331726
[Val-7]: 1.2904506921768188
EarlyStopping counter: 1 out of 10
[Train-8]: 13.884294152259827
[Val-8]: 1.2595832347869873
[Train-9]: 13.853302955627441
[Val-9]: 1.2500712871551514
[Train-10]: 13.820022940635681
[Val-10]: 1.248732089996338
[Train-11]: 13.830458879470825
[Val-11]: 1.2691510915756226
EarlyStopping counter: 1 out of 10
[Train-12]: 13.7943834066391
[Val-12]: 1.2588516473770142
EarlyStopping counter: 2 out

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
Running metabric...


0.95 confidence C-td-full average: 0.4260096754936201
0.95 confidence C-td-full interval: (0.42199513579176623,0.430024215195474)
0.95 confidence 0.25_Ctd_ipcw average: 0.7508380816661646
0.95 confidence 0.25_Ctd_ipcw interval: (0.7458837703473616,0.7557923929849676)
0.95 confidence 0.25_brier average: 0.10417683520961689
0.95 confidence 0.25_brier interval: (0.1027112077984259,0.10564246262080788)
0.95 confidence 0.25_auroc average: 0.7689572671388553
0.95 confidence 0.25_auroc interval: (0.7637469935920419,0.7741675406856686)
0.95 confidence 0.5_Ctd_ipcw average: 0.7081889670124568
0.95 confidence 0.5_Ctd_ipcw interval: (0.7046607584681115,0.7117171755568021)
0.95 confidence 0.5_brier average: 0.17456263927551036
0.95 confidence 0.5_brier interval: (0.17309251784737026,0.17603276070365045)
0.95 confidence 0.5_auroc average: 0.7364485871018399
0.95 confidence 0.5_auroc interval: (0.7324879942659903,0.7404091799376895)
0.95 confidence 0.75_Ctd_ipcw average: 0.6885352706552641
0.95 conf

Running support...


0.95 confidence C-td-full average: 0.39128188779228706
0.95 confidence C-td-full interval: (0.38691744546764656,0.39564633011692757)
0.95 confidence 0.25_Ctd_ipcw average: 0.7000377332878753
0.95 confidence 0.25_Ctd_ipcw interval: (0.6942692190869936,0.7058062474887571)
0.95 confidence 0.25_brier average: 0.11823745637105594
0.95 confidence 0.25_brier interval: (0.11650664508003035,0.11996826766208153)
0.95 confidence 0.25_auroc average: 0.7108541725946422
0.95 confidence 0.25_auroc interval: (0.7047078869602688,0.7170004582290157)
0.95 confidence 0.5_Ctd_ipcw average: 0.6761681488175313
0.95 confidence 0.5_Ctd_ipcw interval: (0.6720759295280795,0.6802603681069831)
0.95 confidence 0.5_brier average: 0.18966948681306728
0.95 confidence 0.5_brier interval: (0.18806409564185503,0.19127487798427953)
0.95 confidence 0.5_auroc average: 0.7013993590668668
0.95 confidence 0.5_auroc interval: (0.696356079661839,0.7064426384718946)
0.95 confidence 0.75_Ctd_ipcw average: 0.6478075544590226
0.95 c

In [11]:
tb = pt.PrettyTable(title="SurvTrace")
tb.field_names = headers
tb.add_row(results)
logger.info(tb)


+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                                                                         SurvTrace                                                                                                                                                                         |
+-------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------

In [12]:
model = SurvTraceSingle(STConfig)
model

SurvTraceSingle(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(25, 16)
    (LayerNorm): LayerNorm((16,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-2): 3 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=16, out_features=16, bias=True)
            (key): Linear(in_features=16, out_features=16, bias=True)
            (value): Linear(in_features=16, out_features=16, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=16, out_features=16, bias=True)
            (LayerNorm): LayerNorm((16,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=16, out_features=64

In [13]:
import numpy as np
import pandas as pd

In [14]:
data = np.random.randint(0, 100, size=(100, 5))
df = pd.DataFrame(data)
df.head()

Unnamed: 0,0,1,2,3,4
0,48,74,77,54,33
1,2,10,49,67,99
2,86,34,66,79,10
3,80,46,78,48,69
4,67,61,32,34,46


In [15]:
from sklearn.model_selection import train_test_split

In [16]:
seed = 10

label_df = df.drop([4], axis=1)
del df[4]
train_df, test_df, train_label_df, test_label_df = train_test_split(df, label_df, test_size=0.3, random_state=seed)

label_data = data[:, 4].copy()
data = data[:, :4]
train_data, test_data, train_label_data, test_label_data = train_test_split(data, label_data, test_size=0.3, random_state=seed)

In [17]:
set_random_seed(seed)
test_df_res = test_df.sample(10, replace=True)
test_label_df_res = test_df_res.loc[test_df_res.index]
test_label_df_res.head()

Unnamed: 0,0,1,2,3
68,63,18,42,29
63,30,75,7,6
63,30,75,7,6
66,78,72,4,89
74,68,42,71,3


In [18]:
set_random_seed(seed)
test_data_res = pd.DataFrame(test_data).sample(10, replace=True)
test_label_data_res = test_data_res.loc[test_data_res.index]
test_label_data_res.head()

Unnamed: 0,0,1,2,3
9,63,18,42,29
29,30,75,7,6
29,30,75,7,6
4,78,72,4,89
15,68,42,71,3
