In [23]:
import tensorflow as tf
import numpy as np
import pandas as pd
import random

from sklearn.model_selection import train_test_split
matplotlib_style = 'fivethirtyeight'
import matplotlib.pyplot as plt; plt.style.use(matplotlib_style)

from tools.baysurv_trainer import Trainer
from utility.config import load_config
from utility.training import get_data_loader, scale_data, split_time_event
from tools.baysurv_builder import make_mlp_model, make_vi_model, make_mcd_model, make_sngp_model
from utility.risk import InputFunction
from utility.loss import CoxPHLoss, CoxPHLossGaussian
from pathlib import Path
import paths as pt
from utility.survival import (calculate_event_times, calculate_percentiles, convert_to_structured,
                              compute_deterministic_survival_curve, compute_nondeterministic_survival_curve)
from utility.training import make_stratified_split
from time import time
from tools.evaluator import LifelinesEvaluator
from pycox.evaluation import EvalSurv
import math
from utility.survival import coverage
from scipy.stats import chisquare
import torch
from utility.survival import survival_probability_calibration
from tools.Evaluations.util import make_monotonic, check_monotonicity
from utility.survival import make_time_bins
from utility.loss import cox_nll_tf

class _TFColor(object):
    """Enum of colors used in TF docs."""
    red = '#F15854'
    blue = '#5DA5DA'
    orange = '#FAA43A'
    green = '#60BD68'
    pink = '#F17CB0'
    brown = '#B2912F'
    purple = '#B276B2'
    yellow = '#DECF3F'
    gray = '#4D4D4D'
    def __getitem__(self, i):
        return [
            self.red,
            self.orange,
            self.green,
            self.blue,
            self.pink,
            self.brown,
            self.purple,
            self.yellow,
            self.gray,
        ][i % 9]
TFColor = _TFColor()

N_SAMPLES_TRAIN = 10
N_SAMPLES_TEST = 1000

In [24]:
dataset_name = "MIMIC"

# Load training parameters
config = load_config(pt.MLP_CONFIGS_DIR, f"{dataset_name.lower()}.yaml")
optimizer = tf.keras.optimizers.deserialize(config['optimizer'])
activation_fn = config['activiation_fn']
layers = config['network_layers']
l2_reg = config['l2_reg']
batch_size = config['batch_size']
early_stop = config['early_stop']
patience = config['patience']
n_samples_train = config['n_samples_train']
n_samples_valid = config['n_samples_valid']
n_samples_test = config['n_samples_test']

# Load data
dl = get_data_loader(dataset_name).load_data()
num_features, cat_features = dl.get_features()
df = dl.get_data()

# Split data
df_train, df_valid, df_test = make_stratified_split(df, stratify_colname='both', frac_train=0.7,
                                                frac_valid=0.1, frac_test=0.2, random_state=0)
X_train = df_train[cat_features+num_features]
X_valid = df_valid[cat_features+num_features]
X_test = df_test[cat_features+num_features]
y_train = convert_to_structured(df_train["time"], df_train["event"])
y_valid = convert_to_structured(df_valid["time"], df_valid["event"])
y_test = convert_to_structured(df_test["time"], df_test["event"])

# Scale data
X_train, X_valid, X_test = scale_data(X_train, X_valid, X_test, cat_features, num_features)

# Convert to array
X_train = np.array(X_train)
X_valid = np.array(X_valid)
X_test = np.array(X_test)

# Make time/event split
t_train, e_train = split_time_event(y_train)
t_valid, e_valid = split_time_event(y_valid)
t_test, e_test = split_time_event(y_test)

# Make event times
time_bins = make_time_bins(t_train, event=e_train)

# Calculate quantiles
event_times_pct = calculate_percentiles(time_bins)

# Make data loaders
train_ds = InputFunction(X_train, t_train, e_train, batch_size=batch_size, drop_last=True, shuffle=True)()
valid_ds = InputFunction(X_valid, t_valid, e_valid, batch_size=batch_size, shuffle=True)() # shuffle=True to avoid NaNs
test_ds = InputFunction(X_test, t_test, e_test, batch_size=batch_size)()

# Make models
models = ["mlp", "sngp", "vi", "mcd1", "mcd2", "mcd3"]
for model_name in models:
    if model_name == "mlp":
        dropout_rate = config['dropout_rate']
        model = make_mlp_model(input_shape=X_train.shape[1:], output_dim=1,
                                layers=layers, activation_fn=activation_fn,
                                dropout_rate=dropout_rate, regularization_pen=l2_reg)
    elif model_name == "sngp":
        dropout_rate = config['dropout_rate']
        model = make_sngp_model(input_shape=X_train.shape[1:], output_dim=1,
                                layers=layers, activation_fn=activation_fn,
                                dropout_rate=dropout_rate, regularization_pen=l2_reg)
    elif model_name == "vi":
        dropout_rate = config['dropout_rate']
        model = make_vi_model(n_train_samples=len(X_train),
                                input_shape=X_train.shape[1:], output_dim=2,
                                layers=layers, activation_fn=activation_fn,
                                dropout_rate=dropout_rate, regularization_pen=l2_reg)
    elif model_name == "mcd1":
        dropout_rate = 0.1
        model = make_mcd_model(input_shape=X_train.shape[1:], output_dim=2,
                                layers=layers, activation_fn=activation_fn,
                                dropout_rate=dropout_rate, regularization_pen=l2_reg)
    elif model_name == "mcd2":
        dropout_rate = 0.2
        model = make_mcd_model(input_shape=X_train.shape[1:], output_dim=2,
                                layers=layers, activation_fn=activation_fn,
                                dropout_rate=dropout_rate, regularization_pen=l2_reg)
    elif model_name == "mcd3":
        dropout_rate = 0.5
        model = make_mcd_model(input_shape=X_train.shape[1:], output_dim=2,
                                layers=layers, activation_fn=activation_fn,
                                dropout_rate=dropout_rate, regularization_pen=l2_reg)
    else:
        raise ValueError("Model not found")
    
    n_params = np.sum([np.prod(v.shape) for v in model.trainable_variables])
    print(f"Model name: {model_name} - #Params: {n_params}")



Model name: mlp - #Params: 2849
Model name: sngp - #Params: 3776


  loc = add_variable_fn(
  untransformed_scale = add_variable_fn(


Model name: vi - #Params: 5700
Model name: mcd1 - #Params: 2882
Model name: mcd2 - #Params: 2882
Model name: mcd3 - #Params: 2882


In [39]:
from tools.sota_builder import make_baymtlr_model, make_baycox_model
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

num_features = X_train.shape[1]
config = dotdict(load_config(pt.BAYMTLR_CONFIGS_DIR, f"{dataset_name.lower()}.yaml"))
config['hidden_size'] = 32
model = make_baymtlr_model(num_features, time_bins, config)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

10718


In [41]:
config = dotdict(load_config(pt.BAYCOX_CONFIGS_DIR, f"{dataset_name.lower()}.yaml"))
config['hidden_size'] = 32
model = make_baycox_model(num_features, config)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

5570
