In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset
from torch.optim import Adam

from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import scale, StandardScaler
from sklearn.metrics import roc_auc_score

from tqdm import tqdm, trange
from math import floor

import sys,os
sys.path.append('../src/')

from lifelines.utils import concordance_index
from torchmtlr import MTLRCR, mtlr_neg_log_likelihood, mtlr_risk, mtlr_survival
from torchmtlr.utils import make_time_bins, encode_survival, reset_parameters

# Load/preprocess rotterdam Data

In [4]:
def multiple_events(row):
    '''
    Censor = 0
    Recurrence = 1
    Death = 2
    '''
    event        = row["event"]
    
    if event==0:
        return 0
    elif row['rtime'] < row['dtime']:
        return 1
    else:
        return 2

import config as cfg
df = pd.read_csv(f'{cfg.DATA_DIR}/rotterdam.csv')
df['time'] = np.minimum(df['rtime'], df['dtime'])
df['event'] = df['recur'] | df['death']


size_mapping = {
    '<=20': 10,
    '20-50': 35,
    '>50': 75
}

# Apply mapping
df['size_mapped'] = df['size'].replace(size_mapping)

In [5]:
time_bins = make_time_bins(df["time"], event=df["event"])
multi_events = df.apply(lambda x: multiple_events(x), axis=1)

'''
Normalize the Data
'''

temp_X_df = df.drop(['pid', 'size', 'rtime', 'recur', 'dtime', 'death', 'time', 'event'], axis=1)
scaler = StandardScaler()
temp_X_df = pd.DataFrame(scaler.fit_transform(temp_X_df), columns=temp_X_df.columns)

y = encode_survival(df["time"], multi_events, time_bins)
X = torch.tensor(temp_X_df.values, dtype=torch.float)

full_indices = range(len(df))
train_indices, test_indices = train_test_split(full_indices, test_size=0.2) # just train & Test
# train_indices, val_indices = train_test_split(train_indices, test_size=0.1)

X_train, X_test = X[train_indices], X[test_indices]
y_train, y_test = y[train_indices], y[test_indices]

# X_train, X_val, X_test = X[train_indices], X[val_indices], X[test_indices]
# y_train, y_val, y_test = y[train_indices], y[val_indices], y[test_indices]

df_test = df.iloc[test_indices]

# Train MTLR

In [6]:
def make_optimizer(opt_cls, model, **kwargs):
    """Creates a PyTorch optimizer for MTLR training."""
    params_dict = dict(model.named_parameters())
    weights = [v for k, v in params_dict.items() if "mtlr" not in k and "bias" not in k]
    biases = [v for k, v in params_dict.items() if "bias" in k]
    mtlr_weights = [v for k, v in params_dict.items() if "mtlr_weight" in k]
    # Don't use weight decay on the biases and MTLR parameters, which have
    # their own separate L2 regularization
    optimizer = opt_cls([
        {"params": weights},
        {"params": biases, "weight_decay": 0.},
        {"params": mtlr_weights, "weight_decay": 0.},
    ], **kwargs)
    return optimizer

def train_mtlr(x, y, model, time_bins,
               num_epochs=1000, lr=.01, weight_decay=0.,
               C1=1., batch_size=None,
               verbose=True, device="cpu"):
    """Trains the MTLR model using minibatch gradient descent.
    
    Parameters
    ----------
    model : torch.nn.Module
        MTLR model to train.
    data_train : pd.DataFrame
        The training dataset. Must contain a `time` column with the
        event time for each sample and an `event` column containing
        the event indicator.
    num_epochs : int
        Number of training epochs.
    lr : float
        The learning rate.
    weight_decay : float
        Weight decay strength for all parameters *except* the MTLR
        weights. Only used for Deep MTLR training.
    C1 : float
        L2 regularization (weight decay) strenght for MTLR parameters.
    batch_size : int
        The batch size.
    verbose : bool
        Whether to display training progress.
    device : str
        Device name or ID to use for training.
        
    Returns
    -------
    torch.nn.Module
        The trained model.
    """
    optimizer = make_optimizer(Adam, model, lr=lr, weight_decay=weight_decay)
    reset_parameters(model)
    print(x.shape, y.shape)
    model = model.to(device)
    model.train()
    train_loader = DataLoader(TensorDataset(x, y), batch_size=batch_size, shuffle=True)
    
    pbar =  trange(num_epochs, disable=not verbose)
    for i in pbar:
        for xi, yi in train_loader:
            xi, yi = xi.to(device), yi.to(device)
            y_pred = model(xi)
            loss = mtlr_neg_log_likelihood(y_pred, yi, model, C1, average=True)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        pbar.set_description(f"[epoch {i+1: 4}/{num_epochs}]")
        pbar.set_postfix_str(f"loss = {loss.item():.4f}")
    model.eval()
    return model

In [7]:
device = "cpu"
num_time_bins = len(time_bins)+1
in_features = X_train.shape[1]

# fit MTLR model 
mtlr = MTLRCR(in_features=in_features, num_time_bins=num_time_bins, num_events=2) # here is 2 competing risk event            
mtlr = train_mtlr(X_train, y_train, mtlr, time_bins, num_epochs=350, 
                  lr=1e-3, batch_size=64, verbose=True, device=device, C1=1.)

torch.Size([2385, 10]) torch.Size([2385, 86])


[epoch  350/350]: 100%|██████████| 350/350 [00:48<00:00,  7.26it/s, loss = 2.6984]


In [8]:
pred_prob       = mtlr(X_test)
survival_recur  = mtlr_survival(pred_prob[:,:num_time_bins]).detach().numpy()
survival_death = mtlr_survival(pred_prob[:,num_time_bins:]).detach().numpy()

In [9]:
pred_risk = mtlr_risk(pred_prob, 2).detach().numpy()

ci_recur  = concordance_index(df_test["time"], -pred_risk[:, 0], event_observed=df_test["recur"]) # 1 is rec
ci_death = concordance_index(df_test["time"], -pred_risk[:, 1], event_observed=df_test["death"]) # 2 is death

print ('Recur C-index:', ci_recur)
print ('death C-index:', ci_death)

Recur C-index: 0.6735140106986389
death C-index: 0.5449568857561988
