# Introduction

MENSA is a deep learning model that jointly models flexible time-to-event distributions for multiple events, whether competing or co-occurring. The data format is as follows:

For single-event: X = [x_1, x_2, ... x_n], T = [t_1, t_2, ..., t_i], E = [e_1, e_2, ... e_i]\
For competing risks: X = [x_1, x_2, ... x_n], T = [t_1, t_2, ..., t_i], E = [e_1, e_2, ... e_i]\
For multi-event: X = [x_1, x_2, ... x_n], T = [[t_i1, t_i2, ..., t_ik], ...], E = [[e_i1, e_i2, ..., e_ik], ...]

Here, $n$ is the number of covariates, $i$ is the subject and $k$ denotes the number of events.\
The demo uses a synthetic data generator (DGP) for better reproducibility.

In [1]:
# 3rd party
import pandas as pd
import numpy as np
import config as cfg
import torch
import random
from SurvivalEVAL.Evaluator import LifelinesEvaluator

# Local
from data_loader import (SingleEventSyntheticDataLoader,
                         CompetingRiskSyntheticDataLoader,
                         MultiEventSyntheticDataLoader)
from utility.survival import make_time_bins
from utility.config import load_config
from utility.evaluation import global_C_index, local_C_index
from mensa.model import MENSA

np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

# Setup precision
dtype = torch.float32
torch.set_default_dtype(dtype)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')        

  from .autonotebook import tqdm as notebook_tqdm


# Single-event prediction

We generate a synthetic single-event dataset from a linear Weibull DGP.\
This generates one event and a censoring event.

See the concrete implementation for details.

In [2]:
# Load synthetic data for single-event case
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_se.yaml")
dl = SingleEventSyntheticDataLoader().load_data(data_config=data_config,
                                                linear=True, copula_name="",
                                                k_tau=0, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]
n_events = dl.n_events

In [3]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Define the model
config = load_config(cfg.MENSA_CONFIGS_DIR, f"synthetic.yaml")
n_epochs = config['n_epochs']
n_dists = config['n_dists']
lr = config['lr']
batch_size = config['batch_size']
layers = config['layers']
weight_decay = config['weight_decay']
dropout_rate = config['dropout_rate']
model = MENSA(n_features, layers=layers, dropout_rate=dropout_rate,
              n_events=n_events, n_dists=n_dists, device=device)

# Train the model
model.fit(train_dict, valid_dict, learning_rate=lr, n_epochs=n_epochs,
          weight_decay=weight_decay, patience=20,
          batch_size=batch_size, verbose=True)

[Epoch   65/1000]:   6%|▋         | 64/1000 [02:09<31:32,  2.02s/it, Training loss = 1.6314, Validation loss = 1.6341]

Early stopping at iteration 64, best valid loss: 1.6266281269490719





In [4]:
# Get predictions for the single event
model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=1)
model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())

# Use SurvivalEVAL package to calculate popular metrics
y_train_time = train_dict['T']
y_train_event = (train_dict['E'])*1.0
y_test_time = test_dict['T']
y_test_event = (test_dict['E'])*1.0
lifelines_eval = LifelinesEvaluator(model_preds.T, y_test_time, y_test_event,
                                    y_train_time, y_train_event)

ci = lifelines_eval.concordance()[0]
ibs = lifelines_eval.integrated_brier_score(num_points=len(time_bins))
mae_hinge = lifelines_eval.mae(method="Hinge")
mae_margin = lifelines_eval.mae(method="Margin")
mae_pseudo = lifelines_eval.mae(method="Pseudo_obs")
d_calib = lifelines_eval.d_calibration()[0]

metrics = [ci, ibs, mae_hinge, mae_margin, mae_pseudo, d_calib]
print("Event 1: " + str(metrics))

Event 1: [0.8264203032510422, 0.14102957628774732, 0.7430959016409477, 4.057748994853522, 8.675915396351094, 0.9919812897555466]


# Competing-risks prediction

We generate a synthetic competing risks (K=2) dataset from a linear Weibull DGP.\
This generates two actual competing events and a censoring event.

See the concrete implementation for details.

In [5]:
# Load synthetic data for competing risks case
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_cr.yaml")
dl = CompetingRiskSyntheticDataLoader().load_data(data_config, k_tau=0, copula_name="",
                                                  linear=True, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]
n_events = dl.n_events

In [6]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Define the model
config = load_config(cfg.MENSA_CONFIGS_DIR, f"synthetic.yaml")
n_epochs = config['n_epochs']
n_dists = config['n_dists']
lr = config['lr']
batch_size = config['batch_size']
layers = config['layers']
weight_decay = config['weight_decay']
dropout_rate = config['dropout_rate']
model = MENSA(n_features, layers=layers, dropout_rate=dropout_rate,
              n_events=n_events, n_dists=n_dists, device=device)

# Train the model
model.fit(train_dict, valid_dict, learning_rate=lr, n_epochs=n_epochs,
          weight_decay=weight_decay, patience=20,
          batch_size=batch_size, verbose=True)

[Epoch   53/1000]:   5%|▌         | 52/1000 [02:28<45:06,  2.86s/it, Training loss = 3.0400, Validation loss = 3.0619] 

Early stopping at iteration 52, best valid loss: 3.037816122174263





In [7]:
# Make predictions for competing risks
all_preds = []
for i in range(n_events):
    model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=i+1) # skip censoring event
    model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())
    all_preds.append(model_preds)
    
# Calculate local and global CI
y_test_time = np.stack([test_dict['T'] for _ in range(n_events)], axis=1)
y_test_event = np.stack([np.array((test_dict['E'] == i+1)*1.0) for i in range(n_events)], axis=1)
all_preds_arr = [df.to_numpy() for df in all_preds]
global_ci = global_C_index(all_preds_arr, y_test_time, y_test_event)
local_ci = local_C_index(all_preds_arr, y_test_time, y_test_event)

# Use SurvivalEVAL package to calculate popular metrics
for event_id, surv_preds in enumerate(all_preds):
    n_train_samples = len(train_dict['X'])
    n_test_samples= len(test_dict['X'])
    y_train_time = train_dict['T']
    y_train_event = (train_dict['E'])*1.0
    y_test_time = test_dict['T']
    y_test_event = (test_dict['E'])*1.0
    
    lifelines_eval = LifelinesEvaluator(surv_preds.T, y_test_time, y_test_event,
                                        y_train_time, y_train_event)
    
    ci =  lifelines_eval.concordance()[0]
    ibs = lifelines_eval.integrated_brier_score(num_points=len(time_bins))
    mae_hinge = lifelines_eval.mae(method="Hinge")
    mae_margin = lifelines_eval.mae(method="Margin")
    mae_pseudo = lifelines_eval.mae(method="Pseudo_obs")
    d_calib = lifelines_eval.d_calibration()[0]
    
    metrics = [ci, ibs, mae_hinge, mae_margin, mae_pseudo, global_ci, local_ci, d_calib]
    print(f'Event {event_id+1}: ' + f'{metrics}')

Event 1: [0.8043287834385147, 0.05231590006247374, 2.309306383357845, 3.02639914112114, 3.5447934641750756, 0.8050830495826518, 0.5, 1.160900001417501e-48]
Event 2: [0.8050663704787273, 0.05061734354618599, 2.227111067519106, 2.9408706216240192, 3.4670379015696806, 0.8050830495826518, 0.5, 8.201324032156303e-41]


# Multi-event prediction

We generate a synthetic multi-event dataset from a linaer Weibull DGP.\
This generates four events that are not mutually exclusive.

See the concrete implementation for details.

In [8]:
# Load and split data
data_config = load_config(cfg.DGP_CONFIGS_DIR, f"synthetic_me.yaml")
dl = MultiEventSyntheticDataLoader().load_data(data_config, k_taus=[0, 0, 0, 0], copula_names=[],
                                               linear=True, device=device, dtype=dtype)
train_dict, valid_dict, test_dict = dl.split_data(train_size=0.7, valid_size=0.1, test_size=0.2,
                                                  random_state=0)
n_features = train_dict['X'].shape[1]
n_events = dl.n_events

In [9]:
# Make time bins
time_bins = make_time_bins(train_dict['T'], event=None, dtype=dtype).to(device)
time_bins = torch.cat((torch.tensor([0]).to(device), time_bins))

# Define the model
config = load_config(cfg.MENSA_CONFIGS_DIR, f"synthetic.yaml")
n_epochs = config['n_epochs']
n_dists = config['n_dists']
lr = config['lr']
batch_size = config['batch_size']
layers = config['layers']
weight_decay = config['weight_decay']
dropout_rate = config['dropout_rate']
model = MENSA(n_features, layers=layers, dropout_rate=dropout_rate,
              n_events=n_events, n_dists=n_dists, device=device)

# Train the model
model.fit(train_dict, valid_dict, learning_rate=lr, n_epochs=n_epochs,
          weight_decay=weight_decay, patience=20,
          batch_size=batch_size, verbose=True)

  0%|          | 0/1000 [00:00<?, ?it/s]

[Epoch   53/1000]:   5%|▌         | 52/1000 [03:13<58:43,  3.72s/it, Training loss = 1.0020, Validation loss = 1.1343]  

Early stopping at iteration 52, best valid loss: 1.1037923842668533





In [10]:
# Make predictions for multi-event
all_preds = []
for i in range(n_events):
    model_preds = model.predict(test_dict['X'].to(device), time_bins, risk=i+1) # skip censoring event
    model_preds = pd.DataFrame(model_preds, columns=time_bins.cpu().numpy())
    all_preds.append(model_preds)

# Calculate local and global CI
all_preds_arr = [df.to_numpy() for df in all_preds]
global_ci = global_C_index(all_preds_arr, test_dict['T'].numpy(), test_dict['E'].numpy())
local_ci = local_C_index(all_preds_arr, test_dict['T'].numpy(), test_dict['E'].numpy())

# Use SurvivalEVAL package to calculate popular metrics
for event_id, surv_preds in enumerate(all_preds):
    n_train_samples = len(train_dict['X'])
    n_test_samples= len(test_dict['X'])
    y_train_time = train_dict['T'][:,event_id]
    y_train_event = train_dict['E'][:,event_id]
    y_test_time = test_dict['T'][:,event_id]
    y_test_event = test_dict['E'][:,event_id]
    
    lifelines_eval = LifelinesEvaluator(surv_preds.T, y_test_time, y_test_event,
                                        y_train_time, y_train_event)
    
    ci =  lifelines_eval.concordance()[0]
    ibs = lifelines_eval.integrated_brier_score(num_points=len(time_bins))
    mae_hinge = lifelines_eval.mae(method="Hinge")
    mae_margin = lifelines_eval.mae(method="Margin")
    mae_pseudo = lifelines_eval.mae(method="Pseudo_obs")
    d_calib = lifelines_eval.d_calibration()[0]
    
    metrics = [ci, ibs, mae_hinge, mae_margin, mae_pseudo, global_ci, local_ci, d_calib]
    print(f'Event {event_id+1}: ' + f'{metrics}')

Event 1: [0.5455040220418196, 0.10427038499911316, 3.2357458968808084, 3.246547542048791, 3.2573994274212748, 0.6113270885405223, 0.9583267059010657, 5.881630695454767e-51]
Event 2: [0.8654050593507538, 0.005448711502378367, 9.722952147316132, 436954.7114455938, 436878.8745664984, 0.6113270885405223, 0.9583267059010657, 0.999999736778676]
Event 3: [0.8770241386520456, 0.005398366164207229, 6.828169383327557, 54796.13677355559, 54713.190123198525, 0.6113270885405223, 0.9583267059010657, 0.9999352277056895]
Event 4: [0.84814308118389, 0.0063197004172375094, 26.005477683199928, 116247.00586593169, 116177.28249421371, 0.6113270885405223, 0.9583267059010657, 0.9998765266341876]
