## Introduction
Given an instance, a multi-event survival model predicts the time until that instance experiences each of several different events. These events are not mutually exclusive and there are often statistical dependencies between them. MENSA works by jointly learning the K event distributions as a convex combination of Weibull distributions. This approach leverages mutual information between events that may be lost in models that assume independence.

The data format is as follows:

For single-event: X = [x_1, x_2, ... x_n], T = [t_1, t_2, ..., t_i], E = [e_1, e_2, ... e_i]\
For competing risks: X = [x_1, x_2, ... x_n], T = [t_1, t_2, ..., t_i], E = [e_1, e_2, ... e_i]\
For multi-event: X = [x_1, x_2, ... x_n], T = [[t_i1, t_i2, ..., t_ik], ...], E = [[e_i1, e_i2, ..., e_ik], ...]

Here, $n$ is the number of covariates, $i$ is the subject and $k$ denotes the number of events.\
The demo uses a synthetic data generator (DGP) for better reproducibility.

In [1]:
# 3rd party
import pandas as pd
import numpy as np
import config as cfg
import torch
import random
from SurvivalEVAL.Evaluator import LifelinesEvaluator

# Local
from data_loader import (SingleEventSyntheticDataLoader,
                         CompetingRiskSyntheticDataLoader,
                         MultiEventSyntheticDataLoader)
from utility.survival import make_time_bins
from utility.config import load_config
from utility.evaluation import global_C_index, local_C_index
from mensa.model import MENSA

np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

# Setup precision
dtype = torch.float64
torch.set_default_dtype(dtype)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')        

  from .autonotebook import tqdm as notebook_tqdm


## 2 Synthetic single-event

We generate a synthetic single-event dataset from a Weibull DGP with no dependence (k_tau=0) and linear risk.\
This generates one actual event (E1) and a censoring event, so K=2.

See the concrete implementation for details.

In [2]:
# Load synthetic data for single-event case
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_se.yaml")
dl = SingleEventSyntheticDataLoader().load_data(data_config=data_config,
                                                linear=True, copula_name="",
                                                k_tau=0, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]

In [3]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Define the model
n_epochs = 1000
lr = 0.001
batch_size = 32
layers = [32]
model = MENSA(n_features, layers=layers, n_events=2, device=device)

# Train the model
model.fit(train_dict, valid_dict, optimizer='adam', n_epochs=n_epochs,
          verbose=True, patience=10, batch_size=batch_size, learning_rate=lr)

[Epoch   25/1000]:   2%|▏         | 24/1000 [00:14<09:31,  1.71it/s, Training loss = 2.7674, Validation loss = 2.7579]

Early stopping at iteration 24, best valid loss: 2.7574658100447373





In [4]:
# Get predictions for the single event
model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=0)
model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())

# Use SurvivalEVAL package to calculate popular metrics
y_train_time = train_dict['T']
y_train_event = (train_dict['E'])*1.0
y_test_time = test_dict['T']
y_test_event = (test_dict['E'])*1.0
lifelines_eval = LifelinesEvaluator(model_preds.T, y_test_time, y_test_event,
                                    y_train_time, y_train_event)

ci = lifelines_eval.concordance()[0]
ibs = lifelines_eval.integrated_brier_score(num_points=len(time_bins))
mae_hinge = lifelines_eval.mae(method="Hinge")
mae_margin = lifelines_eval.mae(method="Margin")
mae_pseudo = lifelines_eval.mae(method="Pseudo_obs")
d_calib = lifelines_eval.d_calibration()[0]

metrics = [ci, ibs, mae_hinge, mae_margin, mae_pseudo, d_calib]
print("E1: " + str(metrics))

E1: [0.6438241344402896, 0.0827682825874046, 1.4229945782198683, 2.0510194644112607, 2.366014410048409, 0.6404361712054161]


## 3 Synthetic competing risks

We generate a synthetic competing risks (K=3) dataset from a Weibull DGP with no dependence (k_tau=0) and linear risk.\
This generates two actual competing events (E1 and E2) and a censoring event, so K=3.

See the concrete implementation for details.

In [5]:
# Load synthetic data for competing risks case
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_cr.yaml")
dl = CompetingRiskSyntheticDataLoader().load_data(data_config, k_tau=0, copula_name="",
                                                  linear=True, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]
n_events = dl.n_events

In [6]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Define the model
config = load_config(cfg.MENSA_CONFIGS_DIR, f"synthetic.yaml")
n_epochs = 1000
lr = 0.001
batch_size = 32
layers = [32]
model = MENSA(n_features, layers=layers, n_events=3, device=device)

# Train the model
model.fit(train_dict, valid_dict, optimizer='adam', verbose=True,
          n_epochs=n_epochs, patience=10, batch_size=batch_size, learning_rate=lr)

[Epoch   33/1000]:   3%|▎         | 32/1000 [01:36<48:42,  3.02s/it, Training loss = 3.0576, Validation loss = 3.0684]

Early stopping at iteration 32, best valid loss: 3.0583668206534242





In [7]:
# Make predictions for competing risks
all_preds = []
for i in range(n_events):
    model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=i+1) # skip censoring event
    model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())
    all_preds.append(model_preds)
    
# Calculate local and global CI
y_test_time = np.stack([test_dict['T'] for _ in range(n_events)], axis=1)
y_test_event = np.stack([np.array((test_dict['E'] == i+1)*1.0) for i in range(n_events)], axis=1)
all_preds_arr = [df.to_numpy() for df in all_preds]
global_ci = global_C_index(all_preds_arr, y_test_time, y_test_event)
local_ci = local_C_index(all_preds_arr, y_test_time, y_test_event)

# Use SurvivalEVAL package to calculate popular metrics
for event_id, surv_preds in enumerate(all_preds):
    n_train_samples = len(train_dict['X'])
    n_test_samples= len(test_dict['X'])
    y_train_time = train_dict['T']
    y_train_event = (train_dict['E'])*1.0
    y_test_time = test_dict['T']
    y_test_event = (test_dict['E'])*1.0
    
    lifelines_eval = LifelinesEvaluator(surv_preds.T, y_test_time, y_test_event,
                                        y_train_time, y_train_event)
    
    ci =  lifelines_eval.concordance()[0]
    ibs = lifelines_eval.integrated_brier_score(num_points=len(time_bins))
    mae_hinge = lifelines_eval.mae(method="Hinge")
    mae_margin = lifelines_eval.mae(method="Margin")
    mae_pseudo = lifelines_eval.mae(method="Pseudo_obs")
    d_calib = lifelines_eval.d_calibration()[0]
    
    metrics = [ci, ibs, mae_hinge, mae_margin, mae_pseudo, global_ci, local_ci, d_calib]
    print(f'E{event_id+1}: ' + f'{metrics}')

E1: [0.6233223899146121, 0.07862104796271611, 1.453079890217305, 1.8323388584045197, 2.0044932498816443, 0.6310468593107317, 0.7427236794825728, 0.07817974543059338]
E2: [0.6030719043844883, 0.1270125648583262, 2.357923647837968, 2.6548876807282413, 2.786660651046932, 0.6310468593107317, 0.7427236794825728, 0.0]


## 4 Synthetic multi-event

We generate a synthetic multi-event dataset from a Weibull DGP with no dependence (k_tau=0) and linear risk.\
This generates three actual events (E1, E2 and E3) that are not mutually exclusive.

See the concrete implementation for details.

In [2]:
# Load and split data
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_me.yaml")
dl = MultiEventSyntheticDataLoader().load_data(data_config, k_taus=[0, 0, 0], copula_names=[],
                                               linear=True, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]
n_events = dl.n_events

In [3]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Define the model
config = load_config(cfg.MENSA_CONFIGS_DIR, f"synthetic.yaml")
n_epochs = 1000
lr = 0.001
batch_size = 32
layers = [32]
model = MENSA(n_features, layers=layers, n_events=3, device=device)

# Train the model
model.fit(train_dict, valid_dict, optimizer='adam', verbose=True,
          n_epochs=n_epochs, patience=10, batch_size=batch_size, learning_rate=lr)

[Epoch   15/1000]:   1%|▏         | 14/1000 [00:38<44:51,  2.73s/it, Training loss = 5.7725, Validation loss = 5.7746]

Early stopping at iteration 14, best valid loss: 5.774478964792162





In [4]:
# Make predictions for multi-event
all_preds = []
for i in range(n_events):
    model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=i)
    model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())
    all_preds.append(model_preds)

# Calculate local and global CI
all_preds_arr = [df.to_numpy() for df in all_preds]
global_ci = global_C_index(all_preds_arr, test_dict['T'].numpy(), test_dict['E'].numpy())
local_ci = local_C_index(all_preds_arr, test_dict['T'].numpy(), test_dict['E'].numpy())

# Use SurvivalEVAL package to calculate popular metrics
for event_id, surv_preds in enumerate(all_preds):
    n_train_samples = len(train_dict['X'])
    n_test_samples= len(test_dict['X'])
    y_train_time = train_dict['T'][:,event_id]
    y_train_event = train_dict['E'][:,event_id]
    y_test_time = test_dict['T'][:,event_id]
    y_test_event = test_dict['E'][:,event_id]
    
    lifelines_eval = LifelinesEvaluator(surv_preds.T, y_test_time, y_test_event,
                                        y_train_time, y_train_event)
    
    ci =  lifelines_eval.concordance()[0]
    ibs = lifelines_eval.integrated_brier_score(num_points=len(time_bins))
    mae_hinge = lifelines_eval.mae(method="Hinge")
    mae_margin = lifelines_eval.mae(method="Margin")
    mae_pseudo = lifelines_eval.mae(method="Pseudo_obs")
    d_calib = lifelines_eval.d_calibration()[0]
    
    metrics = [ci, ibs, mae_hinge, mae_margin, mae_pseudo, global_ci, local_ci, d_calib]
    print(f'E{event_id+1}: ' + f'{metrics}')

E1: [0.6473869482805308, 0.09973084075206892, 1.4692220771505582, 1.789775818310862, 2.1172795637266755, 0.6492982528336503, 0.656479217603912, 0.5174783063076376]
E2: [0.6534588275188322, 0.09880397663715398, 1.4844745602293303, 1.92568067659878, 2.3769391677369254, 0.6492982528336503, 0.656479217603912, 0.15314234845264074]
E3: [0.6472322127790788, 0.10115955313497431, 1.4556439486280301, 1.480436909290556, 1.505255749855323, 0.6492982528336503, 0.656479217603912, 0.4807914868307198]
