In [1]:
# %pip install -r requirements.txt

In [2]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch 
import easydict
from torch import Tensor, device, dtype, nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from sksurv.metrics import concordance_index_ipcw, brier_score
import torchtuples as tt

from pycox.datasets import metabric, support
from pycox.models import LogisticHazard, DeepHit
from pycox.preprocessing.feature_transforms import OrderedCategoricalLong
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from pycox.evaluation import EvalSurv
from pycox.models.loss import NLLPCHazardLoss
from pycox.preprocessing.discretization import (make_cuts, IdxDiscUnknownC, _values_if_series,
    DiscretizeUnknownC, Duration2Idx)


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from survtrace.dataset import load_data
from survtrace.evaluate_utils import Evaluator
from survtrace.utils import set_random_seed
from survtrace.model import SurvTraceMulti, SurvTraceSingle
from survtrace.train_utils import Trainer
from survtrace.config import STConfig

In [4]:
np.random.seed(42)
_ = torch.manual_seed(42)
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## DeepHit Specifications

In [5]:
class SimpleMLP(torch.nn.Module):
    def __init__(self, in_features, num_nodes, num_risks, out_features, batch_norm=True,
                 dropout=None):
        super().__init__()
        self.num_risks = num_risks
        self.mlp = tt.practical.MLPVanilla(
            in_features, num_nodes, num_risks * out_features,
            batch_norm, dropout,
        )
        
    def forward(self, input):
        out = self.mlp(input)
        return out.view(out.size(0), self.num_risks, -1)

class CauseSpecificNet(torch.nn.Module):
    def __init__(self, in_features, num_nodes_shared, num_nodes_indiv, num_risks,
                 out_features, batch_norm=True, dropout=None):
        super().__init__()
        self.shared_net = tt.practical.MLPVanilla(
            in_features, num_nodes_shared[:-1], num_nodes_shared[-1],
            batch_norm, dropout,
        )
        self.risk_nets = torch.nn.ModuleList()
        for _ in range(num_risks):
            net = tt.practical.MLPVanilla(
                num_nodes_shared[-1], num_nodes_indiv, out_features,
                batch_norm, dropout,
            )
            self.risk_nets.append(net)

    def forward(self, input):
        out = self.shared_net(input)
        out = [net(out) for net in self.risk_nets]
        out = torch.stack(out, dim=1)
        return out

In [6]:
STConfig['data'] = 'seer'
STConfig['num_hidden_layers'] = 2
STConfig['hidden_size'] = 16
STConfig['intermediate_size'] = 64
STConfig['num_attention_heads'] = 2
STConfig['initializer_range'] = .02
STConfig['early_stop_patience'] = 5

df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)
train_set = (df_train, df_y_train)
val_set = (df_val, df_y_val)




In [7]:
df.iloc[:, 0:13].describe().apply(lambda s: s.apply('{0:.2f}'.format))

Unnamed: 0,Sex,Year of diagnosis,"Race recode (W, B, AI, API)",Histologic Type ICD-O-3,Laterality,Sequence number,ER Status Recode Breast Cancer (1990+),PR Status Recode Breast Cancer (1990+),Summary stage 2000 (1998-2017),RX Summ--Surg Prim Site (1998+),Reason no cancer-directed surgery,First malignant primary indicator,Diagnostic Confirmation
count,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0
mean,0.01,6.34,2.72,1.43,2.48,4.4,1.71,1.59,1.24,8.74,4.55,0.82,3.0
std,0.09,2.79,0.62,4.58,1.5,2.44,0.56,0.61,0.56,8.04,1.38,0.38,0.18
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,5.0,3.0,0.0,1.0,1.0,2.0,1.0,1.0,4.0,5.0,1.0,3.0
50%,0.0,7.0,3.0,0.0,1.0,6.0,2.0,2.0,1.0,5.0,5.0,1.0,3.0
75%,0.0,9.0,3.0,1.0,4.0,6.0,2.0,2.0,2.0,13.0,5.0,1.0,3.0
max,1.0,10.0,3.0,74.0,4.0,6.0,2.0,2.0,2.0,47.0,5.0,1.0,5.0


In [8]:
df.iloc[:, 14:25].describe().apply(lambda s: s.apply('{0:.2f}'.format))


Unnamed: 0,Regional nodes examined (1988+),CS tumor size (2004-2015),Total number of benign/borderline tumors for patient,Total number of in situ/malignant tumors for patient,duration,event_heart,event_breast,event_0,event_1,duration_disc,proportion
count,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0
mean,-0.0,0.0,0.0,0.0,67.41,0.05,0.18,0.18,0.05,2.36,0.51
std,1.0,1.0,1.0,1.0,31.45,0.21,0.39,0.39,0.21,0.97,0.29
min,-0.48,-0.38,-0.08,-0.57,1.0,0.0,0.0,0.0,0.0,0.0,0.02
25%,-0.43,-0.34,-0.08,-0.57,48.0,0.0,0.0,0.0,0.0,2.0,0.25
50%,-0.32,-0.3,-0.08,-0.57,69.0,0.0,0.0,0.0,0.0,3.0,0.52
75%,-0.0,-0.25,-0.08,0.94,92.0,0.0,0.0,0.0,0.0,3.0,0.76
max,4.83,3.44,57.86,28.18,121.0,1.0,1.0,1.0,1.0,3.0,1.0


## DeepHit - SEER

In [9]:



# df_train, df_y_train = train_set
# x_train = df_train.values.astype('float32')
# y_train = (df_y_train['duration'].values.astype('int64'), (df_y_train['event_0'].values + df_y_train['event_1'].values * 2).astype('int64'))

# df_val, df_y_val = val_set
# x_val = df_val.values.astype('float32')
# y_val = (df_y_val['duration'].values.astype('int64'), (df_y_val['event_0'].values + df_y_val['event_1'].values * 2).astype('int64'))
# val = (x_val, y_val)

# in_features = x_train.shape[1]
# num_nodes_shared = [64, 64]
# num_nodes_indiv = [32]
# num_risks = 2
# out_features = len(STConfig['labtrans'].cuts)
# batch_norm = True
# dropout = 0.1

# net = CauseSpecificNet(in_features, num_nodes_shared, num_nodes_indiv, num_risks,
#                     out_features, batch_norm, dropout).to(DEVICE)
# optimizer = tt.optim.AdamWR(lr=0.01, decoupled_weight_decay=0.01,
#                         cycle_eta_multiplier=0.8)
# DeepHitModel = DeepHit(net, optimizer, alpha=0.2, sigma=0.1,
#             duration_index=STConfig['labtrans'].cuts)
# epochs = 50
# batch_size = 256
# callbacks = [tt.callbacks.EarlyStoppingCycle()]
# verbose = True 

# x_train, y_train = torch.tensor(x_train).to(DEVICE), (torch.tensor(y_train[0]).to(DEVICE), torch.tensor(y_train[1]).to(DEVICE))
# x_val, y_val = torch.tensor(x_val).to(DEVICE), (torch.tensor(y_val[0]).to(DEVICE), torch.tensor(y_val[1]).to(DEVICE))
# val = (x_val, y_val)

# log = DeepHitModel.fit(x_train, y_train, batch_size, epochs, callbacks, verbose, val_data=val)
# log = log.to_pandas()




In [10]:
# predictions = DeepHitModel.predict(x_val)
# event_types = y_val[0].cpu().numpy()
# event_times = y_val[1].cpu().numpy()

# cif = DeepHitModel.predict_cif(x_val)
# cif1 = pd.DataFrame(cif[0].cpu().numpy(), DeepHitModel.duration_index)
# ev1 = EvalSurv(cif1, event_times, event_types > 0, censor_surv='km')
# ev1.concordance_td()

# cif1 = pd.DataFrame(cif[0].cpu().numpy(), DeepHitModel.duration_index)
# cif2 = pd.DataFrame(cif[1].cpu().numpy(), DeepHitModel.duration_index)
# ev1 = EvalSurv(1-cif1, y_val_duration, y_val_event == 1, censor_surv='km')
# ev2 = EvalSurv(1-cif2, y_val_duration, y_val_event == 2, censor_surv='km')

# cis.append(concordance_index_ipcw(et_train, et_test, risk[:, i+1].to("cpu").numpy(), times[i])[0])  

## SurvTRACE - SEER

In [11]:
hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeer = SurvTraceMulti(STConfig).to(DEVICE)
SurvTraceSeer_trainer = Trainer(SurvTraceSeer)
train_loss, val_loss = SurvTraceSeer_trainer.fit((df_train, df_y_train), (df_val, df_y_val),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)


use pytorch-cuda for training.


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\utils\python_arg_parser.cpp:1420.)
  next_m.mul_(beta1).add_(1 - beta1, grad)


[Train-0]: 2.9780880585819687
[Val-0]: 2.2568306922912598
[Train-1]: 0.7210916637563382
[Val-1]: 0.6916519403457642
[Train-2]: 0.6842737139082279
[Val-2]: 0.6783857345581055
[Train-3]: 0.6787194909692622
[Val-3]: 0.6755419969558716
[Train-4]: 0.6768748395296992
[Val-4]: 0.6720520853996277
[Train-5]: 0.674347017087093
[Val-5]: 0.6719754934310913
[Train-6]: 0.6734441117364534
[Val-6]: 0.6702817678451538
[Train-7]: 0.6724797316959926
[Val-7]: 0.6696440577507019
[Train-8]: 0.6720320247873968
[Val-8]: 0.6692485213279724
[Train-9]: 0.6713364744672969
[Val-9]: 0.66905277967453
[Train-10]: 0.6709403488911739
[Val-10]: 0.6695612072944641
EarlyStopping counter: 1 out of 5
[Train-11]: 0.6704962669181175
[Val-11]: 0.669444739818573
EarlyStopping counter: 2 out of 5
[Train-12]: 0.6702545980612437
[Val-12]: 0.667793333530426
[Train-13]: 0.6698257327079773
[Val-13]: 0.6681599617004395
EarlyStopping counter: 1 out of 5
[Train-14]: 0.6694166062235021
[Val-14]: 0.6693049073219299
EarlyStopping counter: 

In [12]:
evaluator = Evaluator(df, df_train.index)
evaluator.eval(SurvTraceSeer, (df_test, df_y_test))

******************************
start evaluation
******************************


  concordant += n_con
  concordant += n_con
  concordant += n_con


Event: 0 For 0.25 quantile,
TD Concordance Index - IPCW: 0.9071106652704147
Brier Score: 0.035204254881749886
Event: 0 For 0.5 quantile,
TD Concordance Index - IPCW: 0.8830064721164833
Brier Score: 0.05992459698803965
Event: 0 For 0.75 quantile,
TD Concordance Index - IPCW: 0.8661684381548437
Brier Score: 0.08137640118267309
Event: 1 For 0.25 quantile,
TD Concordance Index - IPCW: 0.8025370923614724
Brier Score: 0.007705245996202905
Event: 1 For 0.5 quantile,
TD Concordance Index - IPCW: 0.7968190891471266
Brier Score: 0.01634627550036855
Event: 1 For 0.75 quantile,
TD Concordance Index - IPCW: 0.7845127556544388
Brier Score: 0.028127188531967287


defaultdict(list,
            {'0.25_ipcw_0': 0.9071106652704147,
             '0.25_brier_0': 0.035204254881749886,
             '0.5_ipcw_0': 0.8830064721164833,
             '0.5_brier_0': 0.05992459698803965,
             '0.75_ipcw_0': 0.8661684381548437,
             '0.75_brier_0': 0.08137640118267309,
             '0.25_ipcw_1': 0.8025370923614724,
             '0.25_brier_1': 0.007705245996202905,
             '0.5_ipcw_1': 0.7968190891471266,
             '0.5_brier_1': 0.01634627550036855,
             '0.75_ipcw_1': 0.7845127556544388,
             '0.75_brier_1': 0.028127188531967287})

## SurvTRACE - SUPPORT

In [13]:
# define the setup parameters
STConfig['data'] = 'support'
STConfig['num_event'] = 1

hparams = {
    'batch_size': 128,
    'weight_decay': 0,
    'learning_rate': 1e-3,
    'epochs': 20,
}
df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)

df.describe().apply(lambda s: s.apply('{0:.2f}'.format))



Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,duration,event
count,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0
mean,62.63,0.44,1.87,1.25,0.19,0.03,0.94,84.53,97.48,23.36,37.11,137.57,12.35,1.78,478.64,0.68
std,15.62,0.5,1.34,0.62,0.4,0.18,0.58,27.82,31.7,9.63,1.26,6.07,9.27,1.69,560.83,0.47
min,18.04,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,31.7,110.0,0.0,0.1,3.0,0.0
25%,52.75,0.0,1.0,1.0,0.0,0.0,1.0,63.0,72.0,18.0,36.2,134.0,7.0,0.9,26.0,0.0
50%,64.83,0.0,2.0,1.0,0.0,0.0,1.0,77.0,100.0,24.0,36.7,137.0,10.6,1.2,231.0,1.0
75%,74.03,1.0,3.0,1.0,0.0,0.0,1.0,107.0,120.0,28.0,38.2,141.0,15.3,1.9,763.0,1.0
max,101.85,1.0,9.0,5.0,1.0,1.0,2.0,195.0,300.0,90.0,41.7,181.0,200.0,21.5,2029.0,1.0


In [14]:
model = SurvTraceSingle(STConfig)
trainer = Trainer(model)
train_loss, val_loss = trainer.fit((df_train, df_y_train), (df_val, df_y_val),
        batch_size=hparams['batch_size'],
        epochs=hparams['epochs'],
        learning_rate=hparams['learning_rate'],
        weight_decay=hparams['weight_decay'],)

use pytorch-cuda for training.
[Train-0]: 68.95697116851807
[Val-0]: 1.3649463653564453
[Train-1]: 58.374926924705505
[Val-1]: 1.3431566953659058
[Train-2]: 57.908249855041504
[Val-2]: 1.344058871269226
EarlyStopping counter: 1 out of 5
[Train-3]: 57.66807961463928
[Val-3]: 1.3433791399002075
EarlyStopping counter: 2 out of 5
[Train-4]: 57.475478172302246
[Val-4]: 1.3496466875076294
EarlyStopping counter: 3 out of 5
[Train-5]: 57.27681064605713
[Val-5]: 1.3396204710006714
[Train-6]: 57.104907393455505
[Val-6]: 1.3400529623031616
EarlyStopping counter: 1 out of 5
[Train-7]: 57.05213725566864
[Val-7]: 1.3393604755401611
[Train-8]: 57.03586256504059
[Val-8]: 1.3361797332763672
[Train-9]: 56.90569519996643
[Val-9]: 1.3340507745742798
[Train-10]: 56.91401648521423
[Val-10]: 1.3327220678329468
[Train-11]: 56.735639214515686
[Val-11]: 1.3318065404891968
[Train-12]: 56.621888279914856
[Val-12]: 1.3339364528656006
EarlyStopping counter: 1 out of 5
[Train-13]: 56.546706676483154
[Val-13]: 1.3298

In [15]:
evaluator = Evaluator(df, df_train.index)
evaluator.eval(model, (df_test, df_y_test))

******************************
start evaluation
******************************
For 0.25 quantile,
TD Concordance Index - IPCW: 0.6692126949042597
Brier Score: 0.13339263883792654
For 0.5 quantile,
TD Concordance Index - IPCW: 0.6356009266231335
Brier Score: 0.20725806032760985
For 0.75 quantile,
TD Concordance Index - IPCW: 0.6108089107374024
Brier Score: 0.2340467896197376


defaultdict(list,
            {'0.25_ipcw': 0.6692126949042597,
             '0.25_brier': 0.13339263883792654,
             '0.5_ipcw': 0.6356009266231335,
             '0.5_brier': 0.20725806032760985,
             '0.75_ipcw': 0.6108089107374024,
             '0.75_brier': 0.2340467896197376})

## SurvTRACE - METABRIC

In [16]:

# define the setup parameters
STConfig['data'] = 'metabric'
STConfig['num_event'] = 1

hparams = {
    'batch_size': 128,
    'weight_decay': 0,
    'learning_rate': 1e-3,
    'epochs': 20,
}
df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)

df.describe().apply(lambda s: s.apply('{0:.2f}'.format))




Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,duration,event
count,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0
mean,6.21,6.24,10.77,5.87,0.62,0.6,0.21,0.76,61.09,125.03,0.58
std,0.86,1.02,1.36,0.34,0.49,0.49,0.41,0.43,12.98,76.33,0.49
min,5.16,4.86,6.37,5.1,0.0,0.0,0.0,0.0,21.93,0.0,0.0
25%,5.69,5.41,9.97,5.62,0.0,0.0,0.0,1.0,51.38,60.82,0.0
50%,5.95,5.88,10.53,5.82,1.0,1.0,0.0,1.0,61.77,114.9,1.0
75%,6.46,6.9,11.16,6.06,1.0,1.0,0.0,1.0,70.59,184.47,1.0
max,14.44,9.93,14.64,7.66,1.0,1.0,1.0,1.0,96.29,355.2,1.0


In [17]:

model = SurvTraceSingle(STConfig)
trainer = Trainer(model)
train_loss, val_loss = trainer.fit((df_train, df_y_train), (df_val, df_y_val),
        batch_size=hparams['batch_size'],
        epochs=hparams['epochs'],
        learning_rate=hparams['learning_rate'],
        weight_decay=hparams['weight_decay'],)

use pytorch-cuda for training.
[Train-0]: 17.17668056488037
[Val-0]: 1.514079213142395
[Train-1]: 15.400452971458435
[Val-1]: 1.3295152187347412
[Train-2]: 13.65306842327118
[Val-2]: 1.2120444774627686
[Train-3]: 11.666026830673218
[Val-3]: 1.2717303037643433
EarlyStopping counter: 1 out of 5
[Train-4]: 10.727657496929169
[Val-4]: 1.2192248106002808
EarlyStopping counter: 2 out of 5
[Train-5]: 10.635212540626526
[Val-5]: 1.2077102661132812
[Train-6]: 10.72372305393219
[Val-6]: 1.2327027320861816
EarlyStopping counter: 1 out of 5
[Train-7]: 10.655511856079102
[Val-7]: 1.197699785232544
[Train-8]: 10.600127220153809
[Val-8]: 1.1914997100830078
[Train-9]: 10.635278284549713
[Val-9]: 1.1874351501464844
[Train-10]: 10.51487821340561
[Val-10]: 1.1954270601272583
EarlyStopping counter: 1 out of 5
[Train-11]: 10.441945016384125
[Val-11]: 1.2029811143875122
EarlyStopping counter: 2 out of 5
[Train-12]: 10.413191556930542
[Val-12]: 1.17203950881958
[Train-13]: 10.462559700012207
[Val-13]: 1.1851

In [18]:
evaluator = Evaluator(df, df_train.index)
evaluator.eval(model, (df_test, df_y_test))

******************************
start evaluation
******************************
For 0.25 quantile,
TD Concordance Index - IPCW: 0.6863301655978569
Brier Score: 0.12507131130402735
For 0.5 quantile,
TD Concordance Index - IPCW: 0.6627624695022662
Brier Score: 0.20305402740709985
For 0.75 quantile,
TD Concordance Index - IPCW: 0.6447654173252123
Brier Score: 0.21919318639141072


defaultdict(list,
            {'0.25_ipcw': 0.6863301655978569,
             '0.25_brier': 0.12507131130402735,
             '0.5_ipcw': 0.6627624695022662,
             '0.5_brier': 0.20305402740709985,
             '0.75_ipcw': 0.6447654173252123,
             '0.75_brier': 0.21919318639141072})

## SurvTRACE - SEER - Loss Function Ablation

In [19]:
STConfig['data'] = 'seer'
STConfig['num_hidden_layers'] = 2
STConfig['hidden_size'] = 16
STConfig['intermediate_size'] = 64
STConfig['num_attention_heads'] = 2
STConfig['initializer_range'] = .02
STConfig['early_stop_patience'] = 5

df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)
train_set = (df_train, df_y_train)
val_set = (df_val, df_y_val)


In [20]:
from survtrace.losses import NLLLogistiHazardLoss


hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeerLossAblation = SurvTraceMulti(STConfig).to(DEVICE)
SurvTraceSeerLossAblation_trainer = Trainer(SurvTraceSeerLossAblation)
SurvTraceSeerLossAblation_trainer.metrics = [NLLLogistiHazardLoss(),'NLLLogistiHazardLoss']
train_loss, val_loss = SurvTraceSeerLossAblation_trainer.fit((df_train, df_y_train), (df_val, df_y_val),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)


use pytorch-cuda for training.
[Train-0]: 3.3970360382884537
[Val-0]: 2.522803783416748
[Train-1]: 0.7769602123571901
[Val-1]: 0.7481459379196167
[Train-2]: 0.730033027274268
[Val-2]: 0.739301323890686
[Train-3]: 0.7259159136791619
[Val-3]: 0.735472559928894
[Train-4]: 0.7234300008030976
[Val-4]: 0.734266996383667
[Train-5]: 0.7212400310704498
[Val-5]: 0.7328154444694519
[Train-6]: 0.7193970743085252
[Val-6]: 0.7304965853691101
[Train-7]: 0.7188311186777491
[Val-7]: 0.729701042175293
[Train-8]: 0.718085430511812
[Val-8]: 0.7295058369636536
[Train-9]: 0.7172548233651791
[Val-9]: 0.7291771769523621
[Train-10]: 0.716700033468454
[Val-10]: 0.7289347648620605
[Train-11]: 0.7164326277719874
[Val-11]: 0.7282321453094482
[Train-12]: 0.7160614530245463
[Val-12]: 0.7281951904296875
[Train-13]: 0.715598137403021
[Val-13]: 0.7271037101745605
[Train-14]: 0.7155112492389419
[Val-14]: 0.7268155813217163
[Train-15]: 0.7152814877276518
[Val-15]: 0.7281519770622253
EarlyStopping counter: 1 out of 5
[Tra

In [21]:
evaluator = Evaluator(df, df_train.index)
evaluator.eval(SurvTraceSeerLossAblation, (df_test, df_y_test))

******************************
start evaluation
******************************


  concordant += n_con
  concordant += n_con
  concordant += n_con


Event: 0 For 0.25 quantile,
TD Concordance Index - IPCW: 0.9045096863773099
Brier Score: 0.0355474182510648
Event: 0 For 0.5 quantile,
TD Concordance Index - IPCW: 0.8836975483293649
Brier Score: 0.06023112527910501
Event: 0 For 0.75 quantile,
TD Concordance Index - IPCW: 0.8658446910676755
Brier Score: 0.08214365418864937
Event: 1 For 0.25 quantile,
TD Concordance Index - IPCW: 0.7937244498172766
Brier Score: 0.007614440575865029
Event: 1 For 0.5 quantile,
TD Concordance Index - IPCW: 0.7898887675517042
Brier Score: 0.016022714526674274
Event: 1 For 0.75 quantile,
TD Concordance Index - IPCW: 0.7780782187694347
Brier Score: 0.027647309811844236


defaultdict(list,
            {'0.25_ipcw_0': 0.9045096863773099,
             '0.25_brier_0': 0.0355474182510648,
             '0.5_ipcw_0': 0.8836975483293649,
             '0.5_brier_0': 0.06023112527910501,
             '0.75_ipcw_0': 0.8658446910676755,
             '0.75_brier_0': 0.08214365418864937,
             '0.25_ipcw_1': 0.7937244498172766,
             '0.25_brier_1': 0.007614440575865029,
             '0.5_ipcw_1': 0.7898887675517042,
             '0.5_brier_1': 0.016022714526674274,
             '0.75_ipcw_1': 0.7780782187694347,
             '0.75_brier_1': 0.027647309811844236})

## SurvTRACE - SEER - Transformer Ablation

In [22]:
from survtrace.modeling_bert import BertEncoderLame


class SurvTraceMultiAblation(SurvTraceMulti):
    def __init__(self, config: STConfig):
        super().__init__(config)
        self.encoder = BertEncoderLame(config)

hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeerEncoderAblation = SurvTraceMultiAblation(STConfig).to(DEVICE)
SurvTraceSeerEncoderAblation_trainer = Trainer(SurvTraceSeerEncoderAblation)
train_loss, val_loss = SurvTraceSeerEncoderAblation_trainer.fit((df_train, df_y_train), (df_val, df_y_val),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)


use pytorch-cuda for training.
[Train-0]: 2.954394212385424
[Val-0]: 2.2478389739990234
[Train-1]: 0.7364065981641108
[Val-1]: 0.7110079526901245
[Train-2]: 0.6954682753199622
[Val-2]: 0.702724039554596
[Train-3]: 0.6896277810035109
[Val-3]: 0.7003171443939209
[Train-4]: 0.6869831052767176
[Val-4]: 0.6983098387718201
[Train-5]: 0.6858178993471625
[Val-5]: 0.6974062323570251
[Train-6]: 0.6851338736054038
[Val-6]: 0.6967763900756836
[Train-7]: 0.6843079723873917
[Val-7]: 0.6968233585357666
EarlyStopping counter: 1 out of 5
[Train-8]: 0.6839590267259248
[Val-8]: 0.6965789198875427
[Train-9]: 0.6835425686674054
[Val-9]: 0.6956456899642944
[Train-10]: 0.6830308210282099
[Val-10]: 0.6957533955574036
EarlyStopping counter: 1 out of 5
[Train-11]: 0.6827638741253185
[Val-11]: 0.6952235102653503
[Train-12]: 0.682495502793059
[Val-12]: 0.6944617629051208
[Train-13]: 0.6821299811609748
[Val-13]: 0.6947831511497498
EarlyStopping counter: 1 out of 5
[Train-14]: 0.6820245561551075
[Val-14]: 0.6945630

In [23]:
evaluator = Evaluator(df, df_train.index)
evaluator.eval(SurvTraceSeerEncoderAblation, (df_test, df_y_test))

******************************
start evaluation
******************************


  concordant += n_con
  concordant += n_con
  concordant += n_con


Event: 0 For 0.25 quantile,
TD Concordance Index - IPCW: 0.8985903252229017
Brier Score: 0.035641749256750566
Event: 0 For 0.5 quantile,
TD Concordance Index - IPCW: 0.8760477899505578
Brier Score: 0.06079497180867453
Event: 0 For 0.75 quantile,
TD Concordance Index - IPCW: 0.856480772074978
Brier Score: 0.08310501579729272
Event: 1 For 0.25 quantile,
TD Concordance Index - IPCW: 0.7745272121961704
Brier Score: 0.007611286488831452
Event: 1 For 0.5 quantile,
TD Concordance Index - IPCW: 0.7754123624863947
Brier Score: 0.016022296753487878
Event: 1 For 0.75 quantile,
TD Concordance Index - IPCW: 0.7674821014893217
Brier Score: 0.02767031395998046


defaultdict(list,
            {'0.25_ipcw_0': 0.8985903252229017,
             '0.25_brier_0': 0.035641749256750566,
             '0.5_ipcw_0': 0.8760477899505578,
             '0.5_brier_0': 0.06079497180867453,
             '0.75_ipcw_0': 0.856480772074978,
             '0.75_brier_0': 0.08310501579729272,
             '0.25_ipcw_1': 0.7745272121961704,
             '0.25_brier_1': 0.007611286488831452,
             '0.5_ipcw_1': 0.7754123624863947,
             '0.5_brier_1': 0.016022296753487878,
             '0.75_ipcw_1': 0.7674821014893217,
             '0.75_brier_1': 0.02767031395998046})