In [1]:
# %pip install -r requirements.txt

In [2]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch 
import easydict
from torch import Tensor, device, dtype, nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from sksurv.metrics import concordance_index_ipcw, brier_score
import torchtuples as tt

from pycox.datasets import metabric, support
from pycox.models import LogisticHazard, DeepHit
from pycox.preprocessing.feature_transforms import OrderedCategoricalLong
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from pycox.evaluation import EvalSurv
from pycox.models.loss import NLLPCHazardLoss
from pycox.preprocessing.discretization import (make_cuts, IdxDiscUnknownC, _values_if_series,
    DiscretizeUnknownC, Duration2Idx)


In [3]:
from survtrace.dataset import load_data
from survtrace.evaluate_utils import Evaluator
from survtrace.utils import set_random_seed
from survtrace.model import SurvTraceMulti, SurvTraceSingle
from survtrace.train_utils import Trainer
from survtrace.config import STConfig

In [4]:
np.random.seed(42)
_ = torch.manual_seed(42)
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## DeepHit Specifications

In [5]:
class SimpleMLP(torch.nn.Module):
    def __init__(self, in_features, num_nodes, num_risks, out_features, batch_norm=True,
                 dropout=None):
        super().__init__()
        self.num_risks = num_risks
        self.mlp = tt.practical.MLPVanilla(
            in_features, num_nodes, num_risks * out_features,
            batch_norm, dropout,
        )
        
    def forward(self, input):
        out = self.mlp(input)
        return out.view(out.size(0), self.num_risks, -1)

class CauseSpecificNet(torch.nn.Module):
    def __init__(self, in_features, num_nodes_shared, num_nodes_indiv, num_risks,
                 out_features, batch_norm=True, dropout=None):
        super().__init__()
        self.shared_net = tt.practical.MLPVanilla(
            in_features, num_nodes_shared[:-1], num_nodes_shared[-1],
            batch_norm, dropout,
        )
        self.risk_nets = torch.nn.ModuleList()
        for _ in range(num_risks):
            net = tt.practical.MLPVanilla(
                num_nodes_shared[-1], num_nodes_indiv, out_features,
                batch_norm, dropout,
            )
            self.risk_nets.append(net)

    def forward(self, input):
        out = self.shared_net(input)
        out = [net(out) for net in self.risk_nets]
        out = torch.stack(out, dim=1)
        return out

In [6]:
STConfig['data'] = 'seer'
STConfig['num_hidden_layers'] = 2
STConfig['hidden_size'] = 16
STConfig['intermediate_size'] = 64
STConfig['num_attention_heads'] = 2
STConfig['initializer_range'] = .02
STConfig['early_stop_patience'] = 5

df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)
train_set = (df_train, df_y_train)
val_set = (df_val, df_y_val)




In [7]:
df.iloc[:, 0:13].describe().apply(lambda s: s.apply('{0:.2f}'.format))

Unnamed: 0,Sex,Year of diagnosis,"Race recode (W, B, AI, API)",Histologic Type ICD-O-3,Laterality,Sequence number,ER Status Recode Breast Cancer (1990+),PR Status Recode Breast Cancer (1990+),Summary stage 2000 (1998-2017),RX Summ--Surg Prim Site (1998+),Reason no cancer-directed surgery,First malignant primary indicator,Diagnostic Confirmation
count,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0
mean,0.01,6.34,2.72,1.43,2.48,4.4,1.71,1.59,1.24,8.74,4.55,0.82,3.0
std,0.09,2.79,0.62,4.58,1.5,2.44,0.56,0.61,0.56,8.04,1.38,0.38,0.18
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,5.0,3.0,0.0,1.0,1.0,2.0,1.0,1.0,4.0,5.0,1.0,3.0
50%,0.0,7.0,3.0,0.0,1.0,6.0,2.0,2.0,1.0,5.0,5.0,1.0,3.0
75%,0.0,9.0,3.0,1.0,4.0,6.0,2.0,2.0,2.0,13.0,5.0,1.0,3.0
max,1.0,10.0,3.0,74.0,4.0,6.0,2.0,2.0,2.0,47.0,5.0,1.0,5.0


In [8]:
df.iloc[:, 14:25].describe().apply(lambda s: s.apply('{0:.2f}'.format))


Unnamed: 0,Regional nodes examined (1988+),CS tumor size (2004-2015),Total number of benign/borderline tumors for patient,Total number of in situ/malignant tumors for patient,duration,event_heart,event_breast,event_0,event_1,duration_disc,proportion
count,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0
mean,-0.0,0.0,0.0,0.0,67.41,0.05,0.18,0.18,0.05,2.36,0.51
std,1.0,1.0,1.0,1.0,31.45,0.21,0.39,0.39,0.21,0.97,0.29
min,-0.48,-0.38,-0.08,-0.57,1.0,0.0,0.0,0.0,0.0,0.0,0.02
25%,-0.43,-0.34,-0.08,-0.57,48.0,0.0,0.0,0.0,0.0,2.0,0.25
50%,-0.32,-0.3,-0.08,-0.57,69.0,0.0,0.0,0.0,0.0,3.0,0.52
75%,-0.0,-0.25,-0.08,0.94,92.0,0.0,0.0,0.0,0.0,3.0,0.76
max,4.83,3.44,57.86,28.18,121.0,1.0,1.0,1.0,1.0,3.0,1.0


## DeepHit - SEER

In [9]:



# df_train, df_y_train = train_set
# x_train = df_train.values.astype('float32')
# y_train = (df_y_train['duration'].values.astype('int64'), (df_y_train['event_0'].values + df_y_train['event_1'].values * 2).astype('int64'))

# df_val, df_y_val = val_set
# x_val = df_val.values.astype('float32')
# y_val = (df_y_val['duration'].values.astype('int64'), (df_y_val['event_0'].values + df_y_val['event_1'].values * 2).astype('int64'))
# val = (x_val, y_val)

# in_features = x_train.shape[1]
# num_nodes_shared = [64, 64]
# num_nodes_indiv = [32]
# num_risks = 2
# out_features = len(STConfig['labtrans'].cuts)
# batch_norm = True
# dropout = 0.1

# net = CauseSpecificNet(in_features, num_nodes_shared, num_nodes_indiv, num_risks,
#                     out_features, batch_norm, dropout).to(DEVICE)
# optimizer = tt.optim.AdamWR(lr=0.01, decoupled_weight_decay=0.01,
#                         cycle_eta_multiplier=0.8)
# DeepHitModel = DeepHit(net, optimizer, alpha=0.2, sigma=0.1,
#             duration_index=STConfig['labtrans'].cuts)
# epochs = 50
# batch_size = 256
# callbacks = [tt.callbacks.EarlyStoppingCycle()]
# verbose = True 

# x_train, y_train = torch.tensor(x_train).to(DEVICE), (torch.tensor(y_train[0]).to(DEVICE), torch.tensor(y_train[1]).to(DEVICE))
# x_val, y_val = torch.tensor(x_val).to(DEVICE), (torch.tensor(y_val[0]).to(DEVICE), torch.tensor(y_val[1]).to(DEVICE))
# val = (x_val, y_val)

# log = DeepHitModel.fit(x_train, y_train, batch_size, epochs, callbacks, verbose, val_data=val)
# log = log.to_pandas()




In [10]:
# predictions = DeepHitModel.predict(x_val)
# event_types = y_val[0].cpu().numpy()
# event_times = y_val[1].cpu().numpy()

# cif = DeepHitModel.predict_cif(x_val)
# cif1 = pd.DataFrame(cif[0].cpu().numpy(), DeepHitModel.duration_index)
# ev1 = EvalSurv(cif1, event_times, event_types > 0, censor_surv='km')
# ev1.concordance_td()

# cif1 = pd.DataFrame(cif[0].cpu().numpy(), DeepHitModel.duration_index)
# cif2 = pd.DataFrame(cif[1].cpu().numpy(), DeepHitModel.duration_index)
# ev1 = EvalSurv(1-cif1, y_val_duration, y_val_event == 1, censor_surv='km')
# ev2 = EvalSurv(1-cif2, y_val_duration, y_val_event == 2, censor_surv='km')

# cis.append(concordance_index_ipcw(et_train, et_test, risk[:, i+1].to("cpu").numpy(), times[i])[0])  

## SurvTRACE - SEER

In [11]:
hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeer = SurvTraceMulti(STConfig).to(DEVICE)
SurvTraceSeer_trainer = Trainer(SurvTraceSeer)
train_loss, val_loss = SurvTraceSeer_trainer.fit((df_train, df_y_train), (df_val, df_y_val),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)


use pytorch-cuda for training.


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at ..\torch\csrc\utils\python_arg_parser.cpp:1420.)
  next_m.mul_(beta1).add_(1 - beta1, grad)


[Train-0]: 2.9760685688784334
[Val-0]: 2.303372383117676
[Train-1]: 0.7216343528964893
[Val-1]: 0.6906883716583252
[Train-2]: 0.6844168733171865
[Val-2]: 0.6782993078231812
[Train-3]: 0.6789033678518671
[Val-3]: 0.6754992008209229
[Train-4]: 0.677108012089113
[Val-4]: 0.6724416017532349
[Train-5]: 0.6746655793822541
[Val-5]: 0.6721491813659668
[Train-6]: 0.6736950274227428
[Val-6]: 0.6704277992248535
[Train-7]: 0.6725244483574718
[Val-7]: 0.6695846319198608
[Train-8]: 0.6718956274645669
[Val-8]: 0.6687475442886353
[Train-9]: 0.6711800416716102
[Val-9]: 0.6684685945510864
[Train-10]: 0.6706843775551335
[Val-10]: 0.6687781810760498
EarlyStopping counter: 1 out of 5
[Train-11]: 0.6701725078683322
[Val-11]: 0.668524980545044
EarlyStopping counter: 2 out of 5
[Train-12]: 0.6699310697260357
[Val-12]: 0.6672793626785278
[Train-13]: 0.6695130887080212
[Val-13]: 0.6674829125404358
EarlyStopping counter: 1 out of 5
[Train-14]: 0.6689976824789631
[Val-14]: 0.6685466766357422
EarlyStopping counter

In [35]:
evaluator = Evaluator(df, df_train.index)
evaluator.eval(SurvTraceSeer, (df_test, df_y_test))

******************************
start evaluation
******************************


  concordant += n_con
  concordant += n_con
  concordant += n_con


Event: 0 For 0.25 quantile,
TD Concordance Index - IPCW: 0.9033432086254986
Brier Score: 0.03593080123945568
Event: 0 For 0.5 quantile,
TD Concordance Index - IPCW: 0.8820621444269189
Brier Score: 0.06045311869879083
Event: 0 For 0.75 quantile,
TD Concordance Index - IPCW: 0.8649312485483952
Brier Score: 0.08189422303173148
Event: 1 For 0.25 quantile,
TD Concordance Index - IPCW: 0.7862581957187452
Brier Score: 0.007403492846374771
Event: 1 For 0.5 quantile,
TD Concordance Index - IPCW: 0.7870278852379744
Brier Score: 0.016000015042350945
Event: 1 For 0.75 quantile,
TD Concordance Index - IPCW: 0.776352535848915
Brier Score: 0.027978165449578166


defaultdict(list,
            {'0.25_ipcw_0': 0.9033432086254986,
             '0.25_brier_0': 0.03593080123945568,
             '0.5_ipcw_0': 0.8820621444269189,
             '0.5_brier_0': 0.06045311869879083,
             '0.75_ipcw_0': 0.8649312485483952,
             '0.75_brier_0': 0.08189422303173148,
             '0.25_ipcw_1': 0.7862581957187452,
             '0.25_brier_1': 0.007403492846374771,
             '0.5_ipcw_1': 0.7870278852379744,
             '0.5_brier_1': 0.016000015042350945,
             '0.75_ipcw_1': 0.776352535848915,
             '0.75_brier_1': 0.027978165449578166})

## SurvTRACE - SUPPORT

In [13]:
# define the setup parameters
STConfig['data'] = 'support'
STConfig['num_event'] = 1

hparams = {
    'batch_size': 128,
    'weight_decay': 0,
    'learning_rate': 1e-3,
    'epochs': 20,
}
df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)

df.describe().apply(lambda s: s.apply('{0:.2f}'.format))



Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,duration,event
count,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0
mean,62.63,0.44,1.87,1.25,0.19,0.03,0.94,84.53,97.48,23.36,37.11,137.57,12.35,1.78,478.64,0.68
std,15.62,0.5,1.34,0.62,0.4,0.18,0.58,27.82,31.7,9.63,1.26,6.07,9.27,1.69,560.83,0.47
min,18.04,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,31.7,110.0,0.0,0.1,3.0,0.0
25%,52.75,0.0,1.0,1.0,0.0,0.0,1.0,63.0,72.0,18.0,36.2,134.0,7.0,0.9,26.0,0.0
50%,64.83,0.0,2.0,1.0,0.0,0.0,1.0,77.0,100.0,24.0,36.7,137.0,10.6,1.2,231.0,1.0
75%,74.03,1.0,3.0,1.0,0.0,0.0,1.0,107.0,120.0,28.0,38.2,141.0,15.3,1.9,763.0,1.0
max,101.85,1.0,9.0,5.0,1.0,1.0,2.0,195.0,300.0,90.0,41.7,181.0,200.0,21.5,2029.0,1.0


In [14]:
model = SurvTraceSingle(STConfig)
trainer = Trainer(model)
train_loss, val_loss = trainer.fit((df_train, df_y_train), (df_val, df_y_val),
        batch_size=hparams['batch_size'],
        epochs=hparams['epochs'],
        learning_rate=hparams['learning_rate'],
        weight_decay=hparams['weight_decay'],)

use pytorch-cuda for training.
[Train-0]: 69.49956834316254
[Val-0]: 1.3877462148666382
[Train-1]: 58.55283069610596
[Val-1]: 1.349938154220581
[Train-2]: 58.08263158798218
[Val-2]: 1.370281457901001
EarlyStopping counter: 1 out of 5
[Train-3]: 57.78693330287933
[Val-3]: 1.3436068296432495
[Train-4]: 57.56906247138977
[Val-4]: 1.3464226722717285
EarlyStopping counter: 1 out of 5
[Train-5]: 57.29101920127869
[Val-5]: 1.3460872173309326
EarlyStopping counter: 2 out of 5
[Train-6]: 57.27627182006836
[Val-6]: 1.34273099899292
[Train-7]: 57.16512668132782
[Val-7]: 1.3605945110321045
EarlyStopping counter: 1 out of 5
[Train-8]: 56.99452233314514
[Val-8]: 1.3547546863555908
EarlyStopping counter: 2 out of 5
[Train-9]: 56.875293254852295
[Val-9]: 1.346466064453125
EarlyStopping counter: 3 out of 5
[Train-10]: 56.77872693538666
[Val-10]: 1.3448930978775024
EarlyStopping counter: 4 out of 5
[Train-11]: 56.64685416221619
[Val-11]: 1.3438853025436401
EarlyStopping counter: 5 out of 5
early stops a

In [15]:
evaluator = Evaluator(df, df_train.index)
evaluator.eval(model, (df_test, df_y_test))

******************************
start evaluation
******************************
For 0.25 quantile,
TD Concordance Index - IPCW: 0.6521908560171059
Brier Score: 0.13451221546819953
For 0.5 quantile,
TD Concordance Index - IPCW: 0.6228850608765065
Brier Score: 0.20879744735506525
For 0.75 quantile,
TD Concordance Index - IPCW: 0.6015098129932681
Brier Score: 0.2330643214526076


defaultdict(list,
            {'0.25_ipcw': 0.6521908560171059,
             '0.25_brier': 0.13451221546819953,
             '0.5_ipcw': 0.6228850608765065,
             '0.5_brier': 0.20879744735506525,
             '0.75_ipcw': 0.6015098129932681,
             '0.75_brier': 0.2330643214526076})

## SurvTRACE - METABRIC

In [16]:

# define the setup parameters
STConfig['data'] = 'metabric'
STConfig['num_event'] = 1

hparams = {
    'batch_size': 128,
    'weight_decay': 0,
    'learning_rate': 1e-3,
    'epochs': 20,
}
df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)

df.describe().apply(lambda s: s.apply('{0:.2f}'.format))




Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,duration,event
count,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0
mean,6.21,6.24,10.77,5.87,0.62,0.6,0.21,0.76,61.09,125.03,0.58
std,0.86,1.02,1.36,0.34,0.49,0.49,0.41,0.43,12.98,76.33,0.49
min,5.16,4.86,6.37,5.1,0.0,0.0,0.0,0.0,21.93,0.0,0.0
25%,5.69,5.41,9.97,5.62,0.0,0.0,0.0,1.0,51.38,60.82,0.0
50%,5.95,5.88,10.53,5.82,1.0,1.0,0.0,1.0,61.77,114.9,1.0
75%,6.46,6.9,11.16,6.06,1.0,1.0,0.0,1.0,70.59,184.47,1.0
max,14.44,9.93,14.64,7.66,1.0,1.0,1.0,1.0,96.29,355.2,1.0


In [17]:

model = SurvTraceSingle(STConfig)
trainer = Trainer(model)
train_loss, val_loss = trainer.fit((df_train, df_y_train), (df_val, df_y_val),
        batch_size=hparams['batch_size'],
        epochs=hparams['epochs'],
        learning_rate=hparams['learning_rate'],
        weight_decay=hparams['weight_decay'],)

use pytorch-cuda for training.
[Train-0]: 17.016679644584656
[Val-0]: 1.3102798461914062
[Train-1]: 15.152567625045776
[Val-1]: 1.1598601341247559
[Train-2]: 13.263692021369934
[Val-2]: 1.145556092262268
[Train-3]: 11.608664631843567
[Val-3]: 1.1200928688049316
[Train-4]: 10.87654435634613
[Val-4]: 1.124851107597351
EarlyStopping counter: 1 out of 5
[Train-5]: 10.790985822677612
[Val-5]: 1.1358267068862915
EarlyStopping counter: 2 out of 5
[Train-6]: 10.58698844909668
[Val-6]: 1.1241111755371094
EarlyStopping counter: 3 out of 5
[Train-7]: 10.750947892665863
[Val-7]: 1.0970479249954224
[Train-8]: 10.778635561466217
[Val-8]: 1.101806402206421
EarlyStopping counter: 1 out of 5
[Train-9]: 10.622157514095306
[Val-9]: 1.1121330261230469
EarlyStopping counter: 2 out of 5
[Train-10]: 10.509716629981995
[Val-10]: 1.125728726387024
EarlyStopping counter: 3 out of 5
[Train-11]: 10.431023359298706
[Val-11]: 1.0957859754562378
[Train-12]: 10.583075225353241
[Val-12]: 1.1066266298294067
EarlyStoppi

In [18]:
evaluator = Evaluator(df, df_train.index)
evaluator.eval(model, (df_test, df_y_test))

******************************
start evaluation
******************************
For 0.25 quantile,
TD Concordance Index - IPCW: 0.7135106078431047
Brier Score: 0.11492162956562897
For 0.5 quantile,
TD Concordance Index - IPCW: 0.658611861785142
Brier Score: 0.19810274877896733
For 0.75 quantile,
TD Concordance Index - IPCW: 0.6412925627953555
Brier Score: 0.23389756773737294


defaultdict(list,
            {'0.25_ipcw': 0.7135106078431047,
             '0.25_brier': 0.11492162956562897,
             '0.5_ipcw': 0.658611861785142,
             '0.5_brier': 0.19810274877896733,
             '0.75_ipcw': 0.6412925627953555,
             '0.75_brier': 0.23389756773737294})

## SurvTRACE - SEER - Loss Function Ablation

In [19]:
STConfig['data'] = 'seer'
STConfig['num_hidden_layers'] = 2
STConfig['hidden_size'] = 16
STConfig['intermediate_size'] = 64
STConfig['num_attention_heads'] = 2
STConfig['initializer_range'] = .02
STConfig['early_stop_patience'] = 5

df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)
train_set = (df_train, df_y_train)
val_set = (df_val, df_y_val)


In [20]:
from survtrace.losses import NLLLogistiHazardLoss


hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeerLossAblation = SurvTraceMulti(STConfig).to(DEVICE)
SurvTraceSeerLossAblation_trainer = Trainer(SurvTraceSeerLossAblation)
SurvTraceSeerLossAblation_trainer.metrics = [NLLLogistiHazardLoss(),'NLLLogistiHazardLoss']
train_loss, val_loss = SurvTraceSeerLossAblation_trainer.fit((df_train, df_y_train), (df_val, df_y_val),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)


use pytorch-cuda for training.
[Train-0]: 3.3905065165085047
[Val-0]: 2.419238328933716
[Train-1]: 0.7743622794443247
[Val-1]: 0.7256957292556763
[Train-2]: 0.7306951503364407
[Val-2]: 0.7213267683982849
[Train-3]: 0.7266384052987002
[Val-3]: 0.7198609113693237
[Train-4]: 0.7237218723816126
[Val-4]: 0.7145017981529236
[Train-5]: 0.721240260163132
[Val-5]: 0.7131412625312805
[Train-6]: 0.7201806927213863
[Val-6]: 0.7134824991226196
EarlyStopping counter: 1 out of 5
[Train-7]: 0.7191109663369705
[Val-7]: 0.7112871408462524
[Train-8]: 0.7186617099103474
[Val-8]: 0.7109307050704956
[Train-9]: 0.7181141800215455
[Val-9]: 0.7124578952789307
EarlyStopping counter: 1 out of 5
[Train-10]: 0.7176786667230178
[Val-10]: 0.710012674331665
[Train-11]: 0.7173169743852551
[Val-11]: 0.7101588845252991
EarlyStopping counter: 1 out of 5
[Train-12]: 0.7170577596644966
[Val-12]: 0.7096244096755981
[Train-13]: 0.7162235186619013
[Val-13]: 0.7095814943313599
[Train-14]: 0.7161229287280517
[Val-14]: 0.7097368

In [21]:
evaluator = Evaluator(df, df_train.index)
evaluator.eval(SurvTraceSeerLossAblation, (df_test, df_y_test))

******************************
start evaluation
******************************


  concordant += n_con
  concordant += n_con
  concordant += n_con


Event: 0 For 0.25 quantile,
TD Concordance Index - IPCW: 0.9042126639737661
Brier Score: 0.03588776533595865
Event: 0 For 0.5 quantile,
TD Concordance Index - IPCW: 0.8826602563884307
Brier Score: 0.06029968900816516
Event: 0 For 0.75 quantile,
TD Concordance Index - IPCW: 0.8652698868042122
Brier Score: 0.08207141165758451
Event: 1 For 0.25 quantile,
TD Concordance Index - IPCW: 0.7888281552189829
Brier Score: 0.007393128898380241
Event: 1 For 0.5 quantile,
TD Concordance Index - IPCW: 0.7888654719993595
Brier Score: 0.015968243131562723
Event: 1 For 0.75 quantile,
TD Concordance Index - IPCW: 0.7780779250731416
Brier Score: 0.027951526741008745


defaultdict(list,
            {'0.25_ipcw_0': 0.9042126639737661,
             '0.25_brier_0': 0.03588776533595865,
             '0.5_ipcw_0': 0.8826602563884307,
             '0.5_brier_0': 0.06029968900816516,
             '0.75_ipcw_0': 0.8652698868042122,
             '0.75_brier_0': 0.08207141165758451,
             '0.25_ipcw_1': 0.7888281552189829,
             '0.25_brier_1': 0.007393128898380241,
             '0.5_ipcw_1': 0.7888654719993595,
             '0.5_brier_1': 0.015968243131562723,
             '0.75_ipcw_1': 0.7780779250731416,
             '0.75_brier_1': 0.027951526741008745})

## SurvTRACE - SEER - Transformer Ablation

In [22]:
from survtrace.modeling_bert import BertEncoderLame


class SurvTraceMultiAblation(SurvTraceMulti):
    def __init__(self, config: STConfig):
        super().__init__(config)
        self.encoder = BertEncoderLame(config)

hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeerEncoderAblation = SurvTraceMultiAblation(STConfig).to(DEVICE)
SurvTraceSeerEncoderAblation_trainer = Trainer(SurvTraceSeerEncoderAblation)
train_loss, val_loss = SurvTraceSeerEncoderAblation_trainer.fit((df_train, df_y_train), (df_val, df_y_val),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)


use pytorch-cuda for training.
[Train-0]: 2.9514876125621146
[Val-0]: 2.0919079780578613
[Train-1]: 0.7364038932485645
[Val-1]: 0.6968504190444946
[Train-2]: 0.6962061210149
[Val-2]: 0.6863769292831421
[Train-3]: 0.690154683022272
[Val-3]: 0.6833231449127197
[Train-4]: 0.6876527539321354
[Val-4]: 0.6816166043281555
[Train-5]: 0.6863483973100882
[Val-5]: 0.6807183027267456
[Train-6]: 0.6856074501462535
[Val-6]: 0.6799182891845703
[Train-7]: 0.6851262519148742
[Val-7]: 0.6792891025543213
[Train-8]: 0.6843384473907704
[Val-8]: 0.6787375211715698
[Train-9]: 0.683669133453953
[Val-9]: 0.6785733699798584
[Train-10]: 0.683705289753116
[Val-10]: 0.6776801347732544
[Train-11]: 0.6831158571908263
[Val-11]: 0.6792948246002197
EarlyStopping counter: 1 out of 5
[Train-12]: 0.6824237957698147
[Val-12]: 0.6772336959838867
[Train-13]: 0.682305167726919
[Val-13]: 0.6776301264762878
EarlyStopping counter: 1 out of 5
[Train-14]: 0.6821443456776288
[Val-14]: 0.6770136952400208
[Train-15]: 0.68113999585716

In [23]:
evaluator = Evaluator(df, df_train.index)
evaluator.eval(SurvTraceSeerEncoderAblation, (df_test, df_y_test))

******************************
start evaluation
******************************


  concordant += n_con
  concordant += n_con
  concordant += n_con


Event: 0 For 0.25 quantile,
TD Concordance Index - IPCW: 0.9038148991138182
Brier Score: 0.035724348473204964
Event: 0 For 0.5 quantile,
TD Concordance Index - IPCW: 0.8820874001469994
Brier Score: 0.06017103687127706
Event: 0 For 0.75 quantile,
TD Concordance Index - IPCW: 0.8650528721139195
Brier Score: 0.08165567512332446
Event: 1 For 0.25 quantile,
TD Concordance Index - IPCW: 0.7806379671444411
Brier Score: 0.007409649776910549
Event: 1 For 0.5 quantile,
TD Concordance Index - IPCW: 0.7814116178153309
Brier Score: 0.015987608899990718
Event: 1 For 0.75 quantile,
TD Concordance Index - IPCW: 0.7716941103998275
Brier Score: 0.027990005860478295


defaultdict(list,
            {'0.25_ipcw_0': 0.9038148991138182,
             '0.25_brier_0': 0.035724348473204964,
             '0.5_ipcw_0': 0.8820874001469994,
             '0.5_brier_0': 0.06017103687127706,
             '0.75_ipcw_0': 0.8650528721139195,
             '0.75_brier_0': 0.08165567512332446,
             '0.25_ipcw_1': 0.7806379671444411,
             '0.25_brier_1': 0.007409649776910549,
             '0.5_ipcw_1': 0.7814116178153309,
             '0.5_brier_1': 0.015987608899990718,
             '0.75_ipcw_1': 0.7716941103998275,
             '0.75_brier_1': 0.027990005860478295})

## SurvTRACE - SEER - Small Training Set Ablation

In [44]:
df_train_subset = df_train.sample(frac=0.5, random_state=2)
df_y_train_subset = df_y_train.sample(frac=0.5, random_state=2)
df_val_subset = df_val.sample(frac=0.5, random_state=2)
df_y_val_subset = df_y_val.sample(frac=0.5, random_state=2)

df_train_subset_75 = df_train.sample(frac=0.75, random_state=3)
df_y_train_subset_75 = df_y_train.sample(frac=0.75, random_state=3)
df_val_subset_75 = df_val.sample(frac=0.75, random_state=3)
df_y_val_subset_75 = df_y_val.sample(frac=0.75, random_state=3)

df_train_subset_25 = df_train.sample(frac=0.25, random_state=4)
df_y_train_subset_25 = df_y_train.sample(frac=0.25, random_state=4)
df_val_subset_25 = df_val.sample(frac=0.25, random_state=4)
df_y_val_subset_25 = df_y_val.sample(frac=0.25, random_state=4)


In [41]:

hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeerSubsetAblation = SurvTraceMulti(STConfig).to(DEVICE)
SurvTraceSeerSubsetAblation_trainer = Trainer(SurvTraceSeerSubsetAblation)
train_loss, val_loss = SurvTraceSeerSubsetAblation_trainer.fit((df_train_subset, df_y_train_subset), (df_val_subset, df_y_val_subset),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)

use pytorch-cuda for training.
[Train-0]: 3.3984248670590977
[Val-0]: 2.906320095062256
[Train-1]: 0.7680039182812178
[Val-1]: 0.6917750835418701
[Train-2]: 0.6900182260947973
[Val-2]: 0.6807222962379456
[Train-3]: 0.6851043802540319
[Val-3]: 0.6772527694702148
[Train-4]: 0.6820293494633266
[Val-4]: 0.6744533777236938
[Train-5]: 0.6790072070498044
[Val-5]: 0.6721703410148621
[Train-6]: 0.676766502613924
[Val-6]: 0.6706598997116089
[Train-7]: 0.6753144925143443
[Val-7]: 0.6688163876533508
[Train-8]: 0.6739655623630602
[Val-8]: 0.6676222681999207
[Train-9]: 0.6733619566677379
[Val-9]: 0.6671689748764038
[Train-10]: 0.6723835058763724
[Val-10]: 0.6659232378005981
[Train-11]: 0.671696810089812
[Val-11]: 0.6658376455307007
[Train-12]: 0.6713159088374806
[Val-12]: 0.6649394035339355
[Train-13]: 0.6708504079150505
[Val-13]: 0.6643069982528687
[Train-14]: 0.6701690376210375
[Val-14]: 0.6646916270256042
EarlyStopping counter: 1 out of 5
[Train-15]: 0.669948722229523
[Val-15]: 0.6634805202484131

In [45]:
hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeerSubsetAblation_25 = SurvTraceMulti(STConfig).to(DEVICE)
SurvTraceSeerSubsetAblation_trainer = Trainer(SurvTraceSeerSubsetAblation_25)
train_loss, val_loss = SurvTraceSeerSubsetAblation_trainer.fit((df_train_subset_25, df_y_train_subset_25), (df_val_subset_25, df_y_val_subset_25),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)


SurvTraceSeerSubsetAblation_75 = SurvTraceMulti(STConfig).to(DEVICE)
SurvTraceSeerSubsetAblation_trainer = Trainer(SurvTraceSeerSubsetAblation_75)
train_loss, val_loss = SurvTraceSeerSubsetAblation_trainer.fit((df_train_subset_75, df_y_train_subset_75), (df_val_subset_75, df_y_val_subset_75),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)

use pytorch-cuda for training.
[Train-0]: 3.646382599263578
[Val-0]: 3.556924819946289
[Train-1]: 0.9127130935320983
[Val-1]: 0.716437816619873
[Train-2]: 0.701504003357243
[Val-2]: 0.7012441158294678
[Train-3]: 0.6913884129073169
[Val-3]: 0.6946576833724976
[Train-4]: 0.6889209417072503
[Val-4]: 0.6913391351699829
[Train-5]: 0.6848790355630823
[Val-5]: 0.6890085339546204
[Train-6]: 0.6813528602187698
[Val-6]: 0.6873162388801575
[Train-7]: 0.6807550868472537
[Val-7]: 0.6859121322631836
[Train-8]: 0.6793314991770564
[Val-8]: 0.6851811408996582
[Train-9]: 0.6778452307791323
[Val-9]: 0.6848824620246887
[Train-10]: 0.6776262030408189
[Val-10]: 0.6835635900497437
[Train-11]: 0.6753134002556672
[Val-11]: 0.6846587657928467
EarlyStopping counter: 1 out of 5
[Train-12]: 0.6757978805013605
[Val-12]: 0.6835731267929077
EarlyStopping counter: 2 out of 5
[Train-13]: 0.6745332713062698
[Val-13]: 0.6822166442871094
[Train-14]: 0.6740285891133386
[Val-14]: 0.6820202469825745
[Train-15]: 0.67340188171

In [42]:
evaluator = Evaluator(df, df_train_subset.index)
evaluator.eval(SurvTraceSeerSubsetAblation, (df_test, df_y_test))

******************************
start evaluation
******************************


  concordant += n_con
  concordant += n_con
  concordant += n_con


Event: 0 For 0.25 quantile,
TD Concordance Index - IPCW: 0.9038359682551832
Brier Score: 0.03593204847357938
Event: 0 For 0.5 quantile,
TD Concordance Index - IPCW: 0.8823638666104949
Brier Score: 0.060375314796369615
Event: 0 For 0.75 quantile,
TD Concordance Index - IPCW: 0.8652652132084128
Brier Score: 0.08182501308646699
Event: 1 For 0.25 quantile,
TD Concordance Index - IPCW: 0.787177709951976
Brier Score: 0.0073964978265886475
Event: 1 For 0.5 quantile,
TD Concordance Index - IPCW: 0.787493956021642
Brier Score: 0.015952080213538077
Event: 1 For 0.75 quantile,
TD Concordance Index - IPCW: 0.7777010202721919
Brier Score: 0.02785079575228471


defaultdict(list,
            {'0.25_ipcw_0': 0.9038359682551832,
             '0.25_brier_0': 0.03593204847357938,
             '0.5_ipcw_0': 0.8823638666104949,
             '0.5_brier_0': 0.060375314796369615,
             '0.75_ipcw_0': 0.8652652132084128,
             '0.75_brier_0': 0.08182501308646699,
             '0.25_ipcw_1': 0.787177709951976,
             '0.25_brier_1': 0.0073964978265886475,
             '0.5_ipcw_1': 0.787493956021642,
             '0.5_brier_1': 0.015952080213538077,
             '0.75_ipcw_1': 0.7777010202721919,
             '0.75_brier_1': 0.02785079575228471})

In [46]:
evaluator = Evaluator(df, df_train_subset_25.index)
evaluator.eval(SurvTraceSeerSubsetAblation_25, (df_test, df_y_test))

evaluator = Evaluator(df, df_train_subset_75.index)
evaluator.eval(SurvTraceSeerSubsetAblation_75, (df_test, df_y_test))

******************************
start evaluation
******************************


  concordant += n_con
  concordant += n_con
  concordant += n_con


Event: 0 For 0.25 quantile,
TD Concordance Index - IPCW: 0.9016558653395782
Brier Score: 0.03613093562318194
Event: 0 For 0.5 quantile,
TD Concordance Index - IPCW: 0.8798473499404043
Brier Score: 0.06095187328171079
Event: 0 For 0.75 quantile,
TD Concordance Index - IPCW: 0.862724534166732
Brier Score: 0.08280528714561804
Event: 1 For 0.25 quantile,
TD Concordance Index - IPCW: 0.7787431869606857
Brier Score: 0.007401464124342016
Event: 1 For 0.5 quantile,
TD Concordance Index - IPCW: 0.7803661757224769
Brier Score: 0.016004959140329532
Event: 1 For 0.75 quantile,
TD Concordance Index - IPCW: 0.7719474008016112
Brier Score: 0.027924827540093934
******************************
start evaluation
******************************


  concordant += n_con
  concordant += n_con
  concordant += n_con


Event: 0 For 0.25 quantile,
TD Concordance Index - IPCW: 0.9041698958297258
Brier Score: 0.03580182422832825
Event: 0 For 0.5 quantile,
TD Concordance Index - IPCW: 0.8824849124471975
Brier Score: 0.060174879070127675
Event: 0 For 0.75 quantile,
TD Concordance Index - IPCW: 0.8654156482245564
Brier Score: 0.08161266098226704
Event: 1 For 0.25 quantile,
TD Concordance Index - IPCW: 0.7892109637818422
Brier Score: 0.007387233154820435
Event: 1 For 0.5 quantile,
TD Concordance Index - IPCW: 0.7893960795110768
Brier Score: 0.015946164683142947
Event: 1 For 0.75 quantile,
TD Concordance Index - IPCW: 0.7791416239397158
Brier Score: 0.02786030881202452


defaultdict(list,
            {'0.25_ipcw_0': 0.9041698958297258,
             '0.25_brier_0': 0.03580182422832825,
             '0.5_ipcw_0': 0.8824849124471975,
             '0.5_brier_0': 0.060174879070127675,
             '0.75_ipcw_0': 0.8654156482245564,
             '0.75_brier_0': 0.08161266098226704,
             '0.25_ipcw_1': 0.7892109637818422,
             '0.25_brier_1': 0.007387233154820435,
             '0.5_ipcw_1': 0.7893960795110768,
             '0.5_brier_1': 0.015946164683142947,
             '0.75_ipcw_1': 0.7791416239397158,
             '0.75_brier_1': 0.02786030881202452})

In [47]:
df_train_subset_5 = df_train.sample(frac=0.05, random_state=5)
df_y_train_subset_5 = df_y_train.sample(frac=0.05, random_state=5)
df_val_subset_5 = df_val.sample(frac=0.05, random_state=5)
df_y_val_subset_5 = df_y_val.sample(frac=0.05, random_state=5)

In [48]:
hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeerSubsetAblation_5 = SurvTraceMulti(STConfig).to(DEVICE)
SurvTraceSeerSubsetAblation_trainer = Trainer(SurvTraceSeerSubsetAblation_5)
train_loss, val_loss = SurvTraceSeerSubsetAblation_trainer.fit((df_train_subset_5, df_y_train_subset_5), (df_val_subset_5, df_y_val_subset_5),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)

use pytorch-cuda for training.
[Train-0]: 3.956625493367513
[Val-0]: 3.662135362625122
[Train-1]: 1.9544437328974407
[Val-1]: 0.9048177003860474
[Train-2]: 0.820472264289856
[Val-2]: 0.7868235111236572
[Train-3]: 0.746699579556783
[Val-3]: 0.7416931390762329
[Train-4]: 0.7195817430814108
[Val-4]: 0.7261303067207336
[Train-5]: 0.7051557898521423
[Val-5]: 0.7168315649032593
[Train-6]: 0.6978320956230164
[Val-6]: 0.7104278802871704
[Train-7]: 0.6912192185719808
[Val-7]: 0.7059135437011719
[Train-8]: 0.6869211316108703
[Val-8]: 0.7037538886070251
[Train-9]: 0.6849209507306416
[Val-9]: 0.7012452483177185
[Train-10]: 0.6815810759862264
[Val-10]: 0.7000185251235962
[Train-11]: 0.6804095665613811
[Val-11]: 0.7001855969429016
EarlyStopping counter: 1 out of 5
[Train-12]: 0.679466970761617
[Val-12]: 0.6978452205657959
[Train-13]: 0.6781292875607808
[Val-13]: 0.697655200958252
[Train-14]: 0.6766066392262776
[Val-14]: 0.6967000961303711
[Train-15]: 0.6748671968777974
[Val-15]: 0.6985398530960083
E

In [49]:
evaluator = Evaluator(df, df_train_subset_5.index)
evaluator.eval(SurvTraceSeerSubsetAblation_5, (df_test, df_y_test))

******************************
start evaluation
******************************


  concordant += n_con
  concordant += n_con
  concordant += n_con


Event: 0 For 0.25 quantile,
TD Concordance Index - IPCW: 0.895294404539345
Brier Score: 0.0365494025573316
Event: 0 For 0.5 quantile,
TD Concordance Index - IPCW: 0.8755890882415777
Brier Score: 0.06175855179417038
Event: 0 For 0.75 quantile,
TD Concordance Index - IPCW: 0.8590420989809229
Brier Score: 0.08424169508865847
Event: 1 For 0.25 quantile,
TD Concordance Index - IPCW: 0.7393104633892094
Brier Score: 0.007458268406463028
Event: 1 For 0.5 quantile,
TD Concordance Index - IPCW: 0.7467688668986358
Brier Score: 0.016183820924302324
Event: 1 For 0.75 quantile,
TD Concordance Index - IPCW: 0.7419850426788586
Brier Score: 0.028463274932793683


defaultdict(list,
            {'0.25_ipcw_0': 0.895294404539345,
             '0.25_brier_0': 0.0365494025573316,
             '0.5_ipcw_0': 0.8755890882415777,
             '0.5_brier_0': 0.06175855179417038,
             '0.75_ipcw_0': 0.8590420989809229,
             '0.75_brier_0': 0.08424169508865847,
             '0.25_ipcw_1': 0.7393104633892094,
             '0.25_brier_1': 0.007458268406463028,
             '0.5_ipcw_1': 0.7467688668986358,
             '0.5_brier_1': 0.016183820924302324,
             '0.75_ipcw_1': 0.7419850426788586,
             '0.75_brier_1': 0.028463274932793683})