In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns

from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})

from scipy.stats import wasserstein_distance
from tqdm import tqdm

In [None]:
list_acuity = ['1', '2', '3', '4', '5']
list_disposition = ['HOME', 'WARD', 'ICU']
list_complexity = ['LOW', 'MODERATE', 'HIGH']

In [None]:
# Get target LOS distributions
with open(f'params/los/los_overall.txt') as filehandle:
    target_los_all = json.load(filehandle)

target_los_acuity = {}
for acuity_val in list_acuity:
    with open(f'params/los/los_acuity_{acuity_val}.txt') as filehandle:
        target_los_acuity[acuity_val] = json.load(filehandle)

target_los_disposition = {}
for disposition_val in list_disposition:
    with open(f'params/los/los_disposition_{disposition_val}.txt') as filehandle:
        target_los_disposition[disposition_val] = json.load(filehandle)

target_los_complexity = {}
for complexity_val in list_complexity:
    with open(f'params/los/los_complexity_{complexity_val}.txt') as filehandle:
        target_los_complexity[complexity_val] = json.load(filehandle)

target_los_acuity_disposition = {}
target_los_acuity_complexity = {}
for acuity_val in list_acuity:
    for disposition_val in list_disposition:
        los_disposition = []
        for complexity_val in list_complexity:
            try:
                with open(f'params/los/los_groupname_{acuity_val}-{disposition_val}-{complexity_val}.txt') as filehandle:
                    los_temp = json.load(filehandle)
                    los_disposition.extend(los_temp)
            except:
                continue
        target_los_acuity_disposition[f'{acuity_val}_{disposition_val}'] = los_disposition

    for complexity_val in list_complexity:
        los_complexity = []
        for disposition_val in list_disposition:
            try:
                with open(f'params/los/los_groupname_{acuity_val}-{disposition_val}-{complexity_val}.txt') as filehandle:
                    los_temp = json.load(filehandle)
                    los_complexity.extend(los_temp)
            except:
                continue
        target_los_acuity_complexity[f'{acuity_val}_{complexity_val}'] = los_complexity

In [None]:
def load_results(simulation_type, run_id):
    runid_foldername = [foldername for foldername in os.listdir(f'experiments/{simulation_type}') if foldername == f'output_{run_id}'][0]
    output_folder = f'experiments/{simulation_type}/{runid_foldername}'

    seed_list = []
    file_list = os.listdir(f'{output_folder}')
    for file in file_list:
        if 'seed' in file:
            seed_list.append(file.split('_')[1])
    seed_list = list(set(seed_list))

    df_results_concatenated = pd.DataFrame()
    df_results_per_run = []
    for seed in tqdm(seed_list):
        df_seed = pd.read_csv(f'{output_folder}/seed_{seed}', dtype={'acuity': str, 'disposition': str, 'complexity': str, 'ed_los': float})
        df_seed['ed_los'] = df_seed['ed_los'] / 60
        df_seed['case_len'] = df_seed.destination_record.str.len()

        df_results_per_run.append(df_seed)
    df_results_concatenated = pd.concat(df_results_per_run)

    return df_results_concatenated, df_results_per_run

In [None]:
df_pro_concatenated, df_pro_per_run = load_results('experiment_2', '1')
df_ret_concatenated, df_ret_per_run = load_results('experiment_2', '2')

In [None]:
fig, axes = plt.subplots(7, 6, figsize=(2*6,2*7), sharey=True, sharex=True)
fig.dpi = 600

row_idx = 0
for col_idx, acuity_val in enumerate(list_acuity):
    target_data = target_los_acuity[f'{acuity_val}']

    sns.kdeplot(ax=axes[row_idx, col_idx], data=target_data, color='black')
    axes[row_idx, col_idx].set_xticks(np.arange(0, 13, 3))
    axes[row_idx, col_idx].set_yticks(np.arange(0, 0.41, 0.1))
    axes[row_idx, col_idx].set_xlim(0, 12)
    axes[row_idx, col_idx].set_ylim(0, 0.4)

    pro_data = df_pro_concatenated[(df_pro_concatenated['acuity'] == acuity_val)]['ed_los']
    sns.kdeplot(ax=axes[row_idx, col_idx], data=pro_data, color='#0072B2', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

    ret_data = df_ret_concatenated[(df_ret_concatenated['acuity'] == acuity_val)]['ed_los']
    sns.kdeplot(ax=axes[row_idx, col_idx], data=ret_data, color='#009E73', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

    axes[row_idx, col_idx].set_title(f'{acuity_val}', fontsize='small', color='black', pad=3)
    if len(target_data):
        axes[row_idx, col_idx].text(0.6, 0.355, f'MD:', fontsize=8, color='black')
        axes[row_idx, col_idx].text(3.8, 0.355, f'{np.median(target_data):.1f}', fontsize=8, color='black')
        axes[row_idx, col_idx].text(6.8, 0.355, f'{np.median(pro_data):.1f}', fontsize=8, color='#0072B2')
        axes[row_idx, col_idx].text(9.8, 0.355, f'{np.median(ret_data):.1f}', fontsize=8, color='#009E73')

        pro_wd = wasserstein_distance(target_data, pro_data)
        ret_wd = wasserstein_distance(target_data, ret_data)

        axes[row_idx, col_idx].text(0.6, 0.32, f'WD:', fontsize=8, color='black')
        axes[row_idx, col_idx].text(6.8, 0.32, f'{pro_wd:.1f}', fontsize=8, color='#0072B2')
        axes[row_idx, col_idx].text(9.8, 0.32, f'{ret_wd:.1f}', fontsize=8, color='#009E73')

    else:
        sns.kdeplot(ax=axes[row_idx, col_idx], data=[90, 100, 110])
        axes[row_idx, col_idx].text(0.6, 0.355, f'MD:', fontsize=8, color='black')
        axes[row_idx, col_idx].text(3.8, 0.355, f'NA', fontsize=8, color='black')
        axes[row_idx, col_idx].text(6.8, 0.355, f'NA', fontsize=8, color='#0072B2')
        axes[row_idx, col_idx].text(9.8, 0.355, f'NA', fontsize=8, color='#009E73')
        axes[row_idx, col_idx].text(0.6, 0.32, f'WD:', fontsize=8, color='black')
        axes[row_idx, col_idx].text(6.8, 0.32, f'NA', fontsize=8, color='#0072B2')
        axes[row_idx, col_idx].text(9.8, 0.32, f'NA', fontsize=8, color='#009E73')

    axes[row_idx, col_idx].set_ylabel('Density', fontsize='small')
    axes[row_idx, col_idx].set_xlabel('Length of stay', fontsize='small')

    axes[row_idx, col_idx].grid(True, linestyle=':')


target_data = target_los_all

sns.kdeplot(ax=axes[row_idx, col_idx+1], data=target_data, color='black')
axes[row_idx, col_idx+1].set_xticks(np.arange(0, 13, 3))
axes[row_idx, col_idx+1].set_yticks(np.arange(0, 0.41, 0.1))
axes[row_idx, col_idx+1].set_xlim(0, 12)
axes[row_idx, col_idx+1].set_ylim(0, 0.4)

pro_data = df_pro_concatenated['ed_los']
sns.kdeplot(ax=axes[row_idx, col_idx+1], data=pro_data, color='#0072B2', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

ret_data = df_ret_concatenated['ed_los']
sns.kdeplot(ax=axes[row_idx, col_idx+1], data=ret_data, color='#009E73', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

axes[row_idx, col_idx+1].set_title('Overall', fontsize='small', color='black', pad=3)
axes[row_idx, col_idx+1].text(0.6, 0.355, f'MD:', fontsize=8, color='black')
axes[row_idx, col_idx+1].text(3.8, 0.355, f'{np.median(target_data):.1f}', fontsize=8, color='black')
axes[row_idx, col_idx+1].text(6.8, 0.355, f'{np.median(pro_data):.1f}', fontsize=8, color='#0072B2')
axes[row_idx, col_idx+1].text(9.8, 0.355, f'{np.median(ret_data):.1f}', fontsize=8, color='#009E73')

pro_wd = wasserstein_distance(target_data, pro_data)
ret_wd = wasserstein_distance(target_data, ret_data)

axes[row_idx, col_idx+1].text(0.6, 0.32, f'WD:', fontsize=8, color='black')
axes[row_idx, col_idx+1].text(6.8, 0.32, f'{pro_wd:.1f}', fontsize=8, color='#0072B2')
axes[row_idx, col_idx+1].text(9.8, 0.32, f'{ret_wd:.1f}', fontsize=8, color='#009E73')

axes[row_idx, col_idx+1].set_ylabel('Density', fontsize='small')
axes[row_idx, col_idx+1].set_xlabel('Length of stay', fontsize='small')

axes[row_idx, col_idx+1].grid(True, linestyle=':')

for row_idx, disposition_val in enumerate(list_disposition, 1):
    for col_idx, acuity_val in enumerate(list_acuity):
        target_data = target_los_acuity_disposition[f'{acuity_val}_{disposition_val}']

        sns.kdeplot(ax=axes[row_idx, col_idx], data=target_data, color='black')
        axes[row_idx, col_idx].set_xticks(np.arange(0, 13, 3))
        axes[row_idx, col_idx].set_yticks(np.arange(0, 0.41, 0.1))
        axes[row_idx, col_idx].set_xlim(0, 12)
        axes[row_idx, col_idx].set_ylim(0, 0.4)

        pro_data = df_pro_concatenated[(df_pro_concatenated['acuity'] == acuity_val) & (df_pro_concatenated['disposition'] == disposition_val)]['ed_los']
        sns.kdeplot(ax=axes[row_idx, col_idx], data=pro_data, color='#0072B2', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

        ret_data = df_ret_concatenated[(df_ret_concatenated['acuity'] == acuity_val) & (df_ret_concatenated['disposition'] == disposition_val)]['ed_los']
        sns.kdeplot(ax=axes[row_idx, col_idx], data=ret_data, color='#009E73', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

        if disposition_val == 'HOME':
            axes[row_idx, col_idx].set_title(f'{acuity_val}, Home', fontsize='small', color='black', pad=3)
        elif disposition_val == 'WARD':
            axes[row_idx, col_idx].set_title(f'{acuity_val}, Ward', fontsize='small', color='black', pad=3)
        else:
            axes[row_idx, col_idx].set_title(f'{acuity_val}, ICU', fontsize='small', color='black', pad=3)
        if len(target_data):
            axes[row_idx, col_idx].text(0.6, 0.355, f'MD:', fontsize=8, color='black')
            axes[row_idx, col_idx].text(3.8, 0.355, f'{np.median(target_data):.1f}', fontsize=8, color='black')
            axes[row_idx, col_idx].text(6.8, 0.355, f'{np.median(pro_data):.1f}', fontsize=8, color='#0072B2')
            axes[row_idx, col_idx].text(9.8, 0.355, f'{np.median(ret_data):.1f}', fontsize=8, color='#009E73')

            pro_wd = wasserstein_distance(target_data, pro_data)
            ret_wd = wasserstein_distance(target_data, ret_data)

            axes[row_idx, col_idx].text(0.6, 0.32, f'WD:', fontsize=8, color='black')
            axes[row_idx, col_idx].text(6.8, 0.32, f'{pro_wd:.1f}', fontsize=8, color='#0072B2')
            axes[row_idx, col_idx].text(9.8, 0.32, f'{ret_wd:.1f}', fontsize=8, color='#009E73')

        else:
            sns.kdeplot(ax=axes[row_idx, col_idx], data=[90, 100, 110])
            axes[row_idx, col_idx].text(0.6, 0.355, f'MD:', fontsize=8, color='black')
            axes[row_idx, col_idx].text(3.8, 0.355, f'NA', fontsize=8, color='black')
            axes[row_idx, col_idx].text(6.8, 0.355, f'NA', fontsize=8, color='#0072B2')
            axes[row_idx, col_idx].text(9.8, 0.355, f'NA', fontsize=8, color='#009E73')
            axes[row_idx, col_idx].text(0.6, 0.32, f'WD:', fontsize=8, color='black')
            axes[row_idx, col_idx].text(6.8, 0.32, f'NA', fontsize=8, color='#0072B2')
            axes[row_idx, col_idx].text(9.8, 0.32, f'NA', fontsize=8, color='#009E73')

        axes[row_idx, col_idx].set_ylabel('Density', fontsize='small')
        axes[row_idx, col_idx].set_xlabel('Length of stay', fontsize='small')

        axes[row_idx, col_idx].grid(True, linestyle=':')

    target_data = target_los_disposition[f'{disposition_val}']
    sns.kdeplot(ax=axes[row_idx, col_idx+1], data=target_data, color='black')
    axes[row_idx, col_idx+1].set_xticks(np.arange(0, 13, 3))
    axes[row_idx, col_idx+1].set_yticks(np.arange(0, 0.41, 0.1))
    axes[row_idx, col_idx+1].set_xlim(0, 12)
    axes[row_idx, col_idx+1].set_ylim(0, 0.4)

    pro_data = df_pro_concatenated[(df_pro_concatenated['disposition'] == disposition_val)]['ed_los']
    sns.kdeplot(ax=axes[row_idx, col_idx+1], data=pro_data, color='#0072B2', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

    ret_data = df_ret_concatenated[(df_ret_concatenated['disposition'] == disposition_val)]['ed_los']
    sns.kdeplot(ax=axes[row_idx, col_idx+1], data=ret_data, color='#009E73', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

    if disposition_val == 'HOME':
        axes[row_idx, col_idx+1].set_title(f'Home', fontsize='small', color='black', pad=3)
    elif disposition_val == 'WARD':
        axes[row_idx, col_idx+1].set_title(f'Ward', fontsize='small', color='black', pad=3)
    else:
        axes[row_idx, col_idx+1].set_title(f'ICU', fontsize='small', color='black', pad=3)

    axes[row_idx, col_idx+1].text(0.6, 0.355, f'MD:', fontsize=8, color='black')
    axes[row_idx, col_idx+1].text(3.8, 0.355, f'{np.median(target_data):.1f}', fontsize=8, color='black')
    axes[row_idx, col_idx+1].text(6.8, 0.355, f'{np.median(pro_data):.1f}', fontsize=8, color='#0072B2')
    axes[row_idx, col_idx+1].text(9.8, 0.355, f'{np.median(ret_data):.1f}', fontsize=8, color='#009E73')

    pro_wd = wasserstein_distance(target_data, pro_data)
    ret_wd = wasserstein_distance(target_data, ret_data)

    axes[row_idx, col_idx+1].text(0.6, 0.32, f'WD:', fontsize=8, color='black')
    axes[row_idx, col_idx+1].text(6.8, 0.32, f'{pro_wd:.1f}', fontsize=8, color='#0072B2')
    axes[row_idx, col_idx+1].text(9.8, 0.32, f'{ret_wd:.1f}', fontsize=8, color='#009E73')

    axes[row_idx, col_idx+1].set_ylabel('Density', fontsize='small')
    axes[row_idx, col_idx+1].set_xlabel('Length of stay', fontsize='small')

    axes[row_idx, col_idx+1].grid(True, linestyle=':')

for row_idx, complexity_val in enumerate(list_complexity, 4):
    for col_idx, acuity_val in enumerate(list_acuity):
        target_data = target_los_acuity_complexity[f'{acuity_val}_{complexity_val}']

        sns.kdeplot(ax=axes[row_idx, col_idx], data=target_data, color='black')
        axes[row_idx, col_idx].set_xticks(np.arange(0, 13, 3))
        axes[row_idx, col_idx].set_yticks(np.arange(0, 0.41, 0.1))
        axes[row_idx, col_idx].set_xlim(0, 12)
        axes[row_idx, col_idx].set_ylim(0, 0.4)

        pro_data = df_pro_concatenated[(df_pro_concatenated['acuity'] == acuity_val) & (df_pro_concatenated['complexity'] == complexity_val)]['ed_los']
        sns.kdeplot(ax=axes[row_idx, col_idx], data=pro_data, color='#0072B2', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

        ret_data = df_ret_concatenated[(df_ret_concatenated['acuity'] == acuity_val) & (df_ret_concatenated['complexity'] == complexity_val)]['ed_los']
        sns.kdeplot(ax=axes[row_idx, col_idx], data=ret_data, color='#009E73', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

        if complexity_val == 'LOW':
            axes[row_idx, col_idx].set_title(f'{acuity_val}, Low', fontsize='small', color='black', pad=4)
        elif complexity_val == 'MODERATE':
            axes[row_idx, col_idx].set_title(f'{acuity_val}, Moderate', fontsize='small', color='black', pad=3)
        else:
            axes[row_idx, col_idx].set_title(f'{acuity_val}, High', fontsize='small', color='black', pad=3)

        if len(target_data):
            axes[row_idx, col_idx].text(0.6, 0.355, f'MD:', fontsize=8, color='black')
            axes[row_idx, col_idx].text(3.8, 0.355, f'{np.median(target_data):.1f}', fontsize=8, color='black')
            axes[row_idx, col_idx].text(6.8, 0.355, f'{np.median(pro_data):.1f}', fontsize=8, color='#0072B2')
            axes[row_idx, col_idx].text(9.8, 0.355, f'{np.median(ret_data):.1f}', fontsize=8, color='#009E73')

            pro_wd = wasserstein_distance(target_data, pro_data)
            ret_wd = wasserstein_distance(target_data, ret_data)

            axes[row_idx, col_idx].text(0.6, 0.32, f'WD:', fontsize=8, color='black')
            axes[row_idx, col_idx].text(6.8, 0.32, f'{pro_wd:.1f}', fontsize=8, color='#0072B2')
            axes[row_idx, col_idx].text(9.8, 0.32, f'{ret_wd:.1f}', fontsize=8, color='#009E73')

        else:
            sns.kdeplot(ax=axes[row_idx, col_idx], data=[90, 100, 110])
            axes[row_idx, col_idx].text(0.7, 0.355, f'MD:', fontsize=8, color='black')
            axes[row_idx, col_idx].text(3.8, 0.355, f'NA', fontsize=8, color='black')
            axes[row_idx, col_idx].text(6.8, 0.355, f'NA', fontsize=8, color='#0072B2')
            axes[row_idx, col_idx].text(9.8, 0.355, f'NA', fontsize=8, color='#009E73')
            axes[row_idx, col_idx].text(0.7, 0.32, f'WD:', fontsize=8, color='black')
            axes[row_idx, col_idx].text(6.8, 0.32, f'NA', fontsize=8, color='#0072B2')
            axes[row_idx, col_idx].text(9.8, 0.32, f'NA', fontsize=8, color='#009E73')

        axes[row_idx, col_idx].set_ylabel('Density', fontsize='small')
        axes[row_idx, col_idx].set_xlabel('Length of stay', fontsize='small')

        axes[row_idx, col_idx].grid(True, linestyle=':')

    target_data = target_los_complexity[f'{complexity_val}']

    sns.kdeplot(ax=axes[row_idx, col_idx+1], data=target_data, color='black')
    axes[row_idx, col_idx+1].set_xticks(np.arange(0, 13, 3))
    axes[row_idx, col_idx+1].set_yticks(np.arange(0, 0.41, 0.1))
    axes[row_idx, col_idx+1].set_xlim(0, 12)
    axes[row_idx, col_idx+1].set_ylim(0, 0.4)

    pro_data = df_pro_concatenated[(df_pro_concatenated['complexity'] == complexity_val)]['ed_los']
    sns.kdeplot(ax=axes[row_idx, col_idx+1], data=pro_data, color='#0072B2', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

    ret_data = df_ret_concatenated[(df_ret_concatenated['complexity'] == complexity_val)]['ed_los']
    sns.kdeplot(ax=axes[row_idx, col_idx+1], data=ret_data, color='#009E73', linewidth=0, fill=True, alpha=0.5, bw_adjust=1.5)

    if complexity_val == 'LOW':
        axes[row_idx, col_idx+1].set_title(f'Low', fontsize='small', color='black', pad=4)
    elif complexity_val == 'MODERATE':
        axes[row_idx, col_idx+1].set_title(f'Moderate', fontsize='small', color='black', pad=3)
    else:
        axes[row_idx, col_idx+1].set_title(f'High', fontsize='small', color='black', pad=3)

    axes[row_idx, col_idx+1].text(0.6, 0.355, f'MD:', fontsize=8, color='black')
    axes[row_idx, col_idx+1].text(3.8, 0.355, f'{np.median(target_data):.1f}', fontsize=8, color='black')
    axes[row_idx, col_idx+1].text(6.8, 0.355, f'{np.median(pro_data):.1f}', fontsize=8, color='#0072B2')
    axes[row_idx, col_idx+1].text(9.8, 0.355, f'{np.median(ret_data):.1f}', fontsize=8, color='#009E73')

    pro_wd = wasserstein_distance(target_data, pro_data)
    ret_wd = wasserstein_distance(target_data, ret_data)

    axes[row_idx, col_idx+1].text(0.6, 0.32, f'WD:', fontsize=8, color='black')
    axes[row_idx, col_idx+1].text(6.8, 0.32, f'{pro_wd:.1f}', fontsize=8, color='#0072B2')
    axes[row_idx, col_idx+1].text(9.8, 0.32, f'{ret_wd:.1f}', fontsize=8, color='#009E73')

    axes[row_idx, col_idx+1].set_ylabel('Density', fontsize='small')
    axes[row_idx, col_idx+1].set_xlabel('Length of stay', fontsize='small')

    axes[row_idx, col_idx+1].grid(True, linestyle=':')

plt.subplots_adjust(wspace=0.15, hspace=0.25)
