## Introduction
This demo shows how to train the MENSA model on single-event, competing risks and multi-event data. The demo uses synthetic data for better reproducibility.

In [7]:
# 3rd party
import pandas as pd
import numpy as np
import config as cfg
import torch
import random
import warnings
import argparse
import os
from SurvivalEVAL.Evaluator import LifelinesEvaluator

# Local
from data_loader import (SingleEventSyntheticDataLoader,
                         CompetingRiskSyntheticDataLoader,
                         MultiEventSyntheticDataLoader)
from utility.survival import (make_time_bins, convert_to_structured,
                              compute_l1_difference, predict_survival_function)
from utility.data import dotdict
from utility.config import load_config
from utility.evaluation import global_C_index, local_C_index
from mensa.model import MENSA

warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")

np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

# Setup precision
dtype = torch.float64
torch.set_default_dtype(dtype)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')        

## 2 Synthetic single-event

We generate a synthetic single-event (K=2) dataset from a Weibull DGP with no dependence (k_tau=0) and linear risk.

In [8]:
# Load synthetic data for single-event case
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_se.yaml")
dl = SingleEventSyntheticDataLoader().load_data(data_config=data_config,
                                                linear=True, copula_name="",
                                                k_tau=0, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]

In [9]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Load model configuration and make model
config = load_config(cfg.MENSA_CONFIGS_DIR, f"synthetic.yaml")
n_epochs = config['n_epochs']
lr = config['lr']
batch_size = config['batch_size']
layers = config['layers']
model = MENSA(n_features, layers=layers, n_events=2, device=device)
lr_dict = {'network': lr}

# Train the model
model.fit(train_dict, valid_dict, optimizer='adam', verbose=True,
          n_epochs=n_epochs, patience=10, batch_size=batch_size, lr_dict=lr_dict)

[Epoch  149/1000]:  15%|█▍        | 148/1000 [00:48<04:40,  3.04it/s, Training loss = 2.4492, Validation loss = 2.4631]

Early stopping at iteration 148, best valid loss: 2.459958369040445





In [10]:
# Get predictions for the single event
model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=0)
model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())

# Use SurvivalEVAL package to calculate popular metrics
y_train_time = train_dict['T']
y_train_event = (train_dict['E'])*1.0
y_test_time = test_dict['T']
y_test_event = (test_dict['E'])*1.0
lifelines_eval = LifelinesEvaluator(model_preds.T, y_test_time, y_test_event,
                                    y_train_time, y_train_event)

ci = lifelines_eval.concordance()[0]
ibs = lifelines_eval.integrated_brier_score()
mae_hinge = lifelines_eval.mae(method="Hinge")
mae_margin = lifelines_eval.mae(method="Margin")
mae_pseudo = lifelines_eval.mae(method="Pseudo_obs")
d_calib = lifelines_eval.d_calibration()[0]

metrics = [ci, ibs, mae_hinge, mae_margin, mae_pseudo, d_calib]
print("E1: " + str(metrics))

E1: [0.6698866820771986, 0.1028138761033593, 1.3681644611295425, 1.6488549808110948, 1.7436583083292598, 0.5193912618049151]


## 3 Synthetic competing risks

We generate a synthetic competing risks (K=3) dataset from a Weibull DGP with no dependence (k_tau=0) and linear risk.

In [11]:
# Load synthetic data for competing risks case
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_cr.yaml")
dl = CompetingRiskSyntheticDataLoader().load_data(data_config, k_tau=0, copula_name="",
                                                  linear=True, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]
n_events = dl.n_events

In [12]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Load model configuration and make model
config = load_config(cfg.MENSA_CONFIGS_DIR, f"synthetic.yaml")
n_epochs = config['n_epochs']
lr = config['lr']
batch_size = config['batch_size']
layers = config['layers']
model = MENSA(n_features, layers=layers, n_events=3, device=device)
lr_dict = {'network': lr}

# Train the model
model.fit(train_dict, valid_dict, optimizer='adam', verbose=True,
          n_epochs=n_epochs, patience=10, batch_size=batch_size, lr_dict=lr_dict)

[Epoch   83/1000]:   8%|▊         | 82/1000 [02:20<26:15,  1.72s/it, Training loss = 2.9794, Validation loss = 3.0258]

Early stopping at iteration 82, best valid loss: 3.0250137994749773





In [13]:
# Make predictions for competing risks
all_preds = []
for i in range(n_events):
    model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=i+1)
    model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())
    all_preds.append(model_preds)
    
# Calculate local and global CI
y_test_time = np.stack([test_dict['T'] for _ in range(n_events)], axis=1)
y_test_event = np.stack([np.array((test_dict['E'] == i+1)*1.0) for i in range(n_events)], axis=1)
all_preds_arr = [df.to_numpy() for df in all_preds]
global_ci = global_C_index(all_preds_arr, y_test_time, y_test_event)
local_ci = local_C_index(all_preds_arr, y_test_time, y_test_event)

# Make evaluation for each event
for event_id, surv_preds in enumerate(all_preds):
    n_train_samples = len(train_dict['X'])
    n_test_samples= len(test_dict['X'])
    y_train_time = train_dict[f'T{event_id+1}']
    y_train_event = np.array([1] * n_train_samples)
    y_test_time = test_dict[f'T{event_id+1}']
    y_test_event = np.array([1] * n_test_samples)
    lifelines_eval = LifelinesEvaluator(surv_preds.T, y_test_time, y_test_event,
                                        y_train_time, y_train_event)
    
    ci =  lifelines_eval.concordance()[0]
    ibs = lifelines_eval.integrated_brier_score(num_points=len(time_bins))
    mae = lifelines_eval.mae(method='Uncensored')
    d_calib = lifelines_eval.d_calibration()[0]
    
    metrics = [ci, ibs, mae, d_calib, global_ci, local_ci]
    print(f'E{event_id+1}: ' + f'{metrics}')

E1: [0.6451154845105632, 0.05835831089749532, 1.7810479189557282, 5.166257262716706e-105, 0.6266510637837749, 0.6771743258640335]
E2: [0.6169488622155539, 0.08223018811164312, 2.4177082183863488, 0.0, 0.6266510637837749, 0.6771743258640335]


## 4 Synthetic multi-event

We generate a synthetic multi-event (K=3) dataset from a Weibull DGP with no dependence (k_tau=0) and linear risk.

In [14]:
# Load and split data
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_me.yaml")
dl = MultiEventSyntheticDataLoader().load_data(data_config, k_taus=[0, 0, 0], copula_names=[],
                                               linear=True, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]
n_events = dl.n_events

In [15]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Load model configuration and make model
config = load_config(cfg.MENSA_CONFIGS_DIR, f"synthetic.yaml")
n_epochs = config['n_epochs']
lr = config['lr']
batch_size = config['batch_size']
layers = config['layers']
model = MENSA(n_features, layers=layers, n_events=3, device=device)
lr_dict = {'network': lr}

# Train the model
model.fit(train_dict, valid_dict, optimizer='adam', verbose=True,
          n_epochs=n_epochs, patience=10, batch_size=batch_size, lr_dict=lr_dict)

  0%|          | 0/1000 [00:00<?, ?it/s]

[Epoch   71/1000]:   7%|▋         | 70/1000 [02:36<34:42,  2.24s/it, Training loss = 5.4547, Validation loss = 5.5427]

Early stopping at iteration 70, best valid loss: 5.54063224715259





In [16]:
# Make predictions for multi-event
all_preds = []
for i in range(n_events):
    model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=i)
    model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())
    all_preds.append(model_preds)

# Calculate local and global CI
all_preds_arr = [df.to_numpy() for df in all_preds]
global_ci = global_C_index(all_preds_arr, test_dict['T'].numpy(), test_dict['E'].numpy())
local_ci = local_C_index(all_preds_arr, test_dict['T'].numpy(), test_dict['E'].numpy())

# Make evaluation for each event
for event_id, surv_preds in enumerate(all_preds):
    n_train_samples = len(train_dict['X'])
    n_test_samples= len(test_dict['X'])
    y_train_time = train_dict['T'][:,event_id]
    y_train_event = np.array([1] * n_train_samples)
    y_test_time = test_dict['T'][:,event_id]
    y_test_event = np.array([1] * n_test_samples)
    lifelines_eval = LifelinesEvaluator(surv_preds.T, y_test_time, y_test_event,
                                        y_train_time, y_train_event)
    
    ci =  lifelines_eval.concordance()[0]
    ibs = lifelines_eval.integrated_brier_score(num_points=len(time_bins))
    mae = lifelines_eval.mae(method='Uncensored')
    d_calib = lifelines_eval.d_calibration()[0]
    
    metrics = [ci, ibs, mae, d_calib, global_ci, local_ci]
    print(f'E{event_id+1}: ' + f'{metrics}')

E1: [0.6437182971777223, 0.09675261740522718, 1.496732421395548, 6.168807708841666e-174, 0.6373376724813675, 0.6449748275219094]
E2: [0.6408499633529109, 0.1030200286345988, 1.5152797329605348, 4.8233355062202835e-54, 0.6373376724813675, 0.6449748275219094]
E3: [0.6258876777920004, 0.08720759817958376, 1.7690591442505441, 0.0, 0.6373376724813675, 0.6449748275219094]
