In [None]:
import pandas as pd
import numpy as np
import os
from time import time
from sklearn.model_selection import train_test_split
from pydts.examples_utils.generate_simulations_data import generate_quick_start_df
from pydts.examples_utils.plots import plot_example_pred_output
from pydts.examples_utils.plots import add_panel_text
from pydts.cross_validation import TwoStagesCV
from pydts.fitters import TwoStagesFitter, DataExpansionFitter

from pydts.data_generation import EventTimesSampler
from matplotlib import pyplot as plt
import warnings
import pickle
from copy import deepcopy
from sklearn.model_selection import KFold
pd.set_option("display.max_rows", 500)
warnings.filterwarnings('ignore')
%matplotlib inline
slicer = pd.IndexSlice

In [None]:
OUTPUT_DIR = '/home/tomer.me/DiscreteTimeSurvivalPenalization/output'

# Sampling data

In [None]:
n_cov = 100
beta1 = np.zeros(n_cov)
beta1[:5] = [1.2, 1.5, -1, -0.3, -1.2]
beta2 = np.zeros(n_cov)
beta2[:5] = [-1.2, 1, 1, -1, 1.4]


real_coef_dict = {
    "alpha": {
        1: lambda t: -3.4 - 0.1 * np.log(t),
        2: lambda t: -3.4 - 0.2 * np.log(t)
    },
    "beta": {
        1: beta1,
        2: beta2
    }
}

n_patients = 10000
d_times = 15
j_events = 2

ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)

seed = 0
means_vector = np.zeros(n_cov)
covariance_matrix = 0.4*np.identity(n_cov)
clip_value = 1.5

covariates = [f'Z{i + 1}' for i in range(n_cov)]

patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix,
                                                                                size=n_patients),
                                             columns=covariates))
patients_df.clip(lower= -1 * clip_value, upper=clip_value, inplace=True)
patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones_like(ets.times),
                                                   seed=seed + 1)
patients_df = ets.update_event_or_lof(patients_df)
patients_df.index.name = 'pid'
patients_df = patients_df.reset_index()

In [None]:
from pydts.examples_utils.plots import plot_events_occurrence
plot_events_occurrence(patients_df)

In [None]:
patients_df.groupby(['X', 'J']).size()

In [None]:
step = 0.25
penalizers = np.arange(-9, -3.6, step=step) 
n_splits = 5

cross_validators = {}

for idp, penalizer in enumerate(penalizers):
    print(f"Started Penalizer: {penalizer}, {idp+1}/{len(penalizers)}")
    fit_beta_kwargs = {
            'model_kwargs': {
            'penalizer': np.exp(penalizer),
            'l1_ratio': 1
        }
    }
    start = time()
    cross_validators[penalizer] = TwoStagesCV()
    cross_validators[penalizer].cross_validate(full_df=patients_df, n_splits=n_splits, seed=seed, nb_workers=1, 
                                               fit_beta_kwargs=fit_beta_kwargs,
                                               metrics=['PE', 'AUC', 'IAUC', 'GAUC'])
    end = time()
    print(f"Finished Penalizer: {penalizer}, {idp+1}/{len(penalizers)}, {int(end-start)} seconds")

In [None]:
start = time()
cross_validator_null = TwoStagesCV()
cross_validator_null.cross_validate(full_df=patients_df, n_splits=n_splits, seed=seed, nb_workers=1, 
                                    metrics=['PE', 'AUC', 'IAUC', 'GAUC'])
end = time()
print(f"Finished {int(end-start)} seconds")

In [None]:
counts = patients_df.groupby(['J', 'X']).size().unstack('J').fillna(0)

ticksize = 13
axes_title_fontsize = 15
legend_size = 12

risk_names = ['J=1', 'J=2']
risk_colors = ['tab:blue', 'tab:green']
risk_letters = ['d', 'e', 'f', 'g', 'h', 'i']
chosen_lambda = -5.75

fig, axes = plt.subplots(3, 3, figsize=(17, 14))

ax = axes[0, 0]
add_panel_text(ax, 'a')
ax.tick_params(axis='both', which='major', labelsize=ticksize)
ax.tick_params(axis='both', which='minor', labelsize=ticksize)
ax.set_xlabel(r'Log ($\lambda$)', fontsize=axes_title_fontsize)
ax.set_ylabel(r'Global AUC', fontsize=axes_title_fontsize)

penalizers_x, mean_gauc, std_gauc = [], [], []
for penalizer in sorted(cross_validators.keys()):
    ser = pd.Series(cross_validators[penalizer].global_auc)
    penalizers_x.append(penalizer)
    mean_gauc.append(ser.mean())
    std_gauc.append(ser.std())

ax.errorbar(penalizers_x, mean_gauc, yerr=std_gauc, fmt="o", color='g', alpha=0.5, label='With Penalization')
ax.axhline(pd.Series(cross_validator_null.global_auc).mean(), ls = '--', label='Without Penalization', color='tab:blue')
ax.axvline(chosen_lambda, color='brown', ls='-.', label=r'Chosen $\lambda$')
ax.legend(fontsize=legend_size)
#ax.set_ylim([0.53, 0.78])


ax = axes[0, 1]
add_panel_text(ax, 'b')
ax.tick_params(axis='both', which='major', labelsize=ticksize)
ax.tick_params(axis='both', which='minor', labelsize=ticksize)
ax.set_xlabel(r'Log ($\lambda$)', fontsize=axes_title_fontsize)
ax.set_ylabel(r'Integrated AUC', fontsize=axes_title_fontsize)

fig_mean = pd.DataFrame()
fig_std = pd.DataFrame()
for p in sorted(cross_validators.keys()):
    iauc_df = pd.DataFrame.from_dict(cross_validators[p].integrated_auc)
    mean_ser = pd.DataFrame.from_dict(cross_validators[p].integrated_auc).mean(axis=1)
    mean_ser.name = penalizer
    std_ser = pd.DataFrame.from_dict(cross_validators[p].integrated_auc).std(axis=1)
    std_ser.name = penalizer
    fig_mean = pd.concat([fig_mean, mean_ser], axis=1)
    fig_std = pd.concat([fig_std, std_ser], axis=1)

for risk in range(1,3):
    ax.errorbar(penalizers_x, fig_mean.loc[risk], yerr=fig_std.loc[risk], fmt="o", color=risk_colors[risk-1], alpha=0.5, label=f'{risk_names[risk-1]} - With Penalization')
    ax.axhline(pd.DataFrame.from_dict(cross_validator_null.integrated_auc).mean(axis=1).loc[risk], ls = '--', label=f'{risk_names[risk-1]} - Without Penalization', color=risk_colors[risk-1])
#ax.set_ylim([0.53, 0.78])
ax.axvline(chosen_lambda, color='brown', ls='-.', label=r'Chosen $\lambda$')
ax.legend(loc='lower left', fontsize=legend_size)

for risk in range(1, 3):
    for idp, penalizer in enumerate(cross_validators.keys()):

        tmp_j1_params_df = pd.DataFrame()
        for i_fold in range(n_splits):
            tmp_j1_params_df = pd.concat([tmp_j1_params_df, cross_validators[penalizer].models[i_fold].beta_models[risk].params_], axis=1)

        ser_1 = tmp_j1_params_df.mean(axis=1) 
        ser_1.name = penalizer

        if idp == 0:
            j1_params_df = ser_1.to_frame()
        else:
            j1_params_df = pd.concat([j1_params_df, ser_1], axis=1)


    ax = axes[1, risk-1]
    add_panel_text(ax, risk_letters[risk-1])
    ax.tick_params(axis='both', which='major', labelsize=ticksize)
    ax.tick_params(axis='both', which='minor', labelsize=ticksize)
    for i in range(len(j1_params_df)):
        ax.plot(penalizers_x, j1_params_df.iloc[i].values, lw=1)

        if i == 0:
            ax.set_ylabel(f'{n_splits}-Fold Mean Coefficient Value', fontsize=axes_title_fontsize)
            ax.set_xlabel(r'Log ($\lambda$)', fontsize=axes_title_fontsize)
            ax.set_title(rf'$\beta_{risk}$', fontsize=axes_title_fontsize)
            ax.axvline(chosen_lambda, color='brown', ls='-.', label=r'Chosen $\lambda$')

    ax = axes[risk-1, 2]
    
    for idp, penalizer in enumerate(cross_validators.keys()):
        tmp_ser = j1_params_df[penalizer].round(3)
        count = (tmp_ser.abs() > 0).sum()
        if idp == 0:
            ax.scatter(penalizer, count, color=risk_colors[risk-1], alpha=0.8, marker='P', label=f'{risk_names[risk-1]}')
        else:
            ax.scatter(penalizer, count, color=risk_colors[risk-1], alpha=0.8, marker='P')

ax = axes[0, 2]       
add_panel_text(ax, 'c')
ax.tick_params(axis='both', which='major', labelsize=ticksize)
ax.tick_params(axis='both', which='minor', labelsize=ticksize)
ax.set_xlabel(r'Log ($\lambda$)', fontsize=axes_title_fontsize)
ax.set_ylabel(f'Number of Non-Zero Coefficients', fontsize=axes_title_fontsize)
ax.axvline(chosen_lambda, color='brown', ls='-.', label=r'Chosen $\lambda$')
ax.axhline(5, color='k', alpha=0.5, label='True value', ls='--')
ax.legend(loc='upper right', fontsize=legend_size)
ax.set_ylim([0,103])

ax = axes[1, 2]       
add_panel_text(ax, 'f')
ax.tick_params(axis='both', which='major', labelsize=ticksize)
ax.tick_params(axis='both', which='minor', labelsize=ticksize)
ax.set_xlabel(r'Log ($\lambda$)', fontsize=axes_title_fontsize)
ax.set_ylabel(f'Number of Non-Zero Coefficients', fontsize=axes_title_fontsize)
ax.axvline(chosen_lambda, color='brown', ls='-.', label=r'Chosen $\lambda$')
ax.axhline(5, color='k', alpha=0.3, label='True value', ls='--')
ax.legend(loc='upper right', fontsize=legend_size)
ax.set_ylim([0,103])

for risk in range(1, 3):
    ax = axes[2, risk-1]
    add_panel_text(ax, risk_letters[3+risk-1])
    ax.tick_params(axis='both', which='major', labelsize=ticksize)
    ax.tick_params(axis='both', which='minor', labelsize=ticksize)
    mean_auc = cross_validators[chosen_lambda].results.loc[slicer['AUC', :, risk]].mean()
    std_auc = cross_validators[chosen_lambda].results.loc[slicer['AUC', :, risk]].std()
    ax.errorbar(mean_auc.index, mean_auc.values, yerr=std_auc.values, fmt="o", color=risk_colors[risk-1], alpha=0.8)
    ax.set_yticks(np.arange(0, 1.1, 0.1))
    ax.set_yticklabels([c.round(1) for c in np.arange(0, 1.1, 0.1)])
    ax.set_xlabel(r'Time', fontsize=axes_title_fontsize)
    ax.set_ylabel(f'AUC (t)', fontsize=axes_title_fontsize)
    ax.set_title(fr'{risk_names[risk-1]}, Log ($\lambda$) = {chosen_lambda}', fontsize=axes_title_fontsize)
    ax.set_ylim([0,1])
    ax.axhline(0.5, ls='--', color='k', alpha=0.5)
    ax.set_xticks(list(range(1, d_times+1)))
    ax.set_xticklabels(list(range(1, d_times+1)))
    
    ax2 = ax.twinx()
    ax2.bar(counts.index, counts[risk].values.squeeze(), color='r', alpha=0.8, width=0.4)
    ax2.set_ylabel('Number of observed events', fontsize=axes_title_fontsize, color='r')
    ax2.tick_params(axis='y', colors='r')
    ax2.set_ylim([0, 1800])
    ax2.tick_params(axis='both', which='major', labelsize=ticksize)
    ax2.tick_params(axis='both', which='minor', labelsize=ticksize)
    

ax = axes[2,2]
add_panel_text(ax, 'i')
mean_pe = cross_validators[chosen_lambda].results.loc[slicer['PE', :, 1]].mean()
std_pe = cross_validators[chosen_lambda].results.loc[slicer['PE', :, 1]].std()
ax.errorbar(mean_pe.index, mean_pe.values, yerr=std_pe.values, fmt="v", color=risk_colors[0], alpha=0.8, 
            label=risk_names[0])
mean_pe = cross_validators[chosen_lambda].results.loc[slicer['PE', :, 2]].mean()
std_pe = cross_validators[chosen_lambda].results.loc[slicer['PE', :, 2]].std()
ax.errorbar(mean_pe.index, mean_pe.values, yerr=std_pe.values, fmt="v", color=risk_colors[1], alpha=0.8,
            label=risk_names[1])
ax.set_ylabel('PE (t)', fontsize=axes_title_fontsize)
ax.set_ylim([0, 0.17])
ax.tick_params(axis='both', which='major', labelsize=ticksize)
ax.tick_params(axis='both', which='minor', labelsize=ticksize)
ax.legend(loc='lower right', fontsize=legend_size)
ax.set_xticks(list(range(1, d_times+1)))
ax.set_xticklabels(list(range(1, d_times+1)))
ax.set_title(fr'Log ($\lambda$) = {chosen_lambda}', fontsize=axes_title_fontsize)
ax.set_xlabel(r'Time', fontsize=axes_title_fontsize)

fig.tight_layout()

fig.savefig(os.path.join(OUTPUT_DIR, 'regularization_sim.png'), dpi=300)

In [None]:
total_positives_df = pd.DataFrame()

for risk in range(1, j_events+1):
    for idp, penalizer in enumerate(cross_validators.keys()):
        tmp_j1_params_df = pd.DataFrame()
        for i_fold in range(n_splits):
            tmp_j1_params_df = pd.concat([tmp_j1_params_df, cross_validators[penalizer].models[i_fold].beta_models[risk].params_], axis=1)

        ser_1 = tmp_j1_params_df.mean(axis=1) 
        ser_1.name = penalizer

        if idp == 0:
            j1_params_df = ser_1.to_frame()
        else:
            j1_params_df = pd.concat([j1_params_df, ser_1], axis=1)
            
        j1_params_df = j1_params_df.round(4)  
        
        true_positives = (j1_params_df.abs() > 0).iloc[:5].sum()
        true_positives.name = 'True Positives'
        false_positives = (j1_params_df.abs() > 0).iloc[5:].sum()
        false_positives.name = 'False Positives'

        positives_df = pd.concat([true_positives, false_positives], axis=1)
        positives_df.index.name = r'Log ($\lambda$)'
    total_positives_df = pd.concat([total_positives_df, pd.concat([positives_df], keys=[fr'$\beta_{risk}$'], axis=1)], axis=1)

total_positives_df

In [None]:
print(total_positives_df.to_latex(escape=False))

In [None]:
risk = 1
penalizer = chosen_lambda
tmp_j1_params_df = pd.DataFrame()
j1_params_df =  pd.DataFrame()
for i_fold in range(n_splits):
    tmp_j1_params_df = pd.concat([tmp_j1_params_df, cross_validators[penalizer].models[i_fold].beta_models[risk].params_], axis=1)

ser_1 = tmp_j1_params_df.mean(axis=1).round(4)  
ser_1.name = rf'$Log (\lambda$)={penalizer}'

true_values = pd.Series(beta1[:5], name='True Value', index=[f'Z{i}' for i in range(1,6)])

for i_fold in range(n_splits):
    tmp_j1_params_df = pd.concat([tmp_j1_params_df, cross_validator_null.models[i_fold].beta_models[risk].params_], axis=1)

null_ser = tmp_j1_params_df.mean(axis=1).round(4)  
null_ser.name = rf'Without Penalization'

values_df = pd.concat([true_values, null_ser.iloc[:5], ser_1.iloc[:5]], axis=1)
values_df

In [None]:
risk = 2
penalizer = chosen_lambda
tmp_j1_params_df = pd.DataFrame()
j1_params_df =  pd.DataFrame()
for i_fold in range(n_splits):
    tmp_j1_params_df = pd.concat([tmp_j1_params_df, cross_validators[penalizer].models[i_fold].beta_models[risk].params_], axis=1)

ser_1 = tmp_j1_params_df.mean(axis=1).round(4)  
ser_1.name = rf'$Log (\lambda$)={penalizer}'

true_values = pd.Series(beta2[:5], name='True Value', index=[f'Z{i}' for i in range(1,6)])

for i_fold in range(n_splits):
    tmp_j1_params_df = pd.concat([tmp_j1_params_df, cross_validator_null.models[i_fold].beta_models[risk].params_], axis=1)

null_ser = tmp_j1_params_df.mean(axis=1).round(4)  
null_ser.name = rf'Without Penalization'

values_df = pd.concat([true_values, null_ser.iloc[:5], ser_1.iloc[:5]], axis=1)
values_df

In [None]:
# with open(os.path.join(OUTPUT_DIR, 'reg_cross_validators.pkl'), 'wb') as f:
#     pickle.dump(cross_validators, f)

In [None]:
with open(os.path.join(OUTPUT_DIR, 'reg_cross_validators.pkl'), 'rb') as f:
    cross_validators = pickle.load(f)

In [None]:
n_splits = 5

In [None]:
j1_params_df = pd.DataFrame()

for idp, penalizer in enumerate(cross_validators.keys()):
    tmp_j1_params_df = pd.DataFrame()
    for i_fold in range(n_splits):
        tmp_j1_params_df = pd.concat([tmp_j1_params_df, 
                cross_validators[penalizer].models[i_fold].alpha_df.set_index(['J', 'X'])['alpha_jt']], axis=1)

    ser_1 = tmp_j1_params_df.mean(axis=1) 
    ser_1.name = penalizer

    if idp == 0:
        j1_params_df = ser_1.to_frame()
    else:
        j1_params_df = pd.concat([j1_params_df, ser_1], axis=1)

    j1_params_df = j1_params_df.round(4)  

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(11, 4))

s=12

_plot_pen = [-9, -8, -7, -6, -5, -4]
ax = axes[0]
add_panel_text(ax, 'a')
risk = 1
for penalizer in _plot_pen:
    ax.scatter(range(1,16), j1_params_df.loc[slicer[risk, :], penalizer].values, label=str(penalizer), s=s)
    ax.set_xticks(list(range(1,16)))
ax.plot(list(range(1,16)), [real_coef_dict['alpha'][risk](t) for t in range(1,16)], ls='--', color='k', label='True')
ax.set_xticklabels([str(l) for l in list(range(1,16))])
ax.set_ylabel(r'$\alpha_{1t}$', fontsize=15)
ax.set_xlabel('Time', fontsize=15)
ax.legend(title=r"Log $\lambda$")
ax.set_ylim([-4.1, -2.2])

ax = axes[1]
add_panel_text(ax, 'b')

risk = 2
for penalizer in _plot_pen:
    ax.scatter(range(1,16), j1_params_df.loc[slicer[risk, :], penalizer].values, label=str(penalizer), s=s)
    ax.set_xticks(list(range(1,16)))
ax.plot(list(range(1,16)), [real_coef_dict['alpha'][risk](t) for t in range(1,16)], ls='--', color='k', label='True')
ax.set_xticklabels([str(l) for l in list(range(1,16))])
ax.set_ylabel(r'$\alpha_{2t}$', fontsize=15)
ax.set_xlabel('Time', fontsize=15)
ax.legend(title=r"Log $\lambda$")
ax.set_ylim([-4.1, -2.2])

fig.tight_layout()

fig.savefig(os.path.join(OUTPUT_DIR, 'regularization_alpha_sim.png'), dpi=300)