## Introduction
Given an instance, a multi-event survival model predicts the time until that instance experiences each of several different events. These events are not mutually exclusive and there are often statistical dependencies between them. MENSA works by jointly learning the K event distributions as a convex combination of Weibull distributions. This approach leverages mutual information between events that may be lost in models that assume independence.

The data format is as follows:

For single-event: X = [x_1, x_2, ... x_n], T = [t_1, t_2, ..., t_i], E = [e_1, e_2, ... e_i]\
For competing risks: X = [x_1, x_2, ... x_n], T = [t_1, t_2, ..., t_i], E = [e_1, e_2, ... e_i]\
For multi-event: X = [x_1, x_2, ... x_n], T = [[t_i1, t_i2, ..., t_ik], ...], E = [[e_i1, e_i2, ..., e_ik], ...]

Here, $n$ is the number of covariates, $i$ is the subject and $k$ denotes the number of events.\
The demo uses a synthetic data generator (DGP) for better reproducibility.

In [77]:
# 3rd party
import pandas as pd
import numpy as np
import config as cfg
import torch
import random
from SurvivalEVAL.Evaluator import LifelinesEvaluator

# Local
from data_loader import (SingleEventSyntheticDataLoader,
                         CompetingRiskSyntheticDataLoader,
                         MultiEventSyntheticDataLoader)
from utility.survival import make_time_bins
from utility.config import load_config
from utility.evaluation import global_C_index, local_C_index
from mensa.model import MENSA

np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

# Setup precision
dtype = torch.float64
torch.set_default_dtype(dtype)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')        

## Single-event prediction

We generate a synthetic single-event dataset from a Weibull DGP with no dependence (k_tau=0) and linear risk.\
This generates one event and a censoring event.

See the concrete implementation for details.

In [78]:
# Load synthetic data for single-event case
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_se.yaml")
dl = SingleEventSyntheticDataLoader().load_data(data_config=data_config,
                                                linear=True, copula_name="",
                                                k_tau=0, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]
n_events = dl.n_events

In [79]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Define the model
config = load_config(cfg.MENSA_CONFIGS_DIR, f"synthetic.yaml")
n_epochs = config['n_epochs']
n_dists = config['n_dists']
lr = config['lr']
batch_size = config['batch_size']
layers = config['layers']
weight_decay = config['weight_decay']
dropout_rate = config['dropout_rate']
model = MENSA(n_features, layers=layers, dropout_rate=dropout_rate,
                n_events=n_events, n_dists=n_dists, device=device)

# Train the model
model.fit(train_dict, valid_dict, learning_rate=lr, n_epochs=n_epochs,
          weight_decay=weight_decay, patience=10,
          batch_size=batch_size, verbose=True)

[Epoch   44/1000]:   4%|▍         | 43/1000 [00:51<19:09,  1.20s/it, Training loss = 2.8303, Validation loss = 2.9180]

Early stopping at iteration 43, best valid loss: 2.9059269795695584





In [80]:
# Get predictions for the single event
model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=1)
model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())

# Use SurvivalEVAL package to calculate popular metrics
y_train_time = train_dict['T']
y_train_event = (train_dict['E'])*1.0
y_test_time = test_dict['T']
y_test_event = (test_dict['E'])*1.0
lifelines_eval = LifelinesEvaluator(model_preds.T, y_test_time, y_test_event,
                                    y_train_time, y_train_event)

ci = lifelines_eval.concordance()[0]
ibs = lifelines_eval.integrated_brier_score(num_points=len(time_bins))
mae_hinge = lifelines_eval.mae(method="Hinge")
mae_margin = lifelines_eval.mae(method="Margin")
mae_pseudo = lifelines_eval.mae(method="Pseudo_obs")
d_calib = lifelines_eval.d_calibration()[0]

metrics = [ci, ibs, mae_hinge, mae_margin, mae_pseudo, d_calib]
print("Event 1: " + str(metrics))

Event 1: [0.820584836947118, 0.13181309379379055, 1.7259265116239435, 3.366090142802187, 6.369244971546377, 0.27877795442251696]


## Competing-risks prediction

We generate a synthetic competing risks (K=2) dataset from a Weibull DGP with no dependence (k_tau=0) and linear risk.\
This generates two actual competing events and a censoring event.

See the concrete implementation for details.

In [81]:
# Load synthetic data for competing risks case
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_cr.yaml")
dl = CompetingRiskSyntheticDataLoader().load_data(data_config, k_tau=0, copula_name="",
                                                  linear=True, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]
n_events = dl.n_events

In [82]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Define the model
config = load_config(cfg.MENSA_CONFIGS_DIR, f"synthetic.yaml")
n_epochs = config['n_epochs']
n_dists = config['n_dists']
lr = config['lr']
batch_size = config['batch_size']
layers = config['layers']
weight_decay = config['weight_decay']
dropout_rate = config['dropout_rate']
model = MENSA(n_features, layers=layers, dropout_rate=dropout_rate,
              n_events=n_events, n_dists=n_dists, device=device)

# Train the model
model.fit(train_dict, valid_dict, learning_rate=lr, n_epochs=n_epochs,
          weight_decay=weight_decay, patience=10,
          batch_size=batch_size, verbose=True)

[Epoch   65/1000]:   6%|▋         | 64/1000 [01:32<22:27,  1.44s/it, Training loss = 1.9101, Validation loss = 1.8936]

Early stopping at iteration 64, best valid loss: 1.888639416781049





In [83]:
# Make predictions for competing risks
all_preds = []
for i in range(n_events):
    model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=i+1) # skip censoring event
    model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())
    all_preds.append(model_preds)
    
# Calculate local and global CI
y_test_time = np.stack([test_dict['T'] for _ in range(n_events)], axis=1)
y_test_event = np.stack([np.array((test_dict['E'] == i+1)*1.0) for i in range(n_events)], axis=1)
all_preds_arr = [df.to_numpy() for df in all_preds]
global_ci = global_C_index(all_preds_arr, y_test_time, y_test_event)
local_ci = local_C_index(all_preds_arr, y_test_time, y_test_event)

# Use SurvivalEVAL package to calculate popular metrics
for event_id, surv_preds in enumerate(all_preds):
    n_train_samples = len(train_dict['X'])
    n_test_samples= len(test_dict['X'])
    y_train_time = train_dict['T']
    y_train_event = (train_dict['E'])*1.0
    y_test_time = test_dict['T']
    y_test_event = (test_dict['E'])*1.0
    
    lifelines_eval = LifelinesEvaluator(surv_preds.T, y_test_time, y_test_event,
                                        y_train_time, y_train_event)
    
    ci =  lifelines_eval.concordance()[0]
    ibs = lifelines_eval.integrated_brier_score(num_points=len(time_bins))
    mae_hinge = lifelines_eval.mae(method="Hinge")
    mae_margin = lifelines_eval.mae(method="Margin")
    mae_pseudo = lifelines_eval.mae(method="Pseudo_obs")
    d_calib = lifelines_eval.d_calibration()[0]
    
    metrics = [ci, ibs, mae_hinge, mae_margin, mae_pseudo, global_ci, local_ci, d_calib]
    print(f'Event {event_id+1}: ' + f'{metrics}')

Event 1: [0.6542917324907846, 0.053425780883329096, 0.06376680407782936, 5.308637727519044, 4.3206685346010785, 0.6661400737230121, 0.5384615384615384, 0.9999999999927013]
Event 2: [0.6614007372301212, 0.05114556084876717, 0.07307959210563922, 4.965387994837668, 3.9657969006068923, 0.6661400737230121, 0.5384615384615384, 0.999999995179888]


## Multi-event prediction

We generate a synthetic multi-event dataset from a Weibull DGP with no dependence (k_tau=0) and linear risk.\
This generates four events that are not mutually exclusive.

See the concrete implementation for details.

In [84]:
# Load and split data
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_me.yaml")
dl = MultiEventSyntheticDataLoader().load_data(data_config, k_taus=[0, 0, 0, 0], copula_names=[],
                                               linear=True, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]
n_events = dl.n_events

In [85]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Define the model
config = load_config(cfg.MENSA_CONFIGS_DIR, f"synthetic.yaml")
n_epochs = config['n_epochs']
n_dists = config['n_dists']
lr = config['lr']
batch_size = config['batch_size']
layers = config['layers']
weight_decay = config['weight_decay']
dropout_rate = config['dropout_rate']
model = MENSA(n_features, layers=layers, dropout_rate=dropout_rate,
              n_events=n_events, n_dists=n_dists, device=device)

# Train the model
model.fit(train_dict, valid_dict, learning_rate=lr, n_epochs=n_epochs,
          weight_decay=weight_decay, patience=10,
          batch_size=batch_size, verbose=True)

[Epoch   23/1000]:   2%|▏         | 22/1000 [00:43<32:15,  1.98s/it, Training loss = 4.8831, Validation loss = 4.6399] 

Early stopping at iteration 22, best valid loss: 4.1192310296907495





In [86]:
# Make predictions for multi-event
all_preds = []
for i in range(n_events):
    model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=i+1) # skip censoring event
    model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())
    all_preds.append(model_preds)

# Calculate local and global CI
all_preds_arr = [df.to_numpy() for df in all_preds]
global_ci = global_C_index(all_preds_arr, test_dict['T'].numpy(), test_dict['E'].numpy())
local_ci = local_C_index(all_preds_arr, test_dict['T'].numpy(), test_dict['E'].numpy())

# Use SurvivalEVAL package to calculate popular metrics
for event_id, surv_preds in enumerate(all_preds):
    n_train_samples = len(train_dict['X'])
    n_test_samples= len(test_dict['X'])
    y_train_time = train_dict['T'][:,event_id]
    y_train_event = train_dict['E'][:,event_id]
    y_test_time = test_dict['T'][:,event_id]
    y_test_event = test_dict['E'][:,event_id]
    
    lifelines_eval = LifelinesEvaluator(surv_preds.T, y_test_time, y_test_event,
                                        y_train_time, y_train_event)
    
    ci =  lifelines_eval.concordance()[0]
    ibs = lifelines_eval.integrated_brier_score(num_points=len(time_bins))
    mae_hinge = lifelines_eval.mae(method="Hinge")
    mae_margin = lifelines_eval.mae(method="Margin")
    mae_pseudo = lifelines_eval.mae(method="Pseudo_obs")
    d_calib = lifelines_eval.d_calibration()[0]
    
    metrics = [ci, ibs, mae_hinge, mae_margin, mae_pseudo, global_ci, local_ci, d_calib]
    print(f'Event {event_id+1}: ' + f'{metrics}')

Event 1: [0.8165445961316854, 0.058125933917466725, 3.909632098450562, 52.77514858680799, 44.61149061882991, 0.8245981936541406, 0.7030947775628626, 0.0031317579165855564]
Event 2: [0.8294284223861689, 0.0349372619218568, 4.387488412974439, 52.8410442064724, 48.90438511547283, 0.8245981936541406, 0.7030947775628626, 0.1447498868804959]
Event 3: [0.8238387670420866, 0.03402523956982753, 4.842466609322007, 61.50006834710062, 57.42028850610096, 0.8245981936541406, 0.7030947775628626, 0.41379230771236253]
Event 4: [0.8294652823883908, 0.035361355580257, 4.430020968922511, 62.14436841138883, 58.36433422241283, 0.8245981936541406, 0.7030947775628626, 0.08479103958546916]
