In [None]:
%pip install -r requirements.txt

In [18]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch 
import easydict
from torch import Tensor, device, dtype, nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from sksurv.metrics import concordance_index_ipcw, brier_score
import torchtuples as tt

from pycox.datasets import metabric, support
from pycox.models import LogisticHazard, DeepHit
from pycox.preprocessing.feature_transforms import OrderedCategoricalLong
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from pycox.evaluation import EvalSurv
from pycox.models.loss import NLLPCHazardLoss
from pycox.preprocessing.discretization import (make_cuts, IdxDiscUnknownC, _values_if_series,
    DiscretizeUnknownC, Duration2Idx)


In [19]:
from survtrace.dataset import load_data
from survtrace.evaluate_utils import Evaluator
from survtrace.utils import set_random_seed
from survtrace.model import SurvTraceMulti, SurvTraceSingle
from survtrace.train_utils import Trainer
from survtrace.config import STConfig

In [20]:
np.random.seed(42)
_ = torch.manual_seed(42)
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
class SimpleMLP(torch.nn.Module):
    def __init__(self, in_features, num_nodes, num_risks, out_features, batch_norm=True,
                 dropout=None):
        super().__init__()
        self.num_risks = num_risks
        self.mlp = tt.practical.MLPVanilla(
            in_features, num_nodes, num_risks * out_features,
            batch_norm, dropout,
        )
        
    def forward(self, input):
        out = self.mlp(input)
        return out.view(out.size(0), self.num_risks, -1)

class CauseSpecificNet(torch.nn.Module):
    def __init__(self, in_features, num_nodes_shared, num_nodes_indiv, num_risks,
                 out_features, batch_norm=True, dropout=None):
        super().__init__()
        self.shared_net = tt.practical.MLPVanilla(
            in_features, num_nodes_shared[:-1], num_nodes_shared[-1],
            batch_norm, dropout,
        )
        self.risk_nets = torch.nn.ModuleList()
        for _ in range(num_risks):
            net = tt.practical.MLPVanilla(
                num_nodes_shared[-1], num_nodes_indiv, out_features,
                batch_norm, dropout,
            )
            self.risk_nets.append(net)

    def forward(self, input):
        out = self.shared_net(input)
        out = [net(out) for net in self.risk_nets]
        out = torch.stack(out, dim=1)
        return out

In [36]:
STConfig['data'] = 'seer'
STConfig['num_hidden_layers'] = 2
STConfig['hidden_size'] = 16
STConfig['intermediate_size'] = 64
STConfig['num_attention_heads'] = 2
STConfig['initializer_range'] = .02
STConfig['early_stop_patience'] = 5

df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)
train_set = (df_train, df_y_train)
val_set = (df_val, df_y_val)




Unnamed: 0,Regional nodes examined (1988+),CS tumor size (2004-2015),Total number of benign/borderline tumors for patient,Total number of in situ/malignant tumors for patient,duration,event_heart,event_breast,event_0,event_1,duration_disc,proportion
count,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0
mean,-0.0,0.0,0.0,0.0,67.41,0.05,0.18,0.18,0.05,2.36,0.51
std,1.0,1.0,1.0,1.0,31.45,0.21,0.39,0.39,0.21,0.97,0.29
min,-0.48,-0.38,-0.08,-0.57,1.0,0.0,0.0,0.0,0.0,0.0,0.02
25%,-0.43,-0.34,-0.08,-0.57,48.0,0.0,0.0,0.0,0.0,2.0,0.25
50%,-0.32,-0.3,-0.08,-0.57,69.0,0.0,0.0,0.0,0.0,3.0,0.52
75%,-0.0,-0.25,-0.08,0.94,92.0,0.0,0.0,0.0,0.0,3.0,0.76
max,4.83,3.44,57.86,28.18,121.0,1.0,1.0,1.0,1.0,3.0,1.0


In [37]:
df.iloc[:, 0:13].describe().apply(lambda s: s.apply('{0:.2f}'.format))

Unnamed: 0,Sex,Year of diagnosis,"Race recode (W, B, AI, API)",Histologic Type ICD-O-3,Laterality,Sequence number,ER Status Recode Breast Cancer (1990+),PR Status Recode Breast Cancer (1990+),Summary stage 2000 (1998-2017),RX Summ--Surg Prim Site (1998+),Reason no cancer-directed surgery,First malignant primary indicator,Diagnostic Confirmation
count,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0
mean,0.01,6.34,2.72,1.43,2.48,4.4,1.71,1.59,1.24,8.74,4.55,0.82,3.0
std,0.09,2.79,0.62,4.58,1.5,2.44,0.56,0.61,0.56,8.04,1.38,0.38,0.18
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,5.0,3.0,0.0,1.0,1.0,2.0,1.0,1.0,4.0,5.0,1.0,3.0
50%,0.0,7.0,3.0,0.0,1.0,6.0,2.0,2.0,1.0,5.0,5.0,1.0,3.0
75%,0.0,9.0,3.0,1.0,4.0,6.0,2.0,2.0,2.0,13.0,5.0,1.0,3.0
max,1.0,10.0,3.0,74.0,4.0,6.0,2.0,2.0,2.0,47.0,5.0,1.0,5.0


In [38]:
df.iloc[:, 14:25].describe().apply(lambda s: s.apply('{0:.2f}'.format))


Unnamed: 0,Regional nodes examined (1988+),CS tumor size (2004-2015),Total number of benign/borderline tumors for patient,Total number of in situ/malignant tumors for patient,duration,event_heart,event_breast,event_0,event_1,duration_disc,proportion
count,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0,476746.0
mean,-0.0,0.0,0.0,0.0,67.41,0.05,0.18,0.18,0.05,2.36,0.51
std,1.0,1.0,1.0,1.0,31.45,0.21,0.39,0.39,0.21,0.97,0.29
min,-0.48,-0.38,-0.08,-0.57,1.0,0.0,0.0,0.0,0.0,0.0,0.02
25%,-0.43,-0.34,-0.08,-0.57,48.0,0.0,0.0,0.0,0.0,2.0,0.25
50%,-0.32,-0.3,-0.08,-0.57,69.0,0.0,0.0,0.0,0.0,3.0,0.52
75%,-0.0,-0.25,-0.08,0.94,92.0,0.0,0.0,0.0,0.0,3.0,0.76
max,4.83,3.44,57.86,28.18,121.0,1.0,1.0,1.0,1.0,3.0,1.0


## DeepHit - SEER

In [None]:



df_train, df_y_train = train_set
x_train = df_train.values.astype('float32')
y_train = (df_y_train['duration'].values.astype('int64'), (df_y_train['event_0'].values + df_y_train['event_1'].values * 2).astype('int64'))

df_val, df_y_val = val_set
x_val = df_val.values.astype('float32')
y_val = (df_y_val['duration'].values.astype('int64'), (df_y_val['event_0'].values + df_y_val['event_1'].values * 2).astype('int64'))
val = (x_val, y_val)

in_features = x_train.shape[1]
num_nodes_shared = [64, 64]
num_nodes_indiv = [32]
num_risks = 2
out_features = len(STConfig['labtrans'].cuts)
batch_norm = True
dropout = 0.1

net = CauseSpecificNet(in_features, num_nodes_shared, num_nodes_indiv, num_risks,
                    out_features, batch_norm, dropout).to(DEVICE)
optimizer = tt.optim.AdamWR(lr=0.01, decoupled_weight_decay=0.01,
                        cycle_eta_multiplier=0.8)
DeepHitModel = DeepHit(net, optimizer, alpha=0.2, sigma=0.1,
            duration_index=STConfig['labtrans'].cuts)
epochs = 50
batch_size = 256
callbacks = [tt.callbacks.EarlyStoppingCycle()]
verbose = True # set to True if you want printout

x_train, y_train = torch.tensor(x_train).to(DEVICE), (torch.tensor(y_train[0]).to(DEVICE), torch.tensor(y_train[1]).to(DEVICE))
x_val, y_val = torch.tensor(x_val).to(DEVICE), (torch.tensor(y_val[0]).to(DEVICE), torch.tensor(y_val[1]).to(DEVICE))
val = (x_val, y_val)

log = DeepHitModel.fit(x_train, y_train, batch_size, epochs, callbacks, verbose, val_data=val)
log = log.to_pandas()




## SurvTRACE - SEER

In [None]:
hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeer = SurvTraceMulti(STConfig).to(DEVICE)
SurvTraceSeer_trainer = Trainer(SurvTraceSeer)
train_loss, val_loss = SurvTraceSeer_trainer.fit((df_train, df_y_train), (df_val, df_y_val),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)


## SurvTRACE - SUPPORT

In [30]:
# define the setup parameters
STConfig['data'] = 'support'
STConfig['num_event'] = 1

hparams = {
    'batch_size': 128,
    'weight_decay': 0,
    'learning_rate': 1e-3,
    'epochs': 20,
}
df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)

df.describe().apply(lambda s: s.apply('{0:.2f}'.format))



Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,duration,event
count,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0,8873.0
mean,62.63,0.44,1.87,1.25,0.19,0.03,0.94,84.53,97.48,23.36,37.11,137.57,12.35,1.78,478.64,0.68
std,15.62,0.5,1.34,0.62,0.4,0.18,0.58,27.82,31.7,9.63,1.26,6.07,9.27,1.69,560.83,0.47
min,18.04,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,31.7,110.0,0.0,0.1,3.0,0.0
25%,52.75,0.0,1.0,1.0,0.0,0.0,1.0,63.0,72.0,18.0,36.2,134.0,7.0,0.9,26.0,0.0
50%,64.83,0.0,2.0,1.0,0.0,0.0,1.0,77.0,100.0,24.0,36.7,137.0,10.6,1.2,231.0,1.0
75%,74.03,1.0,3.0,1.0,0.0,0.0,1.0,107.0,120.0,28.0,38.2,141.0,15.3,1.9,763.0,1.0
max,101.85,1.0,9.0,5.0,1.0,1.0,2.0,195.0,300.0,90.0,41.7,181.0,200.0,21.5,2029.0,1.0


In [None]:
model = SurvTraceSingle(STConfig)
trainer = Trainer(model)
train_loss, val_loss = trainer.fit((df_train, df_y_train), (df_val, df_y_val),
        batch_size=hparams['batch_size'],
        epochs=hparams['epochs'],
        learning_rate=hparams['learning_rate'],
        weight_decay=hparams['weight_decay'],)

## SurvTRACE - METABRIC

In [29]:

# define the setup parameters
STConfig['data'] = 'metabric'
STConfig['num_event'] = 1

hparams = {
    'batch_size': 128,
    'weight_decay': 0,
    'learning_rate': 1e-3,
    'epochs': 20,
}
df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(STConfig)

df.describe().apply(lambda s: s.apply('{0:.2f}'.format))




Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,duration,event
count,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0
mean,6.21,6.24,10.77,5.87,0.62,0.6,0.21,0.76,61.09,125.03,0.58
std,0.86,1.02,1.36,0.34,0.49,0.49,0.41,0.43,12.98,76.33,0.49
min,5.16,4.86,6.37,5.1,0.0,0.0,0.0,0.0,21.93,0.0,0.0
25%,5.69,5.41,9.97,5.62,0.0,0.0,0.0,1.0,51.38,60.82,0.0
50%,5.95,5.88,10.53,5.82,1.0,1.0,0.0,1.0,61.77,114.9,1.0
75%,6.46,6.9,11.16,6.06,1.0,1.0,0.0,1.0,70.59,184.47,1.0
max,14.44,9.93,14.64,7.66,1.0,1.0,1.0,1.0,96.29,355.2,1.0


In [None]:

model = SurvTraceSingle(STConfig)
trainer = Trainer(model)
train_loss, val_loss = trainer.fit((df_train, df_y_train), (df_val, df_y_val),
        batch_size=hparams['batch_size'],
        epochs=hparams['epochs'],
        learning_rate=hparams['learning_rate'],
        weight_decay=hparams['weight_decay'],)

## SurvTRACE - SEER - Loss Function Ablation

In [None]:
from survtrace.losses import NLLLogistiHazardLoss


hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeerLossAblation = SurvTraceMulti(STConfig).to(DEVICE)
SurvTraceSeerLossAblation_trainer = Trainer(SurvTraceSeerLossAblation, metrics=NLLLogistiHazardLoss())
train_loss, val_loss = SurvTraceSeerLossAblation_trainer.fit((df_train, df_y_train), (df_val, df_y_val),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)


## SurvTRACE - SEER - Transformer Ablation

In [None]:
from survtrace.modeling_bert import BertEncoderLame


class SurvTraceMultiAblation(SurvTraceMulti):
    def __init__(self, config: STConfig):
        self.encoder = BertEncoderLame(config)

hparams = {
    'batch_size': 1024,
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'epochs': 100,
}
SurvTraceSeerEncoderAblation = SurvTraceMultiAblation(STConfig).to(DEVICE)
SurvTraceSeerEncoderAblation_trainer = Trainer(SurvTraceSeerEncoderAblation)
train_loss, val_loss = SurvTraceSeerEncoderAblation_trainer.fit((df_train, df_y_train), (df_val, df_y_val),
    batch_size=hparams['batch_size'],
    epochs=hparams['epochs'],
    learning_rate=hparams['learning_rate'],
    weight_decay=hparams['weight_decay'],
    val_batch_size=10000,)
