In [1]:
import numpy as np
import pandas as pd
import warnings
import scipy
import itertools
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

from sklearn.linear_model import LogisticRegression
from IPython.display import display
from prediction_utils.pytorch_utils.metrics import (
    StandardEvaluator
)
from prediction_utils.util import patient_split_cv

from prediction_utils.pytorch_utils.survival import (
    DiscreteTimeArrayLoaderGenerator,
    DiscreteTimeNNet,
)


from lifelines import KaplanMeierFitter

In [2]:
def simulation(
    n=10000, 
    num_features=50,
    survival_group_means = [4, 15],
    censoring_group_means = [5, 5],
    binary_event_horizon=10,
    max_weight=100
):
    eps = 1e-6
    # Get randomness
    generator = np.random.default_rng()
    
    # Convert inputs to numpy
    survival_group_means = np.array(survival_group_means)
    censoring_group_means = np.array(censoring_group_means)
    num_groups = len(survival_group_means)
    num_groups = len(survival_group_means)
    assert num_groups == len(censoring_group_means)

    group = generator.integers(num_groups, size=n)
    
    features = generator.standard_normal(size=(n, num_features))
    weights = generator.standard_normal(size=(num_features, num_groups))
    
    survival_mean_shifts = survival_group_means[group]
    survival_means = np.maximum(
        np.dot(features, weights)[np.arange(n), group].reshape(-1) + survival_mean_shifts, eps
    )
    
    survival_times = generator.exponential(survival_means)
    
    censoring_mean_shifts = censoring_group_means[group]
    censoring_means = np.maximum(
        np.dot(features, weights)[np.arange(n), group].reshape(-1) + censoring_mean_shifts, eps
    )
    
    censoring_times = generator.exponential(censoring_means)
    
    # The observed times
    observed_times = np.minimum(survival_times, censoring_times)
    
    # The observed times when also censored by the binary event horizon
    binary_observed_times = np.minimum(observed_times, binary_event_horizon)
    
    # Indicators for whether the observed time is an event or censoring
    event_indicator = survival_times <= censoring_times
    censored_indicator = 1-event_indicator
    
    survival_function = np.exp(-observed_times / survival_means)
    
    # The survival function for the censoring process
    censoring_survival_function = np.exp(-observed_times / (censoring_means))
    
    # The survival function for the censoring process when also censored by the binary event horizon
    binary_censoring_survival_function = np.exp(-binary_observed_times / (censoring_means))
    
    # The true value of the binary outcome Y
    true_labels = (survival_times < binary_event_horizon)
    
    labels = true_labels
    
    # An indicator for whether the binary outcome is observed
    binary_event_indicator = ((observed_times < binary_event_horizon) & (event_indicator)) | (binary_event_horizon <= observed_times)
    
    observed_labels = ((observed_times < binary_event_horizon) & (event_indicator))
    
    df = pd.DataFrame(
        {
            'row_id': np.arange(n),
            'survival_times': survival_times,
            'censoring_times': censoring_times,
            'observed_times': observed_times,
            'binary_observed_times': binary_observed_times,
            'group': group,
            'event_indicator': event_indicator,
            'censored_indicator': censored_indicator,
            'binary_event_indicator': binary_event_indicator,
            'true_labels': true_labels,
            'labels': labels,
            'observed_labels': observed_labels,
            'survival_function': survival_function,
            'censoring_survival_function': censoring_survival_function,
            'censoring_weight': np.minimum(1 / censoring_survival_function, max_weight),
            'binary_censoring_weight': np.minimum(1 / binary_censoring_survival_function, max_weight)
        }
    )
    
    return df, features
    
df, features = simulation(
    100000, 
    max_weight=1e6,
    survival_group_means = [5, 12],
    censoring_group_means = [20, 15]
)

fold_id = '1'
df = patient_split_cv(
    df, patient_col='row_id'
)

In [3]:
# Group-stratified KM
km_fitters = {}
group_df_dict = {}
for _group, df_group in df.groupby('group'):
    km_fitter = KaplanMeierFitter()
    df_group_train = df_group.query('fold_id != "test" & fold_id != "eval"')
    km_fitters[_group] = km_fitter.fit(df_group_train.observed_times, 1-df_group_train.event_indicator)
    group_df_dict[_group] = df_group.assign(km_weights_group = lambda x: 1 / km_fitters[_group].survival_function_at_times(x.binary_observed_times.values).values)
df_group = pd.concat(group_df_dict).sort_values('row_id').reset_index(drop=True)
df_group

Unnamed: 0,row_id,survival_times,censoring_times,observed_times,binary_observed_times,group,event_indicator,censored_indicator,binary_event_indicator,true_labels,labels,observed_labels,survival_function,censoring_survival_function,censoring_weight,binary_censoring_weight,fold_id,km_weights_group
0,0,5.193076,26.239883,5.193076,5.193076,1,True,0,True,True,True,True,0.647141,0.706268,1.415894,1.415894,2,1.428036
1,1,12.663557,24.521198,12.663557,10.000000,1,True,0,True,False,False,False,0.362335,0.441150,2.266803,1.908353,9,1.882718
2,2,0.572053,11.181660,0.572053,0.572053,0,True,0,True,True,True,True,0.738564,0.966693,1.034454,1.034454,2,1.027892
3,3,35.114930,2.966007,2.966007,2.966007,1,False,1,False,False,False,False,0.747098,0.798389,1.252522,1.252522,eval,1.246246
4,4,1.362833,3.420894,1.362833,1.362833,1,True,0,True,True,True,True,0.925098,0.935696,1.068723,1.068723,9,1.121121
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,99995,2.361355,16.958943,2.361355,2.361355,0,True,0,True,True,True,True,0.662207,0.892333,1.120658,1.120658,7,1.109053
99996,99996,8.458053,32.011096,8.458053,8.458053,0,True,0,True,True,True,True,0.268481,0.673920,1.483856,1.483856,test,1.408729
99997,99997,4.952322,0.258784,0.258784,0.258784,1,False,1,False,True,True,False,0.919396,0.958326,1.043487,1.043487,4,1.035500
99998,99998,11.101113,144.777931,11.101113,10.000000,1,True,0,True,False,False,False,0.607742,0.644725,1.551050,1.484970,10,1.882718


In [4]:
df = df_group.merge(df_group[['row_id', 'km_weights_group']])

In [5]:
# Overall KM
km_fitter = KaplanMeierFitter()
km_fitter.fit(df.observed_times, 1-df.event_indicator)
df = df.assign(km_weights = lambda x: 1 / km_fitter.survival_function_at_times(x.binary_observed_times.values).values)

In [6]:
display(
    df
    .groupby('group')
    .agg(
        mean_survival = ('survival_times', 'mean'), 
        prob_y = ('true_labels', 'mean'),
        prob_censored = ('binary_event_indicator', lambda x: 1-x.mean())
    )
)

display(
    df
    .query('binary_event_indicator == 1')
    .groupby('group')
    .apply(
        lambda x:
        pd.DataFrame(
            {
            'prob_y_km': np.average(x.observed_labels, weights=x.km_weights),
            'prob_y_biased': np.average(x.observed_labels),
            },
            index=[x.name]
        )
    )
    .reset_index(level=-1, drop=True)
)

Unnamed: 0_level_0,mean_survival,prob_y,prob_censored
group,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,6.112219,0.804061,0.134945
1,12.177618,0.616108,0.313957


Unnamed: 0_level_0,prob_y_km,prob_y_biased
group,Unnamed: 1_level_1,Unnamed: 2_level_1
0,0.781966,0.844421
1,0.600699,0.681778


In [7]:
features_w_group = np.concatenate(
    (
        features, 
        df.group.values.reshape(-1, 1)
    ), 
    axis=1
)

In [8]:
def fit_and_eval_discrete_time_model(
    features, cohort, time_horizon, censoring_model=True, **config_dict
):
    cohort = cohort.copy()
    if censoring_model:
        cohort["event_indicator"] = 1 - cohort["event_indicator"]
    loader_generator = DiscreteTimeArrayLoaderGenerator(
        features=features, cohort=cohort, **config_dict,
    )
    loaders = loader_generator.init_loaders()
    bin_transformer = loader_generator.config_dict["bin_transformer"]
    config_dict["output_dim"] = len(bin_transformer.kw_args["bins"]) - 1
    model = DiscreteTimeNNet(transformer=bin_transformer, **config_dict)
    model.train(loaders, **config_dict)
    output_dict = model.predict(
        loader_generator.init_loaders_predict(),
        phases=["val", "eval", "test"],
        time_horizon=time_horizon,
    )
    return output_dict["outputs"]

In [9]:
config_dict = {
    "label_col": "observed_times",
    "event_indicator_var_name": "event_indicator",
    "num_bins": 20,
    "row_id_col": "row_id",
    "fold_id_test": ["eval", "test"],
    "disable_metric_logging": True,
    "num_hidden": 1,
    "hidden_dim": 128,
    "lr": 1e-3,
    "num_epochs": 20,
    "batch_size": 512,
    "early_stopping": True,
    "early_stopping_patience": 10,
    "input_dim": features_w_group.shape[1],
    "sparse_mode": None,
}

In [10]:
censoring_prob_df = fit_and_eval_discrete_time_model(
    features=features_w_group.astype(np.float32),
    cohort=df,
    time_horizon=10,
    censoring_model=True,
    fold_id=fold_id,
    **config_dict,
)

cuda
Epoch 0/19
----------
Phase: train:
  metric  performance
0   loss     3.178398
Phase: val:
  metric  performance
0   loss     1.393801
Best model updated
Epoch 1/19
----------
Phase: train:
  metric  performance
0   loss     1.295392
Phase: val:
  metric  performance
0   loss     1.266229
Best model updated
Epoch 2/19
----------
Phase: train:
  metric  performance
0   loss     1.237616
Phase: val:
  metric  performance
0   loss     1.248923
Best model updated
Epoch 3/19
----------
Phase: train:
  metric  performance
0   loss     1.217943
Phase: val:
  metric  performance
0   loss     1.238964
Best model updated
Epoch 4/19
----------
Phase: train:
  metric  performance
0   loss     1.207182
Phase: val:
  metric  performance
0   loss     1.232368
Best model updated
Epoch 5/19
----------
Phase: train:
  metric  performance
0   loss     1.197948
Phase: val:
  metric  performance
0   loss     1.229234
Best model updated
Epoch 6/19
----------
Phase: train:
  metric  performance
0   los

In [11]:
df_with_weights=(
    censoring_prob_df
    .query('phase == "test"')
    .rename(columns={'pred_probs': 'censoring_survival_function_model'})
    .assign(ipcw_weight_model=lambda x: 1 / (x.censoring_survival_function_model))
    .merge(df)
)

display(
    df_with_weights
    .query('binary_event_indicator == 1')
    .groupby('group')
    .apply(
        lambda x:
        pd.DataFrame(
            {
                'prob_y_estimated_true': np.average(x.observed_labels, weights=x.binary_censoring_weight),
                'prob_y_estimated_model': np.average(x.observed_labels, weights=x.ipcw_weight_model),
                'prob_y_km': np.average(x.observed_labels, weights=x.km_weights),
                'prob_y_km_group': np.average(x.observed_labels, weights=x.km_weights_group),
                'prob_y_biased': np.average(x.observed_labels),
            },
            index=[x.name]
        )
    )
    .reset_index(level=-1, drop=True)
)

Unnamed: 0_level_0,prob_y_estimated_true,prob_y_estimated_model,prob_y_km,prob_y_km_group,prob_y_biased
group,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.819349,0.809663,0.799153,0.815589,0.857614
1,0.623819,0.602257,0.607011,0.594417,0.688023
