# Conditional Kaplan Meier Estimation

Before running the notebook make sure that:

1. Created a python/conda environment according to the specifications
2. Set your working directory (`project_dir`) to the directory where you unpacked the .zip file
3. You might want to add a shebang line to the ./utils (however using it from notebooks with a specified environment should work as well)

The notebook genereates data, trains models and saves them into the ./data directory. Note that at times you will need to overwrite existing files (this is to ensure that you do not accidentally overwrite files that took a long time training - aka. model weights)

See R-Scrips for visualisations (by default, all visualisations should already be in the zip folder)

In [1]:
project_dir = '/Users/philippratz/Documents/Uni/PhD/UQAM/research/joint_estimation/cond_km/'

In [2]:
# Housekeeping
import os
os.chdir(project_dir)

# Wrangling
import numpy as np
import pandas as pd
from itertools import product as itp
import pickle

# Stats
from lifelines import KaplanMeierFitter, CoxPHFitter
from sklearn.preprocessing import StandardScaler

# TF & Keras
import tensorflow as tf

In [3]:
%load_ext autoreload
%autoreload 2

from main.utils.conditional_km import DeepKaplanMeier
from main.utils.metrics import calculate_concordance_surv
train_ = False

## Load train and test data

In [13]:
df = pd.read_csv('../kaplan_meier/data/application/data_econ_train.csv')
y_target = df.time.astype(int)
y_censoring = np.where(df.event == 1, False, True)

df_val = pd.read_csv('../kaplan_meier/data/application/data_econ_test.csv')
y_target_val = df_val.time.astype(int)
y_censoring_val = np.where(df_val.event == 1, False, True)

In [6]:
df = pd.read_csv('data/application/data_econ_train.csv')
y_target = df.time.astype(int)
y_censoring = np.where(df.event == 1, False, True)

df_val = pd.read_csv('data/application/data_econ_test.csv')
y_target_val = df_val.time.astype(int)
y_censoring_val = np.where(df_val.event == 1, False, True)

df_example = pd.read_csv('data/application/example_econ.csv')


In [14]:
# Scale features

numeric_ = df.loc[:, df.columns[4:-1]]
binary_ = np.where(df.fac_ui == 'yes', 1, 0).reshape(-1,1)

numeric_val = df_val.loc[:, df_val.columns[4:-1]]
binary_val = np.where(df_val.fac_ui == 'yes', 1, 0).reshape(-1,1)

numeric_example = df_example.loc[:, df_val.columns[4:-1]]
binary_example = np.where(df_example.fac_ui == 'yes', 1, 0).reshape(-1,1)

scaler = StandardScaler()
scaler.fit(numeric_)

numeric_trans = scaler.transform(numeric_)
numeric_trans_val = scaler.transform(numeric_val)
numeric_trans_example = scaler.transform(numeric_example)

features = np.concatenate([numeric_trans, binary_], axis=1)
features_val = np.concatenate([numeric_trans_val, binary_val], axis=1)
features_example = np.concatenate([numeric_trans_example, binary_example], axis=1)

In [15]:
# Define model inputs
total_periods = df.time.max()
input_dim = (features.shape[1], )
input_shape = input_dim

In [16]:
init_mod = DeepKaplanMeier(int(total_periods))

In [17]:
y_matrix, weight_matrix = init_mod.prepare_survival(list(y_target), list(y_censoring))
y_matrix_val, weight_matrix_val = init_mod.prepare_survival(list(y_target_val), list(y_censoring_val))

train_dataset = tf.data.Dataset.from_tensor_slices(tuple([features] + [y_matrix.reshape(-1,total_periods+1)] + [weight_matrix.reshape(-1,total_periods+1)] + [y_target]))
val_dataset = tf.data.Dataset.from_tensor_slices(tuple([features_val] + [y_matrix_val.reshape(-1,total_periods+1)] + [weight_matrix_val.reshape(-1,total_periods+1)] + [y_target_val]))

Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



2022-10-20 18:40:42.418433: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-10-20 18:40:42.418547: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [18]:
trial_dict = {}

In [10]:
if train_:
    input_shape = input_dim

    total_trials = 150

    # Optimize layer 1-3 plus additional layer
    # Optimize learning rate
    # Optimize batchsize

    layer_sizes = np.arange(3,13, 3)
    learning_rate = [5e-02, 1e-02, 5e-03, 1e-03]
    batch_sizes = [32,64,128,256]

    np.random.seed(123)
    for trial_ in range(total_trials):
        trial_dict[f"trial_{trial_}"] = {}

        check_indices = np.ceil(np.random.uniform(size=6)*4)-1
        check_indices = check_indices.astype(int)

        hp_unit_1 = layer_sizes[check_indices[0]]
        hp_unit_2 = layer_sizes[check_indices[1]]
        hp_unit_3 = layer_sizes[check_indices[2]]
        hp_unit_4 = layer_sizes[check_indices[3]]

        hp_learning_rate = learning_rate[check_indices[4]]
        hp_batch = batch_sizes[check_indices[5]]

        trial_dict[f"trial_{trial_}"]['hp_unit_1'] = hp_unit_1
        trial_dict[f"trial_{trial_}"]['hp_unit_2'] = hp_unit_2
        trial_dict[f"trial_{trial_}"]['hp_unit_3'] = hp_unit_3
        trial_dict[f"trial_{trial_}"]['hp_unit_4'] = hp_unit_4
        trial_dict[f"trial_{trial_}"]['hp_learning_rate'] = hp_learning_rate
        trial_dict[f"trial_{trial_}"]['hp_batch'] = hp_batch
        trial_dict[f"trial_{trial_}"]['epochs'] = {}

        print(trial_dict[f"trial_{trial_}"])

        model_raw = init_mod.return_comp_graph_raw(input_shape=input_shape, 
                                                hidden_units=[hp_unit_1, hp_unit_2, hp_unit_3],
                                                preoutput=hp_unit_4)

        upper_bound = np.random.uniform(1,2.5)
        trial_dict[f"trial_{trial_}"]['weight'] = upper_bound
        weight_decrease = 1 + np.linspace(0,upper_bound,36)
        weight_array = weight_decrease

        #### Update variables ####
        
        batch_size=hp_batch
        optimizer = keras.optimizers.Adam(learning_rate=hp_learning_rate)
        loss_objective = keras.losses.BinaryCrossentropy()
        epoch_loss_metric = keras.metrics.Mean()

        train_ds = train_dataset.batch(batch_size)
        check_dataset = train_dataset.shuffle(10)
        check_dataset = check_dataset.take(648)
        check_ds = check_dataset.batch(648)

        best_epoch_loss = 0
        epoch_checker = 0
        val_ = 0

        for epoch in range(75):
            print(f"Epoch: {epoch}")
            if best_epoch_loss > val_:
                epoch_checker += 1
            else:
                epoch_checker = 0

            if epoch_checker > 10:
                continue
            
            for step, (features_, _labels, _weights, _) in enumerate(train_ds):
                with tf.GradientTape() as tape:
                        preds_ = model_raw(features_, training=True)
                        losses = [loss_objective(_labels[:,i], preds_[i], sample_weight=tf.reshape(_weights[:,i], (-1,1))) for i in range(len(preds_))]
                        losses = [tf.multiply(loss_, weight_) for loss_, weight_ in zip(losses, weight_array)]
                        gradients = tape.gradient(losses, model_raw.trainable_variables)
                        optimizer.apply_gradients(zip(gradients, model_raw.trainable_variables))

            for features_, _labels, _weights, true_ in check_ds:
                preds_test = model_raw(features_)
                preds_test = pd.DataFrame(list(map(np.ravel, preds_test)))

                nom_, denom_, val_ = calculate_concordance_surv(y_target_val, 
                                                                y_censoring_val, 
                                                                preds_test)

                
                epoch_loss_metric = val_
                
            trial_dict[f"trial_{trial_}"]['epochs'][f'epoch_{epoch}'] = epoch_loss_metric
            best_epoch_loss = max(best_epoch_loss, epoch_loss_metric)
            print(f'best validation result is: {best_epoch_loss}')
            if best_epoch_loss > 0.705 and best_epoch_loss == epoch_loss_metric:
                model_raw.save_weights(f'./data/models/check_unemp_{trial_}')

            epoch_loss_metric = 0


        trial_dict[f"trial_{trial_}"]['best_validation_result'] = best_epoch_loss
        with open('./model_tuning/manual_gridsearch/tuning_unemployment_data.pkl', 'wb') as con:
            pickle.dump(trial_dict, con)

In [19]:
hp_unit_1  = 12
hp_unit_2  = 6
hp_unit_3  = 12
hp_unit_4  = 9
hp_batch =  128
hp_learning_rate = 0.005
hp_epoch = 20
hp_weight = 0.15

model_raw = init_mod.return_comp_graph_raw(input_shape=input_shape, 
                                        hidden_units=[hp_unit_1, hp_unit_2, hp_unit_3],
                                        preoutput=hp_unit_4)


In [43]:
model_raw.load_weights('../kaplan_meier/data/models/model_weights_unemployment_applications')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x105be7040>

In [14]:
if train_:
    tf.keras.utils.set_random_seed(42)

    hp_unit_1  = 12
    hp_unit_2  = 6
    hp_unit_3  = 12
    hp_unit_4  = 9
    hp_batch =  128
    hp_learning_rate = 0.005
    hp_epoch = 20
    hp_weight = 0.15

    model_raw = init_mod.return_comp_graph_raw(input_shape=input_shape, 
                                            hidden_units=[hp_unit_1, hp_unit_2, hp_unit_3],
                                            preoutput=hp_unit_4)

    weight_decrease = 1 + np.linspace(0,hp_weight,36)
    weight_array = weight_decrease

    batch_size=hp_batch
    optimizer = tf.keras.optimizers.Adam(learning_rate=hp_learning_rate)
    loss_objective = tf.keras.losses.BinaryCrossentropy()
    epoch_loss_metric = tf.keras.metrics.Mean()
    train_ds = train_dataset.batch(batch_size)

    for epoch in range(hp_epoch):
        for step, (features_, _labels, _weights, _) in enumerate(train_ds):
            with tf.GradientTape() as tape:
                    preds_ = model_raw(features_, training=True)
                    losses = [loss_objective(_labels[:,i], preds_[i], sample_weight=tf.reshape(_weights[:,i], (-1,1))) for i in range(len(preds_))]
                    losses = [tf.multiply(loss_, weight_) for loss_, weight_ in zip(losses, weight_array)]
                    gradients = tape.gradient(losses, model_raw.trainable_variables)
                    optimizer.apply_gradients(zip(gradients, model_raw.trainable_variables))

    model_econ = model_raw

else:
    model_econ = tf.keras.models.load_model('data/model_weights/model_unemp_application.keras')



In [20]:
hp_unit_1  = 12
hp_unit_2  = 6
hp_unit_3  = 12
hp_unit_4  = 9
hp_batch =  128
hp_learning_rate = 0.001
hp_epoch = 40
hp_weight = 1.5

model_raw = init_mod.return_comp_graph_raw(input_shape=input_shape, 
                                        hidden_units=[hp_unit_1, hp_unit_2, hp_unit_3],
                                        preoutput=hp_unit_4)

weight_decrease = 1 + np.linspace(0,hp_weight,36)
weight_array = weight_decrease

batch_size=hp_batch
optimizer = tf.keras.optimizers.Adam(learning_rate=hp_learning_rate)
loss_objective = tf.keras.losses.BinaryCrossentropy()
epoch_loss_metric = tf.keras.metrics.Mean()
train_ds = train_dataset.batch(batch_size)

for epoch in range(hp_epoch):
    for step, (features_, _labels, _weights, _) in enumerate(train_ds):
        with tf.GradientTape() as tape:
                preds_ = model_raw(features_, training=True)
                losses = [loss_objective(_labels[:,i], preds_[i], sample_weight=tf.reshape(_weights[:,i], (-1,1))) for i in range(len(preds_))]
                losses = [tf.multiply(loss_, weight_) for loss_, weight_ in zip(losses, weight_array)]
                gradients = tape.gradient(losses, model_raw.trainable_variables)
                optimizer.apply_gradients(zip(gradients, model_raw.trainable_variables))

    preds_test = model_raw(features_val)
    preds_test = pd.DataFrame(list(map(np.ravel, preds_test)))
    nom_, denom_, val_ = calculate_concordance_surv(y_target_val, 
                                                    y_censoring_val, 
                                                    preds_test)
    print(epoch, val_)

0 0.4796301562648116
1 0.4809864444104157
2 0.47586025856977227
3 0.4778801070446773
4 0.48280941234805563
5 0.488300191776227
6 0.49301803279883916
7 0.49834839104849826
8 0.5027089303553329
9 0.5071569721231743
10 0.5112914634057416
11 0.5151634473052888
12 0.5187145888478113
13 0.5217480074960441
14 0.5237751478426997
15 0.5253501921408207
16 0.5274356674614807
17 0.5290909223488577
18 0.5300169900611787
19 0.5308993065429966
20 0.5321316328688411
21 0.5330358249659105
22 0.5335170885014474
23 0.534005643908735
24 0.53407127075449
25 0.5343410700092607
26 0.5346619123662854
27 0.5349462953645572
28 0.5352015108758267
29 0.5351723433888245
30 0.5352817214650829
31 0.5359452817943838
32 0.5361057029728961
33 0.5359234061791321
34 0.5361713298186511
35 0.5365286315344285
36 0.5370463544287183
37 0.5374765748620013
38 0.5375786610665092
39 0.5377463741167721


In [44]:
preds_test = model_econ(features_val)
preds_test = pd.DataFrame(list(map(np.ravel, preds_test)))
nom_, denom_, val_ = calculate_concordance_surv(y_target_val, 
                                                y_censoring_val, 
                                                preds_test)
print(f"Concordance is: {np.round(val_,2)}")

Concordance is: 0.49


In [None]:
df_cox = df.drop(columns=['rowid', 'pid'])
df_cox['fac_ui'] = np.where(df_cox.fac_ui == 'yes', 1, 0)

df_cox_test = df_val.drop(columns=['rowid', 'pid'])
df_cox_test['fac_ui'] = np.where(df_cox_test.fac_ui == 'yes', 1, 0)

In [None]:
cph = CoxPHFitter()
cph.fit(df_cox, duration_col='time', event_col='event')

In [None]:
preds_cox = cph.predict_survival_function(df_cox_test)

nom_, denom_, val_ = calculate_concordance_surv(y_target_val, 
                                                y_censoring_val, 
                                                preds_cox)
print(f"Concordance is: {np.round(val_,2)}")

Concordance is: 0.69


In [None]:
preds_example = model_econ(features_example)
preds_example = pd.DataFrame(list(map(np.ravel, preds_example)))

preds_example.to_csv('data/application/econ_predictions_examples.csv', 
                     index=False)
