In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns

from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})

from tqdm import tqdm
from scipy.stats import wasserstein_distance

In [None]:
def load_results(simulation_type, run_id):
    runid_foldername = [foldername for foldername in os.listdir(f'experiments/{simulation_type}') if foldername == f'output_{run_id}'][0]
    output_folder = f'experiments/{simulation_type}/{runid_foldername}'

    seed_list = []
    file_list = os.listdir(f'{output_folder}')
    for file in file_list:
        if 'seed' in file:
            seed_list.append(file.split('_')[1])
    seed_list = list(set(seed_list))

    results_list = []
    for seed in tqdm(seed_list):
        df_seed = pd.read_csv(f'{output_folder}/seed_{seed}')
        df_seed['ed_los'] = df_seed['ed_los'] / 60
        df_seed['case_len'] = df_seed.destination_record.str.len()
        results_list.append(df_seed)
    df_results = pd.concat(results_list)

    return df_results

def plot_distribution_los(target_distribution, generated_distribution, color, branching_type, xlabel):
    fig, axes = plt.subplots(1, 1, figsize=(1.5,1))
    fig.dpi = 600

    axes.axvline(np.median(target_distribution), color='black', linestyle='--', label='Target\ndistribution')
    axes.axvline(np.median(generated_distribution), color=color, linestyle='--', label=branching_type + '\nbranching')

    sns.kdeplot(data=target_distribution, color='black', bw_adjust=2, clip=(0,12))
    sns.kdeplot(data=generated_distribution, color=color, linewidth=0, fill=True, alpha=0.5, bw_adjust=2, clip=(0,12))

    axes.set_xticks(np.arange(0, 13, 6))
    axes.set_xlim(0, 12)
    axes.set_ylim(0, 0.2)
    axes.set_yticks([0.0, 0.2])

    median_target = round(np.median(target_distribution), 1)
    median_generated = round(np.median(generated_distribution), 1)

    if branching_type == 'Independent':
        axes.text(median_target + 0.4, 0.17, f'{median_target}', fontsize=8, color='black')
        axes.text(median_generated - 2, 0.17, f'{median_generated}', fontsize=8, color='#E69F00')
    elif branching_type == 'Conditional':
        axes.text(median_target + 0.4, 0.17, f'{median_target}', fontsize=8, color='black')
        axes.text(median_target + 0.4, 0.14, f'{median_generated}', fontsize=8, color='#56B4E9')

    plt.legend()
    sns.move_legend(axes, 'lower center', bbox_to_anchor=(.5, 1), ncol=2, title=None,
    frameon=False, columnspacing=2, handlelength=0.8, handletextpad=0.8, reverse=False, fontsize='small')

    legend = axes.get_legend()
    title = legend.get_title()
    title.set_fontsize('small')
    plt.ylabel('Density', fontsize='small')
    plt.xlabel(xlabel, fontsize='small')
    plt.tick_params(axis='both', which='both', labelsize='small')

    plt.show()

def plot_distribution_caselen(target_distribution, generated_distribution, color, branching_type, xlabel):
    fig, axes = plt.subplots(1, 1, figsize=(1.5,1))
    fig.dpi = 600

    axes.axvline(np.median(target_distribution), color='black', linestyle='--', label='Target\ndistribution')
    axes.axvline(np.median(generated_distribution), color=color, linestyle='--', label=branching_type + '\nbranching')

    sns.kdeplot(data=target_distribution, color='black', bw_adjust=2, clip=(0,30))
    sns.kdeplot(data=generated_distribution, color=color, linewidth=0, fill=True, alpha=0.5, bw_adjust=2, clip=(0,30))

    axes.set_xticks(np.arange(0, 31, 15))
    axes.set_xlim(0, 30)
    axes.set_ylim(0, 0.1)
    axes.set_yticks([0.0, 0.1])

    median_target = round(np.median(target_distribution), 1)
    median_generated = round(np.median(generated_distribution), 1)

    if branching_type == 'Independent':
        axes.text(median_target + 0.8, 0.085, f'{median_target:.0f}', fontsize=8, color='black')
        axes.text(median_generated - 2.8, 0.085, f'{median_generated:.0f}', fontsize=8, color='#E69F00')
    elif branching_type == 'Conditional':
        axes.text(median_target + 0.8, 0.085, f'{median_target:.0f}', fontsize=8, color='black')
        axes.text(median_target + 0.8, 0.07, f'{median_generated:.0f}', fontsize=8, color='#56B4E9')

    plt.legend()
    sns.move_legend(axes, 'lower center', bbox_to_anchor=(.5, 1), ncol=2, title=None,
    frameon=False, columnspacing=2, handlelength=0.8, handletextpad=0.8, reverse=False, fontsize='small')

    legend = axes.get_legend()
    title = legend.get_title()
    title.set_fontsize('small')
    plt.ylabel('Density', fontsize='small')
    plt.xlabel(xlabel, fontsize='small')
    plt.tick_params(axis='both', which='both', labelsize='small')

    plt.show()

In [None]:
df_pro = load_results('experiment_1', '1')
df_ret = load_results('experiment_1', '2')

In [None]:
with open('params/los/los_overall.txt') as filehandle:
    los_overall_list = json.load(filehandle)

plot_distribution_los(los_overall_list, df_pro['ed_los'], '#E69F00', 'Independent', 'Length of stay')
plot_distribution_los(los_overall_list, df_ret['ed_los'], '#56B4E9', 'Conditional', 'Length of stay')

In [None]:
with open('params/caselen/caselen_overall.txt') as filehandle:
    caselen_overall_list = json.load(filehandle)

plot_distribution_caselen(caselen_overall_list, df_pro['case_len'], '#E69F00', 'Independent', 'Resource usage')
plot_distribution_caselen(caselen_overall_list, df_ret['case_len'], '#56B4E9', 'Conditional', 'Resource usage')

In [None]:
wasserstein_distance(los_overall_list, df_pro['ed_los'])

In [None]:
wasserstein_distance(los_overall_list, df_ret['ed_los'])

In [None]:
wasserstein_distance(caselen_overall_list, df_pro['case_len'])

In [None]:
wasserstein_distance(caselen_overall_list, df_ret['case_len'])