In [None]:
import pandas as pd
pd.set_option('display.max_columns', None)
import joblib
import numpy as np
import seaborn as sns
sns.set_theme()
import matplotlib.pyplot as plt
fs = 14


In [None]:
import os

def count_sequence(row, model='trad'):
    counter = 0

    for i in range(5):
        if row[f'days_past_pred_{model}_{i}'] == False:
            counter += 1
        else:
            break
    return counter


def calculate_custom_metricts(trad, nd, threshold=0.5):
    #display(trad)
    #display(nd)
    nd = nd[['ID','starttime', 'days_past', 'False', 'True']]
    trad = trad[['ID','starttime', 'days_past', 'False', 'True','lot_in_days']]
    
    combined = trad[['lot_in_days','ID','starttime','days_past','True','False']].merge(nd, on=['ID', 'starttime', 'days_past'], suffixes=('_trad', '_nd'))
    combined['pred_random'] = np.random.rand(len(combined))
    combined['pred_random'] = combined['pred_random'] > threshold
    combined['pred_trad'] = combined['True_trad'] > threshold
    combined['pred_nd'] = combined['True_nd'] > threshold #'True_nd'
    combined.drop(['False_trad','False_nd','True_trad','True_nd'], axis=1, inplace=True)
    combined['days_past'] = combined['days_past'].astype(int).astype(str)

    combined = combined.pivot_table(
        index=['ID', 'starttime', 'lot_in_days'], 
        columns='days_past', 
        values=['pred_trad','pred_nd', 'pred_random'], 
        aggfunc='first'
    )

    combined.columns = [f'days_past_{col[0]}_{col[1]}' for col in combined.columns]
    combined.reset_index(inplace=True)

    combined['predicted_days_of_treatment_trad'] = combined.apply(lambda row: count_sequence(row, 'trad'), axis=1)
    combined['predicted_days_of_treatment_nd'] = combined.apply(lambda row: count_sequence(row, 'nd'), axis=1)
    combined['predicted_days_of_treatment_random'] = combined.apply(lambda row: count_sequence(row, 'random'), axis=1)

    combined['saved_days_trad'] = combined.apply(lambda row: 0 if row['lot_in_days'] >= 5 else row['lot_in_days'] - row['predicted_days_of_treatment_trad'], axis=1)
    combined['missed_days_trad'] = combined.apply(lambda row: 0 if row['lot_in_days'] <= 5 else 5 - row['predicted_days_of_treatment_trad'], axis=1)

    combined['saved_days_nd'] = combined.apply(lambda row: 0 if row['lot_in_days'] >= 5 else row['lot_in_days'] - row['predicted_days_of_treatment_nd'], axis=1)
    combined['missed_days_nd'] = combined.apply(lambda row: 0 if row['lot_in_days'] <= 5 else 5 - row['predicted_days_of_treatment_nd'], axis=1)

    combined['saved_days_random'] = combined.apply(lambda row: 0 if row['lot_in_days'] >= 5 else row['lot_in_days'] - row['predicted_days_of_treatment_random'], axis=1)
    combined['missed_days_random'] = combined.apply(lambda row: 0 if row['lot_in_days'] <= 5 else 5 - row['predicted_days_of_treatment_random'], axis=1)


    df = combined

    res = pd.Series({
        'threshold': threshold,
        'GASAD_trad' : df['saved_days_trad'].mean(),
        'GASAD_nd' : df['saved_days_nd'].mean(),
        'GASAD_random' : df['saved_days_random'].mean(),
        'GAMAD_trad' : df['missed_days_trad'].mean(),
        'GAMAD_nd' : df['missed_days_nd'].mean(),
        'GAMAD_random' : df['missed_days_random'].mean(),
        'RASAD_trad' : df[(df['lot_in_days'] <= 5)]['saved_days_trad'].mean(), #& (df['predicted_days_of_treatment_trad'] <= df['lot_in_days'])
        'RASAD_nd' : df[(df['lot_in_days'] <= 5)]['saved_days_nd'].mean(), #& (df['predicted_days_of_treatment_nd'] <= df['lot_in_days'])
        'RASAD_random' : df[(df['lot_in_days'] <= 5)]['saved_days_random'].mean(), # & (df['predicted_days_of_treatment_random'] <= df['lot_in_days'])
        'RAMAD_trad' : df[(df['lot_in_days'] >= 5)]['missed_days_trad'].mean(), #& (df['predicted_days_of_treatment_trad'] < 5)
        'RAMAD_nd' : df[(df['lot_in_days'] >= 5)]['missed_days_nd'].mean(), # & (df['predicted_days_of_treatment_nd'] < 5)
        'RAMAD_random' : df[(df['lot_in_days'] >= 5)]['missed_days_random'].mean(), #  & (df['predicted_days_of_treatment_random'] < 5)
    })
    #display(res)
    return res

def compare_models(database='mimic', 
                  lookback=2, 
                  prediction_time_points='random', 
                  numberofsamples=1, 
                  sample_train=None, 
                  sample_test=None, 
                  seed:int=42, 
                  inc_ab=False,
                  has_microbiology=False,
                  model='LGBMClassifier',
                  dropout = 0.0,
                  hidden_dim = 256,
                  lamb = 0.1,
                  num_lin_layers = 2,
                  num_stacked_lstm = 3,
                  is_tuned = True,
                  lr = 0.01,
                  bs = 128,
                  use_relus = False, 
                  use_batchnormalization = False):
    if prediction_time_points == 'random':
        time_point = ('random', numberofsamples)

    traditional_path = 'data/model_input/traditional/'+database+'/microbiology_res_'+str(has_microbiology)+'/ab_'+str(inc_ab)+'/seed_'+str(seed)+'/'
    traditional_model_path = 'data/results/traditional/'+database+'/microbiology_res_'+str(has_microbiology)+'/ab_'+str(inc_ab)+'/seed_'+str(seed)+ \
                            '/lookback_'+str(lookback)+'/time_point'+str(time_point)+'/sample_'+str(sample_train)+"_"+str(sample_test)+"/"+model+"/"


    # first we get the complete test dataset for the traditional model
    X_traditional = pd.DataFrame()
    y_traditional = pd.DataFrame()
    for tp in [0,1,2,3,4]:
        X_part = pd.read_parquet(traditional_path+'X_test_time_point_'+str(tp)+'_lookback_'+str(lookback)+'.parquet')
        y_part = pd.read_parquet(traditional_path+'y_test_time_point_'+str(tp)+'_lookback_'+str(lookback)+'.parquet')

        X_traditional = pd.concat([X_traditional, X_part], axis=0, join='outer').fillna(0)
        y_traditional = pd.concat([y_traditional, y_part], axis=0, join='outer').fillna(0)
        #display(y_part)


    # next we load the traditional model
    print(traditional_model_path)
    model = joblib.load(traditional_model_path+'model.pkl')

    X_trained_original = pd.read_parquet(traditional_path+'X_train_time_point_(\'random\', 1)_lookback_'+str(lookback)+'.parquet')

    X_traditional = X_traditional[X_trained_original.columns]

    # calculate test set predictions
    pred_proba_test = pd.DataFrame(model.predict_proba(X_traditional), columns=['False','True'])

    y_traditional = pd.concat([y_traditional.reset_index(drop=True), pred_proba_test], axis=1) #, pred_test

    #print(calculate_custom_metricts(y_traditional))


    # get the predicitions from the next day model
    
    nd_path = "data/results/lstm/"+database+"/microbiology_res_"+str(has_microbiology)+"/ab_"+str(inc_ab)+"/use_censored_True/lookback_7/aggregated_hours_4/seed_"+str(seed)+'/'+ \
               "dropout_"+str(dropout).replace('.','-')+'/'+"lambda_"+str(lamb).replace('.','-')+'/'+"num_lin_layers_"+str(num_lin_layers)+'/' + \
               "num_stacked_lstm_"+str(num_stacked_lstm)+"/hidden_dim_"+str(hidden_dim)+"/lr_"+str(lr).replace('.','-')+"/bs_"+str(bs)+ \
               "/is_tuned_"+ str(is_tuned) + "/use_relus_"+ str(use_relus) + "/use_bn_"+ str(use_batchnormalization) +'/test_gt_and_preds.csv'
    y_next_day = pd.read_csv(nd_path)
    print(nd_path)
    
    y_next_day['starttime'] = pd.to_datetime(y_next_day['starttime'])
    y_next_day['starttime'] = y_next_day['starttime'].dt.floor('5min')
    y_next_day['days_past'] = y_next_day['days_past'].astype(int)

    df = pd.DataFrame()
    for t in np.arange(0, 1.01, 0.01).round(2).tolist(): #[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]: #0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1
        df_part = pd.DataFrame(calculate_custom_metricts(y_traditional, y_next_day, threshold=t)).transpose()
        df = pd.concat([df, df_part], axis=0)

    
    #pd.DataFrame(df).to_csv(result_path+'combined_result.csv')

    return df

    
res = pd.DataFrame()
for seed in [42, 43, 44, 45, 46]:
    df = compare_models(seed=seed)
    df['seed'] = seed
    res = pd.concat([res, df])


display(res)



In [None]:
mean_std_data = res.groupby('threshold').agg(['mean', 'std']).reset_index()

# Plot the values of GASAD_* columns with threshold on the x-axis
plt.figure(figsize=(5, 5))

for col in ['GASAD_trad', 'GASAD_nd', 'GASAD_random']:
    mean_col = mean_std_data[col]['mean']
    std_col = mean_std_data[col]['std']
    upper = mean_col + std_col
    lower = mean_col - std_col

    sns.lineplot(x='threshold', y=mean_col, data=mean_std_data, label=col)
    plt.fill_between(mean_std_data['threshold'], lower, upper, alpha=0.2)

plt.xlabel('Threshold', fontsize=fs)
plt.ylabel('Values', fontsize=fs)
#plt.title('GASAD Metrics with Standard Deviation')
plt.legend()
plt.legend(fontsize=fs)
plt.xticks(fontsize=fs)
plt.yticks(fontsize=fs)
plt.tight_layout()
plt.savefig('images/experiments/comparison/comparison_GASAD.png')
plt.show()

In [None]:
mean_std_data = res.groupby('threshold').agg(['mean', 'std']).reset_index()

# Plot the values of GASAD_* columns with threshold on the x-axis
plt.figure(figsize=(5, 5))

for col in ['GAMAD_trad','GAMAD_nd','GAMAD_random']:
    mean_col = mean_std_data[col]['mean']
    std_col = mean_std_data[col]['std']
    upper = mean_col + std_col
    lower = mean_col - std_col

    sns.lineplot(x='threshold', y=mean_col, data=mean_std_data, label=col)
    plt.fill_between(mean_std_data['threshold'], lower, upper, alpha=0.2)

plt.xlabel('Threshold', fontsize=fs)
plt.ylabel('Values', fontsize=fs)
#plt.title('GASAD Metrics with Standard Deviation')
plt.legend()
plt.legend(fontsize=fs)
plt.xticks(fontsize=fs)
plt.yticks(fontsize=fs)
plt.tight_layout()
plt.savefig('images/experiments/comparison/comparison_GAMAD.png')
plt.show()

In [None]:
mean_std_data = res.groupby('threshold').agg(['mean', 'std']).reset_index()

# Plot the values of GASAD_* columns with threshold on the x-axis
plt.figure(figsize=(5, 5))

for col in ['RASAD_trad', 'RASAD_nd', 'RASAD_random']:
    mean_col = mean_std_data[col]['mean']
    std_col = mean_std_data[col]['std']
    upper = mean_col + std_col
    lower = mean_col - std_col

    sns.lineplot(x='threshold', y=mean_col, data=mean_std_data, label=col)
    plt.fill_between(mean_std_data['threshold'], lower, upper, alpha=0.2)

plt.xlabel('Threshold', fontsize=fs)
plt.ylabel('Values', fontsize=fs)
#plt.title('GASAD Metrics with Standard Deviation')
plt.legend()
plt.legend(fontsize=fs)
plt.xticks(fontsize=fs)
plt.yticks(fontsize=fs)
plt.tight_layout()
plt.savefig('images/experiments/comparison/comparison_RASAD.png')
plt.show()

In [None]:
mean_std_data = res.groupby('threshold').agg(['mean', 'std']).reset_index()

# Plot the values of GASAD_* columns with threshold on the x-axis
plt.figure(figsize=(5, 5))

for col in ['RAMAD_trad','RAMAD_nd','RAMAD_random']:
    mean_col = mean_std_data[col]['mean']
    std_col = mean_std_data[col]['std']
    upper = mean_col + std_col
    lower = mean_col - std_col

    sns.lineplot(x='threshold', y=mean_col, data=mean_std_data, label=col)
    plt.fill_between(mean_std_data['threshold'], lower, upper, alpha=0.2)

plt.xlabel('Threshold', fontsize=fs)
plt.ylabel('Values', fontsize=fs)
#plt.title('GASAD Metrics with Standard Deviation')
plt.legend()
plt.legend(fontsize=fs)
plt.xticks(fontsize=fs)
plt.yticks(fontsize=fs)
plt.tight_layout()
plt.savefig('images/experiments/comparison/comparison_RAMAD.png')
plt.show()

In [None]:


print("nd was better on global scale")
display(df[(df['GASAD_trad'] < df['GASAD_nd']) & (df['GAMAD_nd'] < df['GAMAD_trad'])])


print("nd was better on restricted scale")
display(df[(df['RASAD_trad'] < df['RASAD_nd']) & (df['RAMAD_nd'] < df['RAMAD_trad'])])


print("trad was better on global scale")
display(df[(df['GASAD_trad'] > df['GASAD_nd']) & (df['GAMAD_nd'] > df['GAMAD_trad'])])


print("trad was better on restricted scale")
display(df[(df['RASAD_trad'] > df['RASAD_nd']) & (df['RAMAD_nd'] > df['RAMAD_trad'])])
