# Run SurvTRACE on GBSG, Metabric, Support  datasets

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import os, sys
sys.path.append(os.path.abspath('../SurvTRACE'))
sys.path.append(os.path.abspath(".."))

In [3]:
import pdb
from collections import defaultdict
import matplotlib.pyplot as plt

from survtrace.dataset import load_data
from survtrace.evaluate_utils import Evaluator
from survtrace.utils import set_random_seed
from survtrace.model import SurvTraceSingle
from survtrace.train_utils import Trainer
from survtrace.config import STConfig
import prettytable as pt

In [4]:
# Hyper params for transformer training
hparams = {
    'batch_size': 64,
    'weight_decay': 0,
    'learning_rate': 1e-3,
    'epochs': 20,
}

In [5]:
DATASETS = ['gbsg', 'metabric', 'support']

In [6]:
from btdsa.utils import create_logger
logger = create_logger('./logs_survtrace')

In [7]:
horizons = STConfig.horizons
headers = []
results = []

In [8]:
def run_experiment(dataset, hparams, show_plot=False):
    
    assert dataset in DATASETS
    logger.info(f"Running {dataset}...")
    headers.append(dataset)

    # define the setup parameters
    STConfig.data = dataset
    STConfig.early_stop_patience = 10

    seed = STConfig.seed # 1234
    set_random_seed(seed)
    
    # load data
    df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)

    # get model
    model = SurvTraceSingle(STConfig)

    # initialize a trainer
    trainer = Trainer(model)
    train_loss, val_loss = trainer.fit((df_train, df_y_train), (df_val, df_y_val),
            batch_size=hparams['batch_size'],
            epochs=hparams['epochs'],
            learning_rate=hparams['learning_rate'],
            weight_decay=hparams['weight_decay'],)
    
    # evaluate model
    evaluator = Evaluator(df, df_train.index)
    result_dict = evaluator.eval(model, (df_test, df_y_test), confidence=.95, nb_bootstrap=100)
    
    # Messages for pretty table summary
    cindex_mean, (cindex_lower, cindex_upper) = result_dict.pop('C-td-full')
    row_str = f"C-td (full): {cindex_mean:.6f} ({cindex_lower:.6f},{cindex_upper:.6f})\n"
    
    for horizon in horizons:
        keys = [ k for k in result_dict.keys() if k.startswith(str(horizon)) ]
        results_at_horizon = [result_dict[k] for k in keys]
        msg = [f"[{horizon*100}%]"]
        for k,res in zip(keys,results_at_horizon):
            metric = k.split('_')[1]
            mean, (lower, upper) = res
            msg.append(f"{metric}: {mean:.6f} ({lower:.6f},{upper:.6f})")
        row_str += (" ".join(msg) + "\n")
    results.append(row_str)
        
    if show_plot:
        # show training curves
        plt.plot(train_loss, label='train')
        plt.plot(val_loss, label='val')
        plt.legend(fontsize=20)
        plt.xlabel('epoch',fontsize=20)
        plt.ylabel('loss', fontsize=20)
        plt.show()
    print("done")

In [9]:
for dataset in DATASETS:
    run_experiment(dataset, hparams, show_plot=False)

Running gbsg...


use pytorch-cuda for training.
[Train-0]: 39.750836968421936
[Val-0]: 1.430763840675354
[Train-1]: 30.77250039577484
[Val-1]: 1.3748854398727417
[Train-2]: 27.913107752799988
[Val-2]: 1.3679455518722534
[Train-3]: 27.77131462097168
[Val-3]: 1.3514004945755005
[Train-4]: 27.594568848609924
[Val-4]: 1.3398962020874023
[Train-5]: 27.405004382133484
[Val-5]: 1.316502332687378
[Train-6]: 27.383405804634094
[Val-6]: 1.3220163583755493
EarlyStopping counter: 1 out of 10
[Train-7]: 27.330041527748108
[Val-7]: 1.3244298696517944
EarlyStopping counter: 2 out of 10
[Train-8]: 27.199955224990845
[Val-8]: 1.3601561784744263
EarlyStopping counter: 3 out of 10
[Train-9]: 27.118715286254883
[Val-9]: 1.3365825414657593
EarlyStopping counter: 4 out of 10
[Train-10]: 27.14840805530548
[Val-10]: 1.3214644193649292
EarlyStopping counter: 5 out of 10
[Train-11]: 27.093369007110596
[Val-11]: 1.3420758247375488
EarlyStopping counter: 6 out of 10
[Train-12]: 27.070682406425476
[Val-12]: 1.3670085668563843
Earl

Running metabric...


0.95 confidence C-td-full average: 0.42831365517917047
0.95 confidence C-td-full interval: (0.42454289137324275,0.4320844189850982)
0.95 confidence 0.25_Ctd_ipcw average: 0.757287186930771
0.95 confidence 0.25_Ctd_ipcw interval: (0.7527851351175945,0.7617892387439475)
0.95 confidence 0.25_brier average: 0.10685194902783396
0.95 confidence 0.25_brier interval: (0.10518068051829148,0.10852321753737644)
0.95 confidence 0.25_auroc average: 0.7733435646420301
0.95 confidence 0.25_auroc interval: (0.7684410341004999,0.7782460951835604)
0.95 confidence 0.5_Ctd_ipcw average: 0.709094023930823
0.95 confidence 0.5_Ctd_ipcw interval: (0.7058863150366925,0.7123017328249535)
0.95 confidence 0.5_brier average: 0.18027988023171676
0.95 confidence 0.5_brier interval: (0.1787034904203576,0.18185627004307592)
0.95 confidence 0.5_auroc average: 0.7339031591742499
0.95 confidence 0.5_auroc interval: (0.7302379233183384,0.7375683950301614)
0.95 confidence 0.75_Ctd_ipcw average: 0.6938200794316799
0.95 conf

Running support...


0.95 confidence C-td-full average: 0.3930680131427969
0.95 confidence C-td-full interval: (0.3881816309514039,0.39795439533418986)
0.95 confidence 0.25_Ctd_ipcw average: 0.7329977692480122
0.95 confidence 0.25_Ctd_ipcw interval: (0.7275875857725627,0.7384079527234617)
0.95 confidence 0.25_brier average: 0.10745381022805552
0.95 confidence 0.25_brier interval: (0.10534636896587672,0.10956125149023431)
0.95 confidence 0.25_auroc average: 0.7499634271531764
0.95 confidence 0.25_auroc interval: (0.7442645994216192,0.7556622548847337)
0.95 confidence 0.5_Ctd_ipcw average: 0.7101277790942627
0.95 confidence 0.5_Ctd_ipcw interval: (0.7064251920218643,0.7138303661666612)
0.95 confidence 0.5_brier average: 0.17446746571622027
0.95 confidence 0.5_brier interval: (0.17282453560773314,0.1761103958247074)
0.95 confidence 0.5_auroc average: 0.7359222391947706
0.95 confidence 0.5_auroc interval: (0.7314616797560952,0.740382798633446)
0.95 confidence 0.75_Ctd_ipcw average: 0.6852725413477013
0.95 conf

In [10]:
tb = pt.PrettyTable(title="SurvTrace")
tb.field_names = headers
tb.add_row(results)
logger.info(tb)


+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                                                                         SurvTrace                                                                                                                                                                         |
+-------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------