In [None]:
import pandas as pd
import sys
from matplotlib import pyplot as plt
import os
%matplotlib inline
sys.path.append('../')
from src.plots import add_panel_text
import numpy as np
from sklearn.preprocessing import StandardScaler
from src.constants import *
from pydts.examples_utils.plots import plot_example_pred_output
from pydts.examples_utils.plots import add_panel_text
from pydts.fitters import TwoStagesFitter, DataExpansionFitter
from pydts.examples_utils.plots import plot_events_occurrence
from pydts.cross_validation import TwoStagesCV
import pickle
from tableone import TableOne
from time import time

slicer = pd.IndexSlice


OUTPUT_DIR = '/app/output'
DATA_DIR = '/app/data/mimic-iv-2.0/'

# Load Data

In [None]:
patients_file = os.path.join(DATA_DIR, 'hosp', 'patients.csv.gz')
admissions_file = os.path.join(DATA_DIR, 'hosp', 'admissions.csv.gz')
lab_file = os.path.join(DATA_DIR, 'hosp', 'labevents.csv.gz')
lab_meta_file = os.path.join(DATA_DIR, 'hosp', 'd_labitems.csv.gz')

In [None]:
patients_df = pd.read_csv(patients_file, compression='gzip')
patients_df.head()

In [None]:
COLUMNS_TO_DROP = ['dod']
patients_df.drop(COLUMNS_TO_DROP, axis=1, inplace=True)

In [None]:
print(len(patients_df))

In [None]:
patients_df.dtypes

In [None]:
admissions_df = pd.read_csv(admissions_file, compression='gzip', parse_dates=[ADMISSION_TIME_COL,
                            DISCHARGE_TIME_COL, DEATH_TIME_COL, ED_REG_TIME, ED_OUT_TIME])
admissions_df.head()

In [None]:
COLUMNS_TO_DROP = ['hospital_expire_flag', 'edouttime', 'edregtime', 'deathtime', 'language']
admissions_df.drop(COLUMNS_TO_DROP, axis=1, inplace=True)

In [None]:
admissions_df = admissions_df.merge(patients_df, on=[SUBJECT_ID_COL])
admissions_df.shape

# Calculate Age at Admission and Group of Admission Year

Based on mimic IV example https://mimic.mit.edu/docs/iv/modules/hosp/patients/

In [None]:
# Diff column first
admissions_df[ADMISSION_YEAR_COL] = (admissions_df[ADMISSION_TIME_COL].dt.year - admissions_df['anchor_year'])

# Age at admission calculation
admissions_df[ADMISSION_AGE_COL] = (admissions_df[AGE_COL] + admissions_df[ADMISSION_YEAR_COL])

# Admission year group lower bound calculation
admissions_df[ADMISSION_YEAR_COL] = admissions_df[ADMISSION_YEAR_COL] + admissions_df[YEAR_GROUP_COL].apply(lambda x: int(x.split(' ')[0]))

In [None]:
fig, ax = plt.subplots(1,1,dpi=100)
admissions_df[ADMISSION_YEAR_COL].value_counts().sort_index().plot.bar(ax=ax)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_xlabel('Admission Year (lower bound)', fontsize=font_sz)
for p in ax.patches:
    ax.annotate(str(p.get_height()), (p.get_x(), p.get_height() * 1.01))

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8,4))
tmp = admissions_df[[ADMISSION_AGE_COL, GENDER_COL]]
tmp.groupby([ADMISSION_AGE_COL, GENDER_COL]).size().unstack().plot(kind='bar', ax=ax)
ax.set_xlabel('Age at Admission [years]', fontsize=font_sz)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_title(f'Total Population, N={len(tmp)}', fontsize=font_sz)
ax.legend(labels=['Female', 'Male'], title="Sex")
plt.setp(ax.get_xticklabels()[1::2], visible=False)
plt.show()

# Calculating LOS (exact, days resolution) and night admission indicator

In [None]:
admissions_df[LOS_EXACT_COL] = (admissions_df[DISCHARGE_TIME_COL] - admissions_df[ADMISSION_TIME_COL])
admissions_df[NIGHT_ADMISSION_FLAG] = ((admissions_df[ADMISSION_TIME_COL].dt.hour >= 20) | \
                                       (admissions_df[ADMISSION_TIME_COL].dt.hour < 8) ).values
admissions_df[LOS_DAYS_COL] = admissions_df[LOS_EXACT_COL].dt.ceil('1d')
print(f"Mean night admissions flag: {admissions_df[NIGHT_ADMISSION_FLAG].mean():.3f}")

In [None]:
fig, ax = plt.subplots(1,1,dpi=100)
admissions_df[ADMISSION_TYPE_COL].value_counts().plot.bar(ax=ax)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_xlabel('Admission Type', fontsize=font_sz)
for p in ax.patches:
    ax.annotate(str(p.get_height()), (p.get_x(), p.get_height() * 1.01))

In [None]:
max_clip_days = 28

fig, axes = plt.subplots(1, 3, figsize=(16, 4))

ax = axes[0]
tmp = admissions_df[admissions_df[ADMISSION_TYPE_COL] == 'URGENT']
los_bar = tmp[LOS_DAYS_COL].clip(pd.to_timedelta('1d'), pd.to_timedelta(f'{max_clip_days}d')).value_counts().sort_index()
los_bar.index = np.arange(1, max_clip_days+1)
los_bar.plot.bar(ax=ax)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_xlabel('LOS (Days)', fontsize=font_sz)
ax.grid(axis='y')
ax.set_title('URGENT', fontsize=font_sz)

ax = axes[1]
tmp = admissions_df[admissions_df[ADMISSION_TYPE_COL] == 'EW EMER.']
los_bar = tmp[LOS_DAYS_COL].clip(pd.to_timedelta('1d'), pd.to_timedelta(f'{max_clip_days}d')).value_counts().sort_index()
los_bar.index = np.arange(1, max_clip_days+1)
los_bar.plot.bar(ax=ax)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_xlabel('LOS (Days)', fontsize=font_sz)
ax.grid(axis='y')
ax.set_title('EW EMER.', fontsize=font_sz)

ax = axes[2]
tmp = admissions_df[admissions_df[ADMISSION_TYPE_COL] == 'DIRECT EMER.']
los_bar = tmp[LOS_DAYS_COL].clip(pd.to_timedelta('1d'), pd.to_timedelta(f'{max_clip_days}d')).value_counts().sort_index()
los_bar.index = np.arange(1, max_clip_days+1)
los_bar.plot.bar(ax=ax)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_xlabel('LOS (Days)', fontsize=font_sz)
ax.grid(axis='y')
ax.set_title('DIRECT EMER.', fontsize=font_sz)

fig.tight_layout()

# Taking only SPECIFIC_ADMISSION_TYPE admissions from now on

In [None]:
SPECIFIC_ADMISSION_TYPE = ['DIRECT EMER.', 'EW EMER.']

In [None]:
print(len(admissions_df))
admissions_df = admissions_df[admissions_df[ADMISSION_TYPE_COL].isin(SPECIFIC_ADMISSION_TYPE)]
print(len(admissions_df))

In [None]:
# add direct emergency if needed

if 'DIRECT EMER.' in SPECIFIC_ADMISSION_TYPE:
    admissions_df[DIRECT_IND_COL] = (admissions_df[ADMISSION_TYPE_COL] == 'DIRECT EMER.').astype(int)

# Counting SPECIFIC_ADMISSION_TYPE admissions to each patient 

In [None]:
number_of_admissions = admissions_df.groupby(SUBJECT_ID_COL)[ADMISSION_ID_COL].nunique()
number_of_admissions.name = ADMISSION_COUNT_COL
number_of_admissions.head()

In [None]:
fig, ax = plt.subplots(1,1,dpi=100)
number_of_admissions.value_counts().sort_index().plot.bar(ax=ax, logy=True)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_xlabel('Number of Admissions', fontsize=font_sz)
ax.grid('y', which='minor', alpha=0.4)
for p in ax.patches:
    ax.annotate(str(p.get_height()), (p.get_x(), p.get_height() * 1.01))

In [None]:
admissions_df = admissions_df.merge(number_of_admissions, on=SUBJECT_ID_COL)
admissions_df.head()

# Add recurrent admissions group per patient according to last admission

In [None]:
ADMISSION_COUNT_BINS = [1, 1.5, 2.5, 5000]
ADMISSION_COUNT_LABELS = ['1', '2', '3up']

admissions_df[ADMISSION_COUNT_GROUP_COL] = pd.cut(admissions_df[ADMISSION_COUNT_COL], 
                                                  bins=ADMISSION_COUNT_BINS, 
                                                  labels=ADMISSION_COUNT_LABELS, 
                                                  include_lowest=True)
admissions_df.head()

# Adds last admission with previous admission in past month indicator

In [None]:
indicator_diff = pd.to_timedelta('30d')

tmp_admissions = admissions_df[admissions_df[ADMISSION_COUNT_COL] > 1]
print(tmp_admissions.shape)
ind_ser = tmp_admissions.sort_values(by=[SUBJECT_ID_COL, ADMISSION_TIME_COL]).groupby(
    SUBJECT_ID_COL).apply(
    lambda tmp_df: (tmp_df[ADMISSION_TIME_COL] - tmp_df[DISCHARGE_TIME_COL].shift(1)) <= indicator_diff)

ind_ser.index = ind_ser.index.droplevel(1)
ind_ser.name = PREV_ADMISSION_IND_COL
ind_ser = ind_ser.iloc[ind_ser.reset_index().drop_duplicates(subset=[SUBJECT_ID_COL], keep='last').index]
ind_ser

In [None]:
admissions_df = admissions_df.merge(ind_ser.astype(int), left_on=SUBJECT_ID_COL, right_index=True, how='outer')
admissions_df[PREV_ADMISSION_IND_COL].fillna(0, inplace=True)
admissions_df

In [None]:
# Example
admissions_df[admissions_df[PREV_ADMISSION_IND_COL] == 1].sort_values(by=[SUBJECT_ID_COL, ADMISSION_TIME_COL])

# Keep only last admission per patient

In [None]:
only_last_admission = admissions_df.sort_values(by=[ADMISSION_TIME_COL]).drop_duplicates(subset=[SUBJECT_ID_COL], keep='last')
len(only_last_admission)

# Take only patients with last admission after MINIMUM YEAR

In [None]:
# MINIMUM_YEAR = 2017
MINIMUM_YEAR = 2014
print(len(only_last_admission))
only_last_admission = only_last_admission[only_last_admission[ADMISSION_YEAR_COL] >= MINIMUM_YEAR]
print(len(only_last_admission))

In [None]:
only_last_admission[PREV_ADMISSION_IND_COL].mean()

In [None]:
pids = only_last_admission[SUBJECT_ID_COL].drop_duplicates()
adm_ids = only_last_admission[ADMISSION_ID_COL].drop_duplicates()
print(len(pids))
print(len(adm_ids))

# Load relevant lab tests

In [None]:
LOAD_SPECIFIC_COLUMNS = [SUBJECT_ID_COL, ADMISSION_ID_COL, ITEM_ID_COL, 'storetime', 'flag']

In [None]:
chunksize = 10 ** 6
full_df = pd.DataFrame()
with pd.read_csv(lab_file, chunksize=chunksize, compression='gzip', parse_dates=[STORE_TIME_COL], usecols=LOAD_SPECIFIC_COLUMNS) as reader:
    for chunk in reader:
        tmp_chunk = chunk[chunk[SUBJECT_ID_COL].isin(pids) & chunk[ADMISSION_ID_COL].isin(adm_ids)]
        tmp_adms = only_last_admission[only_last_admission[SUBJECT_ID_COL].isin(pids) & only_last_admission[ADMISSION_ID_COL].isin(adm_ids)]
        #tmp_patinets = patients_df[patients_df[SUBJECT_ID_COL].isin(pids)]
        tmp_chunk = tmp_chunk.merge(tmp_adms, on=[SUBJECT_ID_COL, ADMISSION_ID_COL])
        #tmp = tmp_chunk.merge(tmp_patinets, on=[SUBJECT_ID_COL])
        full_df = pd.concat([full_df, tmp_chunk])
        print(len(full_df))

full_df.head()

# Continue only with included patients_df and admissions_df and full_df

In [None]:
pids = full_df[SUBJECT_ID_COL].drop_duplicates().values
adms_ids = full_df[ADMISSION_ID_COL].drop_duplicates().values
print(len(patients_df))
patients_df = patients_df[patients_df[SUBJECT_ID_COL].isin(pids)]
print(len(patients_df))
print(len(admissions_df))
admissions_df = admissions_df[admissions_df[ADMISSION_ID_COL].isin(adms_ids)]
print(len(admissions_df))

In [None]:
len(full_df)

In [None]:
fig, ax = plt.subplots(1,1,dpi=100)
admissions_df[ADMISSION_LOCATION_COL].value_counts().plot.bar(ax=ax)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_xlabel('Admission Location', fontsize=font_sz)
for p in ax.patches:
    ax.annotate(str(p.get_height()), (p.get_x(), p.get_height() * 1.01))

In [None]:
fig, ax = plt.subplots(1,1,dpi=100)
admissions_df[DISCHARGE_LOCATION_COL].value_counts().plot.bar(ax=ax)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_xlabel('Discharge Location', fontsize=font_sz)
for p in ax.patches:
    ax.annotate(str(p.get_height()), (p.get_x(), p.get_height() * 1.01))

# Regrouping discharge location

In [None]:
discharge_regrouping_df = pd.Series(DISCHARGE_REGROUPING_DICT).to_frame()
discharge_regrouping_df.index.name = 'Original Group'
discharge_regrouping_df.columns = ['Regrouped']
discharge_regrouping_df

In [None]:
admissions_df[DISCHARGE_LOCATION_COL].replace(DISCHARGE_REGROUPING_DICT, inplace=True)
full_df[DISCHARGE_LOCATION_COL].replace(DISCHARGE_REGROUPING_DICT, inplace=True)

In [None]:
fig, ax = plt.subplots(1,1,dpi=100)
admissions_df[DISCHARGE_LOCATION_COL].value_counts().plot.bar(ax=ax)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_xlabel('Discharge Location', fontsize=font_sz)
for p in ax.patches:
    ax.annotate(str(p.get_height()), (p.get_x(), p.get_height() * 1.01))

In [None]:
fig, ax = plt.subplots(1,1,dpi=100)
tmp = admissions_df[[ADMISSION_AGE_COL, GENDER_COL]]
tmp.groupby([ADMISSION_AGE_COL, GENDER_COL]).size().unstack().plot(kind='bar', ax=ax)
ax.set_xlabel('Age at Admission [years]', fontsize=font_sz)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_title(f'Total Population, N={len(tmp)}', fontsize=font_sz)
ax.legend(labels=['Female', 'Male'], title="Sex")
plt.setp(ax.get_xticklabels()[1::2], visible=False)
fig.savefig(os.path.join(OUTPUT_DIR, 'age_gender_admissions_subset.png'), dpi=300)

# Regroup Race

In [None]:
race_regrouping_df = pd.Series(RACE_REGROUPING_DICT).to_frame()
race_regrouping_df.index.name = 'Original Group'
race_regrouping_df.columns = ['Regrouped']
race_regrouping_df

In [None]:
admissions_df[RACE_COL].replace(RACE_REGROUPING_DICT, inplace=True)
full_df[RACE_COL].replace(RACE_REGROUPING_DICT, inplace=True)

In [None]:
fig, ax = plt.subplots(1,1,dpi=100)
admissions_df[RACE_COL].value_counts().plot.bar(ax=ax)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_xlabel('Race', fontsize=font_sz)
for p in ax.patches:
    ax.annotate(str(p.get_height()), (p.get_x(), p.get_height() * 1.01))

In [None]:
fig, ax = plt.subplots(1,1,dpi=100)
admissions_df[INSURANCE_COL].value_counts().plot.bar(ax=ax)
ax.set_ylabel('Number of Patients', fontsize=font_sz)
ax.set_xlabel('Insurance', fontsize=font_sz)
for p in ax.patches:
    ax.annotate(str(p.get_height()), (p.get_x(), p.get_height() * 1.01))

# Taking only results 24 hours from admission

In [None]:
full_df.head()

In [None]:
full_df[ADMISSION_TO_RESULT_COL] = (full_df[STORE_TIME_COL] - full_df[ADMISSION_TIME_COL])

In [None]:
full_df = full_df[full_df[ADMISSION_TO_RESULT_COL] <= pd.to_timedelta('1d')]
full_df.head()

In [None]:
print(len(full_df))
full_df.sort_values(by=[ADMISSION_TIME_COL, STORE_TIME_COL]).drop_duplicates(subset=[SUBJECT_ID_COL, ADMISSION_ID_COL, ITEM_ID_COL], 
    inplace=True, keep='last')
print(len(full_df))

# Most common lab tests upon arrival

In [None]:
lab_meta_df = pd.read_csv(lab_meta_file, compression='gzip')
lab_meta_df

In [None]:
threshold = 25000

In [None]:
common_tests = full_df.groupby(ITEM_ID_COL)[ADMISSION_ID_COL].nunique().sort_values(ascending=False)
included_in_threshold = common_tests[common_tests > threshold].to_frame().merge(lab_meta_df, on=ITEM_ID_COL)
included_in_threshold

In [None]:
print(len(full_df))
full_df = full_df[full_df[ITEM_ID_COL].isin(included_in_threshold[ITEM_ID_COL].values)]
print(len(full_df))

In [None]:
minimal_item_id = included_in_threshold.iloc[-1][ITEM_ID_COL]
minimal_item_id

In [None]:
pids = full_df[full_df[ITEM_ID_COL] == minimal_item_id][SUBJECT_ID_COL].drop_duplicates().values
adms_ids = full_df[full_df[ITEM_ID_COL] == minimal_item_id][ADMISSION_ID_COL].drop_duplicates().values
print(len(patients_df))
patients_df = patients_df[patients_df[SUBJECT_ID_COL].isin(pids)]
print(len(patients_df))
print(len(admissions_df))
admissions_df = admissions_df[admissions_df[ADMISSION_ID_COL].isin(adms_ids)]
print(len(admissions_df))
print(len(admissions_df))
full_df = full_df[full_df[SUBJECT_ID_COL].isin(pids)]
full_df = full_df[full_df[ADMISSION_ID_COL].isin(adms_ids)]
print(len(admissions_df))

In [None]:
full_df.head()

In [None]:
full_df['flag'].fillna('normal', inplace=True)
full_df['flag'].replace({'normal': 0, 'abnormal':1}, inplace=True)
full_df['flag'].value_counts()

In [None]:
full_df = full_df.sort_values(by=[ADMISSION_TIME_COL, STORE_TIME_COL]).drop_duplicates(
    subset=[SUBJECT_ID_COL, ADMISSION_ID_COL, ITEM_ID_COL], 
    keep='last')
full_df

In [None]:
tmp = full_df[[SUBJECT_ID_COL, ADMISSION_ID_COL, ITEM_ID_COL, 'flag']]
fitters_table = pd.pivot_table(tmp, values=['flag'], index=[SUBJECT_ID_COL, ADMISSION_ID_COL], 
                               columns=[ITEM_ID_COL], aggfunc=np.sum)
fitters_table

In [None]:
fitters_table = fitters_table.droplevel(1, axis=0).droplevel(0, axis=1)
fitters_table

In [None]:
dummies_df = full_df.drop_duplicates(subset=[SUBJECT_ID_COL]).set_index(SUBJECT_ID_COL)
dummies_df

In [None]:
del full_df
del admissions_df
del patients_df

# Standardize age

In [None]:
scaler = StandardScaler()
dummies_df[STANDARDIZED_AGE_COL] = scaler.fit_transform(dummies_df[[AGE_COL]])

In [None]:
J_DICT = {'HOME': 1, 'FURTHER TREATMENT': 2, 'DIED': 3, 'CENSORED': 0} 
GENDER_DICT = {'F': 1, 'M': 0}

In [None]:
dummies_df[GENDER_COL] = dummies_df[GENDER_COL].replace(GENDER_DICT)

# Table 1

In [None]:
included_in_threshold['label'] = included_in_threshold['label'].apply(lambda x: x.replace(' ', '')).apply(lambda x: x.replace(',', ''))
RENAME_ITEMS_DICT = included_in_threshold[[ITEM_ID_COL, 'label']].set_index(ITEM_ID_COL).to_dict()['label']
RENAME_ITEMS_DICT

In [None]:
table1 = pd.concat([
    fitters_table.copy(),
    dummies_df[[NIGHT_ADMISSION_FLAG,
                GENDER_COL, 
                DIRECT_IND_COL,
                PREV_ADMISSION_IND_COL,
                ADMISSION_AGE_COL]].astype(int),
    dummies_df[[INSURANCE_COL,
                MARITAL_STATUS_COL,
                RACE_COL,
                ADMISSION_COUNT_GROUP_COL]],
    dummies_df[LOS_DAYS_COL].dt.days,
    dummies_df[DISCHARGE_LOCATION_COL].dropna().replace(J_DICT).astype(int)
], axis=1)
    
table1.rename(RENAME_ITEMS_DICT, inplace=True, axis=1)  
table1.dropna(inplace=True)
table1

In [None]:
ADMINISTRATIVE_CENSORING = 28
censoring_index = table1[table1[LOS_DAYS_COL] > ADMINISTRATIVE_CENSORING].index
table1.loc[censoring_index, DISCHARGE_LOCATION_COL] = 0
table1.loc[censoring_index, LOS_DAYS_COL] = ADMINISTRATIVE_CENSORING + 1

In [None]:
table1[GENDER_COL].replace(table1_rename_sex, inplace=True)
table1[RACE_COL].replace(table1_rename_race, inplace=True)
table1[MARITAL_STATUS_COL].replace(table1_rename_marital, inplace=True)
table1[DIRECT_IND_COL].replace(table1_rename_yes_no, inplace=True)
table1[NIGHT_ADMISSION_FLAG].replace(table1_rename_yes_no, inplace=True)
table1[PREV_ADMISSION_IND_COL].replace(table1_rename_yes_no, inplace=True)
table1[DISCHARGE_LOCATION_COL].replace(table1_rename_discharge, inplace=True)
table1[ADMISSION_COUNT_GROUP_COL].replace({'3up': '3+'}, inplace=True)
table1.rename(table1_rename_columns, inplace=True, axis=1)

In [None]:
columns = ['gender', 'admission_age', 'race', 'insurance', 'marital_status',
           'direct_emrgency_flag', 'night_admission', 'last_less_than_diff', 
           'admissions_count_group', 'LOS days', 'discharge_location']
columns = [table1_rename_columns[c] for c in columns]
categorical = ['gender', 'race', 'insurance', 'marital_status',
           'direct_emrgency_flag', 'night_admission', 'last_less_than_diff', 
           'admissions_count_group', 'discharge_location']
categorical = [table1_rename_columns[c] for c in categorical]
table1.dropna(inplace=True)
groupby = [table1_rename_columns[DISCHARGE_LOCATION_COL]]
mytable = TableOne(table1, columns, categorical, groupby, missing=False)
mytable

In [None]:
print(mytable.tableone.round(3).to_latex())

In [None]:
columns = [DISCHARGE_LOCATION_COL, 'AnionGap', 'Bicarbonate', 'CalciumTotal', 'Chloride', 'Creatinine',
           'Glucose', 'Magnesium', 'Phosphate', 'Potassium', 'Sodium',
           'UreaNitrogen', 'Hematocrit', 'Hemoglobin', 'MCH', 'MCHC', 'MCV',
           'PlateletCount', 'RDW', 'RedBloodCells', 'WhiteBloodCells']
categorical = [DISCHARGE_LOCATION_COL, 'AnionGap', 'Bicarbonate', 'CalciumTotal', 'Chloride', 'Creatinine',
           'Glucose', 'Magnesium', 'Phosphate', 'Potassium', 'Sodium',
           'UreaNitrogen', 'Hematocrit', 'Hemoglobin', 'MCH', 'MCHC', 'MCV',
           'PlateletCount', 'RDW', 'RedBloodCells', 'WhiteBloodCells']
columns = [table1_rename_columns[c] for c in columns]
categorical = [table1_rename_columns[c] for c in categorical]
groupby = [table1_rename_columns[DISCHARGE_LOCATION_COL]]
mytable = TableOne(table1.dropna().replace(table1_rename_normal_abnormal), columns, categorical, groupby, missing=False)
mytable

In [None]:
print(mytable.tableone.round(3).to_latex())

In [None]:
fitters_table = pd.concat([
    fitters_table.copy(),
    pd.get_dummies(dummies_df[INSURANCE_COL], prefix='Insurance', drop_first=True),
    pd.get_dummies(dummies_df[MARITAL_STATUS_COL], prefix='Marital', drop_first=True),
    pd.get_dummies(dummies_df[RACE_COL], prefix='Ethnicity', drop_first=True),
    pd.get_dummies(dummies_df[ADMISSION_COUNT_GROUP_COL], prefix='AdmsCount', drop_first=True),
    dummies_df[[NIGHT_ADMISSION_FLAG, 
                GENDER_COL, 
                DIRECT_IND_COL,
                PREV_ADMISSION_IND_COL]].astype(int),
    dummies_df[STANDARDIZED_AGE_COL],
    dummies_df[LOS_DAYS_COL].dt.days,
    dummies_df[DISCHARGE_LOCATION_COL].dropna().replace(J_DICT).astype(int)
], axis=1)
    
fitters_table   

In [None]:
fitters_table.columns

In [None]:
print(len(fitters_table))
fitters_table.dropna(inplace=True)
fitters_table = fitters_table[fitters_table.index.isin(table1.index)]
print(len(fitters_table))

In [None]:
fitters_table.reset_index(inplace=True)
fitters_table.rename({DISCHARGE_LOCATION_COL: 'J', LOS_DAYS_COL: 'X', SUBJECT_ID_COL: 'pid'}, inplace=True, axis=1)
fitters_table.rename(RENAME_ITEMS_DICT, inplace=True, axis=1)

In [None]:
fitters_table = fitters_table[fitters_table['X'] > 0]
fitters_table.loc[fitters_table.X > ADMINISTRATIVE_CENSORING, 'J'] = 0
fitters_table.loc[fitters_table.X > ADMINISTRATIVE_CENSORING, 'X'] = ADMINISTRATIVE_CENSORING + 1
fitters_table['J'] = fitters_table['J'].astype(int)

plot_events_occurrence(fitters_table)

In [None]:
case = f'mimic_final_'
two_step_timing = []
lee_timing = []

# Two step fitter
new_fitter = TwoStagesFitter()
print(f'Starting two-step')
two_step_start = time()
new_fitter.fit(df=fitters_table, nb_workers=1)
two_step_end = time()
print(f'Finished two-step: {two_step_end-two_step_start}sec')

two_step_timing.append(two_step_end-two_step_start)

# Lee et al fitter
print(f'Starting Lee et al.')
lee_fitter = DataExpansionFitter()
lee_start = time()
lee_fitter.fit(df=fitters_table)
lee_end = time()
print(f'Finished lee: {lee_end-lee_start}sec')

lee_timing.append(lee_end-lee_start) 

# Regularized Two step fitter
reg_fitter = TwoStagesFitter()
print(f'Starting regularized two-step')
fit_beta_kwargs = {
        'model_kwargs': {
        'penalizer': np.exp(-7),
        'l1_ratio': 1
    }
}
reg_two_step_start = time()
reg_fitter.fit(df=fitters_table, nb_workers=1, fit_beta_kwargs=fit_beta_kwargs)
reg_two_step_end = time()
print(f'Finished two-step: {reg_two_step_end-reg_two_step_start}sec')

lee_alpha_ser = lee_fitter.get_alpha_df().loc[:, slicer[:, [COEF_COL, STDERR_COL] ]].unstack().sort_index()
lee_beta_ser = lee_fitter.get_beta_SE().loc[:, slicer[:, [COEF_COL, STDERR_COL] ]].unstack().sort_index()

two_step_alpha_k_results = new_fitter.alpha_df[['J', 'X', 'alpha_jt']]
two_step_beta_k_results = new_fitter.get_beta_SE().unstack().to_frame()

reg_two_step_alpha_k_results = reg_fitter.alpha_df[['J', 'X', 'alpha_jt']]
reg_two_step_beta_k_results = reg_fitter.get_beta_SE().unstack().to_frame()

lee_alpha_k_results = lee_alpha_ser.to_frame()
lee_beta_k_results = lee_beta_ser.to_frame()

# Cache results
two_step_alpha_k_results.to_csv(os.path.join(OUTPUT_DIR, f'{case}_two_step_alpha.csv'))
two_step_beta_k_results.to_csv(os.path.join(OUTPUT_DIR, f'{case}_two_step_beta.csv'))
reg_two_step_alpha_k_results.to_csv(os.path.join(OUTPUT_DIR, f'{case}_reg_two_step_alpha.csv'))
reg_two_step_beta_k_results.to_csv(os.path.join(OUTPUT_DIR, f'{case}_reg_two_step_beta.csv'))
lee_alpha_k_results.to_csv(os.path.join(OUTPUT_DIR, f'{case}_lee_alpha.csv'))
lee_beta_k_results.to_csv(os.path.join(OUTPUT_DIR, f'{case}_lee_beta.csv'))

In [None]:
covariates = [c for c in fitters_table.columns if c not in ['pid', 'J', 'X']]
covariates

In [None]:
two_step_alpha_k_results = pd.read_csv(os.path.join(OUTPUT_DIR, f'{case}_two_step_alpha.csv'), 
                                       index_col=['J', 'X'])
two_step_beta_k_results = pd.read_csv(os.path.join(OUTPUT_DIR, f'{case}_two_step_beta.csv'),
                                      index_col=[0, 1])
reg_two_step_alpha_k_results = pd.read_csv(os.path.join(OUTPUT_DIR, f'{case}_reg_two_step_alpha.csv'), 
                                       index_col=['J', 'X'])
reg_two_step_beta_k_results = pd.read_csv(os.path.join(OUTPUT_DIR, f'{case}_reg_two_step_beta.csv'),
                                      index_col=[0, 1])
lee_alpha_k_results = pd.read_csv(os.path.join(OUTPUT_DIR, f'{case}_lee_alpha.csv'),
                                  index_col=[0,1,2])
lee_beta_k_results = pd.read_csv(os.path.join(OUTPUT_DIR, f'{case}_lee_beta.csv'),
                                 index_col=[0, 1, 2])


twostep_beta1_summary = two_step_beta_k_results.mean(axis=1).unstack([0]).round(3).iloc[:, [1,0]]
twostep_beta1_summary.index = [f'{iii.replace(" ", "")}_1' for iii in twostep_beta1_summary.index]
twostep_beta2_summary = two_step_beta_k_results.mean(axis=1).unstack([0]).round(3).iloc[:, [3,2]]
twostep_beta2_summary.index = [f'{iii.replace(" ", "")}_2' for iii in twostep_beta2_summary.index]
twostep_beta3_summary = two_step_beta_k_results.mean(axis=1).unstack([0]).round(3).iloc[:, [5,4]]
twostep_beta3_summary.index = [f'{iii.replace(" ", "")}_3' for iii in twostep_beta3_summary.index]

reg_twostep_beta1_summary = reg_two_step_beta_k_results.mean(axis=1).unstack([0]).round(3).iloc[:, [1,0]]
reg_twostep_beta1_summary.index = [f'{iii.replace(" ", "")}_1' for iii in reg_twostep_beta1_summary.index]
reg_twostep_beta2_summary = reg_two_step_beta_k_results.mean(axis=1).unstack([0]).round(3).iloc[:, [3,2]]
reg_twostep_beta2_summary.index = [f'{iii.replace(" ", "")}_2' for iii in reg_twostep_beta2_summary.index]
reg_twostep_beta3_summary = reg_two_step_beta_k_results.mean(axis=1).unstack([0]).round(3).iloc[:, [5,4]]
reg_twostep_beta3_summary.index = [f'{iii.replace(" ", "")}_3' for iii in reg_twostep_beta3_summary.index]

lee_beta1_summary = lee_beta_k_results.mean(axis=1).loc[slicer[1,:,:]].unstack([0]).round(3)
lee_beta1_summary.index = [f'{iii.replace(" ", "")}_1' for iii in lee_beta1_summary.index]
lee_beta2_summary = lee_beta_k_results.mean(axis=1).loc[slicer[2,:,:]].unstack([0]).round(3)
lee_beta2_summary.index = [f'{iii.replace(" ", "")}_2' for iii in lee_beta2_summary.index]
lee_beta3_summary = lee_beta_k_results.mean(axis=1).loc[slicer[3,:,:]].unstack([0]).round(3)
lee_beta3_summary.index = [f'{iii.replace(" ", "")}_3' for iii in lee_beta3_summary.index]
    
lee_beta1_summary.columns = pd.MultiIndex.from_tuples([('Lee et al.', 'Estimate'), ('Lee et al.', 'Estimated SE')])
lee_beta2_summary.columns = pd.MultiIndex.from_tuples([('Lee et al.', 'Estimate'), ('Lee et al.', 'Estimated SE')])
lee_beta3_summary.columns = pd.MultiIndex.from_tuples([('Lee et al.', 'Estimate'), ('Lee et al.', 'Estimated SE')])

beta_summary_comparison = pd.concat([lee_beta1_summary, lee_beta2_summary, lee_beta3_summary], axis=0)

twostep_beta1_summary.columns = pd.MultiIndex.from_tuples([('two-step', 'Estimate'), ('two-step', 'Estimated SE')])
twostep_beta2_summary.columns = pd.MultiIndex.from_tuples([('two-step', 'Estimate'), ('two-step', 'Estimated SE')])
twostep_beta3_summary.columns = pd.MultiIndex.from_tuples([('two-step', 'Estimate'), ('two-step', 'Estimated SE')])

reg_twostep_beta1_summary.columns = pd.MultiIndex.from_tuples([('Regularized two-step', 'Estimate'), ('Regularized two-step', 'Estimated SE')])
reg_twostep_beta2_summary.columns = pd.MultiIndex.from_tuples([('Regularized two-step', 'Estimate'), ('Regularized two-step', 'Estimated SE')])
reg_twostep_beta3_summary.columns = pd.MultiIndex.from_tuples([('Regularized two-step', 'Estimate'), ('Regularized two-step', 'Estimated SE')])

tmp = pd.concat([twostep_beta1_summary.round(3), twostep_beta2_summary.round(3), twostep_beta3_summary.round(3)], axis=0)
tmp2 = pd.concat([reg_twostep_beta1_summary.round(3), reg_twostep_beta2_summary.round(3), reg_twostep_beta3_summary.round(3)], axis=0)

beta_summary_comparison = pd.concat([beta_summary_comparison, tmp, tmp2], axis=1)
beta_summary_comparison.index.name =  r'$\beta_{jk}$'
beta_summary_comparison.index = [c.replace("_", " ") for c in beta_summary_comparison.index]
beta_summary_comparison

In [None]:
risk1_rename_index_dict = {k + f' 1': v for k, v in rename_beta_index.items()}
risk1 = beta_summary_comparison.iloc[:int(len(beta_summary_comparison) // 3)].rename(risk1_rename_index_dict, axis=0)
print(risk1.to_latex(escape=False))

In [None]:
risk2_rename_index_dict = {k + f' 2': v for k, v in rename_beta_index.items()}
risk2 = beta_summary_comparison.iloc[int(len(beta_summary_comparison) // 3):2*(int(len(beta_summary_comparison) // 3))].rename(risk2_rename_index_dict, axis=0)
print(risk2.to_latex(escape=False))

In [None]:
risk3_rename_index_dict = {k + f' 3': v for k, v in rename_beta_index.items()}
risk3 = beta_summary_comparison.iloc[2*int(len(beta_summary_comparison) // 3):].rename(risk3_rename_index_dict, axis=0)
print(risk3.to_latex(escape=False))

In [None]:
filename = 'mimic_summary_.png'

first_model_name = 'Lee et al.'
second_model_name = 'two-step'
times = range(1, ADMINISTRATIVE_CENSORING+1)

lee_colors = ['tab:blue', 'tab:green', 'tab:red']
two_step_colors = ['navy', 'darkgreen', 'tab:brown']
true_colors = ['tab:blue', 'tab:green', 'tab:red']

fig, ax = plt.subplots(1, 1, figsize=(10, 8))

counts = fitters_table.groupby(['J', 'X'])['pid'].count().unstack('J').fillna(0)

two_step_alpha_k_results = pd.read_csv(os.path.join(OUTPUT_DIR, f'{case}_two_step_alpha.csv'), 
                                         index_col=['J', 'X'])

lee_alpha_k_results = pd.read_csv(os.path.join(OUTPUT_DIR, f'{case}_lee_alpha.csv'),
                                   index_col=[0,1,2])

ax.tick_params(axis='both', which='major', labelsize=15)
ax.tick_params(axis='both', which='minor', labelsize=15)

for j in [1, 2, 3]:

    tmp_alpha = lee_alpha_k_results.loc[slicer[j, COEF_COL, :]].mean(axis=1)
    tmp_alpha.index = [int(idx.split(')[')[1].split(']')[0]) for idx in tmp_alpha.index]
    tmp_alpha = pd.Series(tmp_alpha.values.squeeze().astype(float), index=tmp_alpha.index)

    ax.scatter(tmp_alpha.index, tmp_alpha.values,
       label=f'J={j} ({first_model_name})', color=lee_colors[j-1], marker='o', alpha=0.4, s=40)

    tmp_alpha = two_step_alpha_k_results.loc[slicer[j, 'alpha_jt']]
    ax.scatter(tmp_alpha.index, tmp_alpha.values,
       label=f'J={j} ({second_model_name})', color=two_step_colors[j-1], marker='*', alpha=0.7, s=20)

    ax.set_xlabel(r'Time', fontsize=18)
    ax.set_ylabel(r'$\alpha_{jt}$', fontsize=18)
    ax.legend(loc='upper right', fontsize=12)

ax.set_ylim([-13, 3])

ax2 = ax.twinx()
ax2.bar(counts.index, counts[1].values.squeeze(), label='J=1', color='navy', alpha=0.4, width=0.4)
ax2.bar(counts.index, counts[2].values.squeeze(), label='J=2', color='darkgreen', alpha=0.4, align='edge',
        width=0.4)
ax2.bar(counts.index, counts[3].values.squeeze(), label='J=3', color='tab:red', alpha=0.6, align='edge',
        width=-0.4)
ax2.legend(loc='upper center', fontsize=12)
ax2.set_ylabel('Number of observed events', fontsize=16, color='red')
ax2.tick_params(axis='y', colors='red')
ax2.set_ylim([0, 8000])
ax2.tick_params(axis='both', which='major', labelsize=15)
ax2.tick_params(axis='both', which='minor', labelsize=15)
    
fig.tight_layout()

if filename is not None:
    fig.savefig(os.path.join(OUTPUT_DIR, filename), dpi=300)

In [None]:
step = 1
penalizers = np.arange(-14, -3.9, step=step) 
n_splits = 4
seed = 1
cross_validators = {}

for idp, penalizer in enumerate(penalizers):
    print(f"Started Penalizer: {penalizer}, {idp+1}/{len(penalizers)}")
    fit_beta_kwargs = {
            'model_kwargs': {
            'penalizer': np.exp(penalizer),
            'l1_ratio': 1
        }
    }
    start = time()
    cross_validators[penalizer] = TwoStagesCV()
    cross_validators[penalizer].cross_validate(full_df=fitters_table, n_splits=n_splits, seed=seed, nb_workers=1, 
                                               fit_beta_kwargs=fit_beta_kwargs,
                                               metrics=['PE', 'AUC', 'IAUC', 'GAUC'])
    end = time()
    print(f"Finished Penalizer: {penalizer}, {idp+1}/{len(penalizers)}, {int(end-start)} seconds")

In [None]:
start = time()
cross_validator_null = TwoStagesCV()
cross_validator_null.cross_validate(full_df=fitters_table, n_splits=n_splits, seed=seed, nb_workers=1, 
                                    metrics=['PE', 'AUC', 'IAUC', 'GAUC'])
end = time()
print(f"Finished {int(end-start)} seconds")

In [None]:
lof_censoring = (100*len(fitters_table[(fitters_table['J'] == 0) & (fitters_table['X'] <= ADMINISTRATIVE_CENSORING)]) / len(fitters_table))
adm_censoring = (100*len(fitters_table[(fitters_table['J'] == 0) & (fitters_table['X'] > ADMINISTRATIVE_CENSORING)]) / len(fitters_table))
risks = (100*fitters_table.groupby(['J']).size() / fitters_table.groupby('J').size().sum()).round(1)
print(f"LOF censoring: {lof_censoring:.1f}%, Administrative censoring: {adm_censoring:.1f}%, Home: {risks.loc[1]}%, Further treatment: {risks.loc[2]}%, Death: {risks.loc[3]}%")

In [None]:
ticksize = 15
axes_title_fontsize = 17
legend_size = 13

risk_names = ['Home', 'Further Treatment', 'Death']
risk_colors = ['tab:blue', 'tab:green', 'tab:red']
risk_letters = ['d', 'e', 'f', 'g', 'h', 'i']
chosen_lambda = -7

fig, axes = plt.subplots(3, 3, figsize=(20, 17))

ax = axes[0, 0]
add_panel_text(ax, 'a')
ax.tick_params(axis='both', which='major', labelsize=ticksize)
ax.tick_params(axis='both', which='minor', labelsize=ticksize)
ax.set_xlabel(r'Log ($\lambda$)', fontsize=axes_title_fontsize)
ax.set_ylabel(r'Global AUC', fontsize=axes_title_fontsize)

penalizers_x, mean_gauc, std_gauc = [], [], []
for penalizer in sorted(cross_validators.keys()):
    ser = pd.Series(cross_validators[penalizer].global_auc)
    penalizers_x.append(penalizer)
    mean_gauc.append(ser.mean())
    std_gauc.append(ser.std())

ax.errorbar(penalizers_x, mean_gauc, yerr=std_gauc, fmt="o", color='g', alpha=0.5, label='With Penalization')
ax.axhline(pd.Series(cross_validator_null.global_auc).mean(), ls = '--', label='Without Penalization', color='tab:blue')
ax.axvline(chosen_lambda, color='brown', ls='-.', label=r'Chosen $\lambda$')
ax.legend(fontsize=legend_size)
ax.set_ylim([0.53, 0.78])


ax = axes[0, 1]
add_panel_text(ax, 'b')
ax.tick_params(axis='both', which='major', labelsize=ticksize)
ax.tick_params(axis='both', which='minor', labelsize=ticksize)
ax.set_xlabel(r'Log ($\lambda$)', fontsize=axes_title_fontsize)
ax.set_ylabel(r'Integrated AUC', fontsize=axes_title_fontsize)

fig_mean = pd.DataFrame()
fig_std = pd.DataFrame()
for p in sorted(cross_validators.keys()):
    iauc_df = pd.DataFrame.from_dict(cross_validators[p].integrated_auc)
    mean_ser = pd.DataFrame.from_dict(cross_validators[p].integrated_auc).mean(axis=1)
    mean_ser.name = penalizer
    std_ser = pd.DataFrame.from_dict(cross_validators[p].integrated_auc).std(axis=1)
    std_ser.name = penalizer
    fig_mean = pd.concat([fig_mean, mean_ser], axis=1)
    fig_std = pd.concat([fig_std, std_ser], axis=1)

for risk in range(1,4):
    ax.errorbar(penalizers_x, fig_mean.loc[risk], yerr=fig_std.loc[risk], fmt="o", color=risk_colors[risk-1], alpha=0.5, label=f'{risk_names[risk-1]} - With Penalization')
    ax.axhline(pd.DataFrame.from_dict(cross_validator_null.integrated_auc).mean(axis=1).loc[risk], ls = '--', label=f'{risk_names[risk-1]} - Without Penalization', color=risk_colors[risk-1])
ax.set_ylim([0.48, 0.78])
ax.axvline(chosen_lambda, color='brown', ls='-.', label=r'Chosen $\lambda$')
ax.legend(loc='lower left', fontsize=legend_size)

for risk in range(1, 4):
    for idp, penalizer in enumerate(cross_validators.keys()):

        tmp_j1_params_df = pd.DataFrame()
        for i_fold in range(n_splits):
            tmp_j1_params_df = pd.concat([tmp_j1_params_df, cross_validators[penalizer].models[i_fold].beta_models[risk].params_], axis=1)

        ser_1 = tmp_j1_params_df.mean(axis=1) 
        ser_1.name = penalizer

        if idp == 0:
            j1_params_df = ser_1.to_frame()
        else:
            j1_params_df = pd.concat([j1_params_df, ser_1], axis=1)


    ax = axes[1, risk-1]
    add_panel_text(ax, risk_letters[risk-1])
    ax.tick_params(axis='both', which='major', labelsize=ticksize)
    ax.tick_params(axis='both', which='minor', labelsize=ticksize)
    for i in range(len(j1_params_df)):
        ax.plot(penalizers_x, j1_params_df.iloc[i].values, lw=1)

        if i == 0:
            ax.set_ylabel(f'{n_splits}-Fold Mean Coefficient Value', fontsize=axes_title_fontsize)
            ax.set_xlabel(r'Log ($\lambda$)', fontsize=axes_title_fontsize)
            ax.set_title(rf'$\beta_{risk}$ - {risk_names[risk-1]}', fontsize=axes_title_fontsize)
            ax.axvline(chosen_lambda, color='tab:blue', alpha=1, ls='--', lw=1)

    ax = axes[0, 2]
    
    for idp, penalizer in enumerate(cross_validators.keys()):
        tmp_ser = j1_params_df[penalizer].round(3)
        count = (tmp_ser.abs() > 0).sum()
        if idp == 0:
            ax.scatter(penalizer, count, color=risk_colors[risk-1], alpha=0.8, marker='P', label=f'{risk_names[risk-1]}')
        else:
            ax.scatter(penalizer, count, color=risk_colors[risk-1], alpha=0.8, marker='P')
        if penalizer == chosen_lambda:
            print(f"Risk {risk}: {count} non-zero coefficients at chosen lambda {chosen_lambda}")

add_panel_text(ax, 'c')
ax.tick_params(axis='both', which='major', labelsize=ticksize)
ax.tick_params(axis='both', which='minor', labelsize=ticksize)
ax.set_xlabel(r'Log ($\lambda$)', fontsize=axes_title_fontsize)
ax.set_ylabel(f'Number of Non-Zero Coefficient', fontsize=axes_title_fontsize)
ax.axvline(chosen_lambda, color='tab:blue', alpha=1, ls='--', lw=1)
ax.legend(loc='lower left', fontsize=legend_size)

for risk in range(1, 4):
    ax = axes[2, risk-1]
    add_panel_text(ax, risk_letters[3+risk-1])
    ax.tick_params(axis='both', which='major', labelsize=ticksize)
    ax.tick_params(axis='both', which='minor', labelsize=ticksize)
    mean_auc = cross_validators[chosen_lambda].results.loc[slicer['AUC', :, risk]].mean()
    std_auc = cross_validators[chosen_lambda].results.loc[slicer['AUC', :, risk]].std()
    ax.errorbar(mean_auc.index, mean_auc.values, yerr=std_auc.values, fmt="o", color=risk_colors[risk-1], alpha=0.8)
    ax.set_yticks(np.arange(0, 1.1, 0.1))
    ax.set_yticklabels([c.round(1) for c in np.arange(0, 1.1, 0.1)])
    ax.set_xlabel(r'Time', fontsize=axes_title_fontsize)
    ax.set_ylabel(f'AUC (t)', fontsize=axes_title_fontsize)
    ax.set_title(fr'{risk_names[risk-1]}, Log ($\lambda$) = {chosen_lambda}', fontsize=axes_title_fontsize)
    ax.set_ylim([0,1])
    ax.axhline(0.5, ls='--', color='k', alpha=0.3)
    ax2 = ax.twinx()
    ax2.bar(counts.index, counts[risk].values.squeeze(), color=risk_colors[risk-1], alpha=0.8, width=0.4)
    ax2.set_ylabel('Number of observed events', fontsize=axes_title_fontsize, color=risk_colors[risk-1])
    ax2.tick_params(axis='y', colors=risk_colors[risk-1])
    ax2.set_ylim([0, 5100])
    ax2.tick_params(axis='both', which='major', labelsize=ticksize)
    ax2.tick_params(axis='both', which='minor', labelsize=ticksize)

fig.tight_layout()

fig.savefig(os.path.join(OUTPUT_DIR, 'mimic_regularization_fig.png'), dpi=300)


In [None]:
filename = 'mimic_summary_.png'

first_model_name = 'Lee et al.'
second_model_name = 'two-step'
third_model_name = 'Regularized two-step'
times = range(1, ADMINISTRATIVE_CENSORING+1)

lee_colors = ['tab:blue', 'tab:green', 'tab:red']
two_step_colors = ['navy', 'darkgreen', 'tab:brown']
reg_two_step_colors = ['darkviolet', 'olive', 'maroon']
true_colors = ['tab:blue', 'tab:green', 'tab:red']

fig, ax = plt.subplots(1, 1, figsize=(10, 8))

counts = fitters_table.groupby(['J', 'X'])['pid'].count().unstack('J').fillna(0)

two_step_alpha_k_results = pd.read_csv(os.path.join(OUTPUT_DIR, f'{case}_two_step_alpha.csv'), 
                                         index_col=['J', 'X'])

lee_alpha_k_results = pd.read_csv(os.path.join(OUTPUT_DIR, f'{case}_lee_alpha.csv'),
                                   index_col=[0,1,2])

ax.tick_params(axis='both', which='major', labelsize=15)
ax.tick_params(axis='both', which='minor', labelsize=15)

tmp_j1_params_df = pd.DataFrame()
for i_fold in range(n_splits):
    tmp_j1_params_df = pd.concat([tmp_j1_params_df, cross_validators[chosen_lambda].models[i_fold].alpha_df.set_index(['J', 'X'])['alpha_jt']], axis=1)

ser_1 = tmp_j1_params_df.mean(axis=1) 
ser_1.name = penalizer

for j in [1, 2, 3]:

    tmp_alpha = lee_alpha_k_results.loc[slicer[j, COEF_COL, :]].mean(axis=1)
    tmp_alpha.index = [int(idx.split(')[')[1].split(']')[0]) for idx in tmp_alpha.index]
    tmp_alpha = pd.Series(tmp_alpha.values.squeeze().astype(float), index=tmp_alpha.index)

    ax.scatter(tmp_alpha.index, tmp_alpha.values,
       label=f'J={j} ({first_model_name})', color=lee_colors[j-1], marker='o', alpha=0.4, s=40)

    tmp_alpha = two_step_alpha_k_results.loc[slicer[j, 'alpha_jt']]
    ax.scatter(tmp_alpha.index, tmp_alpha.values,
       label=f'J={j} ({second_model_name})', color=two_step_colors[j-1], marker='*', alpha=0.7, s=20)

    ax.scatter(range(1, ADMINISTRATIVE_CENSORING+1), ser_1.loc[slicer[j, :]].values,
       label=f'J={j} ({third_model_name})', color=reg_two_step_colors[j-1], marker='>', alpha=0.7, s=20)

    ax.set_xlabel(r'Time', fontsize=18)
    ax.set_ylabel(r'$\alpha_{jt}$', fontsize=18)
    ax.legend(loc='upper right', fontsize=12)


ax.set_ylim([-13, 4.5])

ax2 = ax.twinx()
ax2.bar(counts.index, counts[1].values.squeeze(), label='J=1', color='navy', alpha=0.4, width=0.4)
ax2.bar(counts.index, counts[2].values.squeeze(), label='J=2', color='darkgreen', alpha=0.4, align='edge',
        width=0.4)
ax2.bar(counts.index, counts[3].values.squeeze(), label='J=3', color='tab:red', alpha=0.6, align='edge',
        width=-0.4)
ax2.legend(loc='upper center', fontsize=12)
ax2.set_ylabel('Number of observed events', fontsize=16, color='red')
ax2.tick_params(axis='y', colors='red')
ax2.set_ylim([0, 8500])
ax2.tick_params(axis='both', which='major', labelsize=15)
ax2.tick_params(axis='both', which='minor', labelsize=15)
    
fig.tight_layout()

if filename is not None:
    fig.savefig(os.path.join(OUTPUT_DIR, filename), dpi=300)