In [1]:
import numpy as np
import pandas as pd

from collections import OrderedDict
from pysankey2 import Sankey
from pysankey2.utils import setColorConf
from venn import venn

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})

from dataset_path import output_path

## Load model results

In [2]:
model_type = 'rf'

df_task1 = pd.read_csv(f'{output_path}/results_hospitalization_{model_type}.csv')
df_task2 = pd.read_csv(f'{output_path}/results_critical_{model_type}.csv')
df_task3 = pd.read_csv(f'{output_path}/results_revisit_{model_type}.csv', low_memory=False)
df_task4 = pd.read_csv(f'{output_path}/results_triage_{model_type}.csv')

In [3]:
from functools import reduce

# List of DataFrames
all_dfs = [df_task1, df_task2, df_task3, df_task4]

# Merge all DataFrames in the list dfs based on 'key'
merged_df = reduce(lambda left, right: pd.merge(left, right, how='outer'), all_dfs)

# Set label_revisit from NaN to 4 (true label, true prediction) for admitted patients
merged_df['label_revisit'] = merged_df['label_revisit'].fillna(1)

In [4]:
mapping = {1: 'no, no', 2: 'no, yes', 3: 'yes, no', 4: 'yes, yes'}
merged_df['label_hospitalization'] = merged_df['label_hospitalization'].replace(mapping)

mapping = {1: 'no, no', 2: 'no, yes', 3: 'yes, no', 4: 'yes, yes'}
merged_df['label_critical'] = merged_df['label_critical'].replace(mapping)

mapping = {1: 'no, no', 2: 'no, yes', 3: 'yes, no', 4: 'yes, yes'}
merged_df['label_revisit'] = merged_df['label_revisit'].replace(mapping)

mapping = {-1: 'under', 0: 'correct', 1: 'over'}
merged_df['label_triage'] = merged_df['label_triage'].replace(mapping)

In [None]:
display(merged_df.groupby('label_hospitalization').index.count() / len(df_task1))
display(merged_df.groupby('label_critical').index.count() / len(df_task2))
display(merged_df.groupby('label_revisit').index.count() / len(df_task3))
display(merged_df.groupby('label_triage').index.count() / len(df_task4))

In [6]:
results_dict = {'layer1': merged_df['label_hospitalization'].to_list(),
                'layer2': merged_df['label_critical'].to_list(),
                'layer3': merged_df['label_revisit'].to_list(),
                'layer4': merged_df['label_triage'].to_list()}
df_sankey = pd.DataFrame.from_dict(results_dict, orient='columns')

In [None]:
layer_labels= {'layer1': ['no, yes', 'yes, no', 'no, no', 'yes, yes'],
               'layer2': ['no, yes', 'yes, no', 'no, no', 'yes, yes'],
               'layer3': ['no, yes', 'yes, no', 'no, no', 'yes, yes'],
               'layer4': ['under', 'over', 'correct']}

layer_labels = OrderedDict(layer_labels)
fruits = list(set(df_sankey.layer1).union(set(df_sankey.layer2).union(set(df_sankey.layer3)).union(set(df_sankey.layer4))))
colors = setColorConf(len(fruits), colors='Accent')
cls_map = dict(zip(fruits,colors))

good_color = ['correct', 'yes, yes', 'no, no', '>4, >4', '<=4, <=4']
new_cls_map = {}
for key, val in cls_map.items():
    if key in good_color:
        new_cls_map[key] = colors[0]
    else:
        new_cls_map[key] = colors[1]

sky_auto_global_colors = Sankey(df_sankey, colorMode="global", colorDict=new_cls_map, stripColor='left', layerLabels=layer_labels)
sky_auto_global_colors._layerLabels = layer_labels
fig, ax = sky_auto_global_colors.plot(strip_kws={'linewidth': 0}, figSize=(8,4), fontSize=8)
fig.dpi = 600

In [None]:
merged_df['miss_hospitalization'] = merged_df['label_hospitalization'].apply(lambda x: 0 if x in ['yes, yes', 'no, no'] else 1)
merged_df['miss_critical'] = merged_df['label_critical'].apply(lambda x: 0 if x in ['yes, yes', 'no, no'] else 1)
merged_df['miss_revisit'] = merged_df['label_revisit'].apply(lambda x: 0 if x in ['yes, yes', 'no, no'] else 1)
merged_df['miss_triage'] = merged_df['label_triage'].apply(lambda x: 0 if x == 'correct' else 1)

merged_df['miss_total'] = merged_df['miss_hospitalization'] + merged_df['miss_revisit'] + merged_df['miss_revisit'] + merged_df['miss_triage']
display(merged_df.groupby('miss_total')['stay_id'].count())
display(merged_df.groupby('miss_total')['stay_id'].count() / len(merged_df) * 100)

In [None]:
patient_grp = {
    "hospitalisation": {x for x in merged_df[merged_df['miss_hospitalization'] == 1].stay_id},
    "critical outcomes": {x for x in merged_df[merged_df['miss_critical'] == 1].stay_id},
    "72hr reattendance": {x for x in merged_df[merged_df['miss_revisit'] == 1].stay_id},
    "triage": {x for x in merged_df[merged_df['miss_triage'] == 1].stay_id},
}

cmap = ['#1f77b4', '#ff7f0e', '#2ca02c', '#9467bd']
venn(patient_grp, legend_loc="upper left", cmap=cmap)

## Update patient characteristics

In [10]:
# Update acuity
merged_df['triage_acuity'] = merged_df['triage_acuity'].replace(1.0, '1')
merged_df['triage_acuity'] = merged_df['triage_acuity'].replace(2.0, '2')
merged_df['triage_acuity'] = merged_df['triage_acuity'].replace(3.0, '3')
merged_df['triage_acuity'] = merged_df['triage_acuity'].replace(4.0, '4')
merged_df['triage_acuity'] = merged_df['triage_acuity'].replace(5.0, '5')

In [11]:
# Update gender
merged_df['gender'] = merged_df['gender'].replace('F', 'FEMALE')
merged_df['gender'] = merged_df['gender'].replace('M', 'MALE')

In [12]:
# Update ethnicity
merged_df['ethnicity'] = merged_df['ethnicity'].replace(['ASIAN',
                                               'ASIAN - ASIAN INDIAN',
                                               'ASIAN - CHINESE',
                                               'ASIAN - KOREAN',
                                               'ASIAN - SOUTH EAST ASIAN'], 'ASIAN')

merged_df['ethnicity'] = merged_df['ethnicity'].replace(['BLACK/AFRICAN',
                                               'BLACK/AFRICAN AMERICAN',
                                               'BLACK/CAPE VERDEAN',
                                               'BLACK/CARIBBEAN ISLAND'], 'BLACK')

merged_df['ethnicity'] = merged_df['ethnicity'].replace(['HISPANIC OR LATINO',
                                               'HISPANIC/LATINO - CENTRAL AMERICAN',
                                               'HISPANIC/LATINO - COLUMBIAN',
                                               'HISPANIC/LATINO - CUBAN',
                                               'HISPANIC/LATINO - DOMINICAN',
                                               'HISPANIC/LATINO - GUATEMALAN',
                                               'HISPANIC/LATINO - HONDURAN',
                                               'HISPANIC/LATINO - MEXICAN',
                                               'HISPANIC/LATINO - PUERTO RICAN',
                                               'HISPANIC/LATINO - SALVADORAN'], 'HISPANIC')

merged_df['ethnicity'] = merged_df['ethnicity'].replace(['WHITE',
                                               'WHITE - BRAZILIAN',
                                               'WHITE - EASTERN EUROPEAN',
                                               'WHITE - OTHER EUROPEAN',
                                               'WHITE - RUSSIAN'], 'WHITE')

merged_df['ethnicity'] = merged_df['ethnicity'].replace(['MULTIPLE RACE/ETHNICITY',
                                               'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER',
                                               'OTHER',
                                               'PATIENT DECLINED TO ANSWER',
                                               'PORTUGUESE',
                                               'SOUTH AMERICAN',
                                               'UNABLE TO OBTAIN',
                                               'UNKNOWN',
                                               'AMERICAN INDIAN/ALASKA NATIVE',
                                               np.nan], 'OTHERS')

In [13]:
# Update age group

# Define age bins and labels for 5 groups
bins_5 = [0, 20, 40, 60, 80, 200]
labels_5 = ['0-19', '20-39', '40-59', '60-79', '80+']

# Add a new column 'AgeGroup' to the DataFrame
merged_df['age_group'] = pd.cut(merged_df['age'], bins=bins_5, labels=labels_5, right=False)

In [14]:
# Update CCI
def calculate_cci(row):
    # Define the weights for each condition
    weights = {
        'cci_MI': 1,
        'cci_CHF': 1,
        'cci_PVD': 1,
        'cci_Stroke': 1,
        'cci_Dementia': 1,
        'cci_Pulmonary': 1,
        'cci_Rheumatic': 1,
        'cci_PUD': 1,
        'cci_Liver1': 1,
        'cci_DM1': 1,
        'cci_DM2': 2,
        'cci_Paralysis': 2,
        'cci_Renal': 2,
        'cci_Cancer1': 2,
        'cci_Liver2': 3,
        'cci_Cancer2': 6,
        'cci_HIV': 6
    }

    # Calculate the total CCI score
    cci_score = sum(weights[condition] for condition in weights if row[condition] == 1)

    # Group the CCI score into categories
    if cci_score == 0:
        cci_group = 'No'
    elif cci_score == 1 or cci_score == 2:
        cci_group = 'Low'
    elif cci_score == 3 or cci_score == 4:
        cci_group = 'Moderate'
    else:
        cci_group = 'High'

    return cci_score, cci_group

conditions_list = ["cci_MI", "cci_CHF", "cci_PVD", "cci_Stroke", "cci_Dementia", "cci_Pulmonary", "cci_Rheumatic", "cci_PUD", "cci_Liver1", "cci_DM1", "cci_DM2", "cci_Paralysis", "cci_Renal", "cci_Cancer1", "cci_Liver2", "cci_Cancer2", "cci_HIV"]
merged_df[['CCI_Score', 'CCI_Group']] = merged_df.apply(lambda row: pd.Series(calculate_cci(row)), axis=1)

## Analyse patient groups

### Distribution across demographics and clinical labels

In [None]:
palette = ['#E69F00', '#56B4E9', '#009E73', '#D55E00', '#CC79A7']

for demo_type in ['triage_acuity', 'age_group', 'gender', 'ethnicity', 'CCI_Group', 'disposition']:
    display(merged_df.groupby([demo_type, 'miss_total']).stay_id.count() / merged_df.groupby([demo_type]).stay_id.count() * 100)

    fig, ax = plt.subplots(1, 1, figsize=(4,0.3*len(set(merged_df[demo_type]))))
    fig.dpi = 600
    ax = sns.histplot(data=merged_df.sort_values(demo_type), y=demo_type, hue='miss_total', hue_order=[4, 3, 2, 1, 0], stat='percent', multiple='fill', shrink=0.8, alpha=0.5, palette=palette)
    sns.move_legend(
        ax, 'upper left',
        bbox_to_anchor=(1, 1.05), ncol=1, title='Misprediction count', frameon=False, columnspacing=0.9, handlelength=0.8, handletextpad=0.5, reverse=True
    )
    plt.ylabel(demo_type)
    # ax.set_yticks([0, 1, 2, 3, 4])
    plt.xlabel('Percentage')
    ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])
    ax.set_xticklabels(['0', '20', '40', '60', '80', '100'])
    plt.show()


### Quality of care

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3,3))
fig.dpi = 600

palette = ['#E69F00', '#56B4E9', '#009E73', '#D55E00', '#CC79A7']
for idx, disp_type in enumerate([False, True]):
    df_group = merged_df[merged_df['outcome_hospitalization'] == disp_type].copy()

    # Calculate relative median los per patient group
    median_los_per_group = df_group[df_group['miss_total'] == 0]['ed_los_hours'].median()
    df_group['rel_los'] = df_group['ed_los_hours'] - median_los_per_group

    grouped = df_group.groupby('miss_total').agg(
        median=('rel_los', 'median'),
        ci95_lower=('rel_los', lambda x: np.percentile(x, 25)),
        ci95_upper=('rel_los', lambda x: np.percentile(x, 75)),
        count=('rel_los', 'count')
    )
    grouped = grouped.reset_index()

    plt.plot(grouped['miss_total'], grouped['median'], marker='o', markersize=4, markeredgecolor='w', markeredgewidth=0.75, label=disp_type, color=palette[idx])
    plt.fill_between(grouped['miss_total'], grouped['ci95_lower'], grouped['ci95_upper'], alpha=0.2, linewidth=0, color=palette[idx])


plt.ylim(-6,6)
ax.set_yticks(np.arange(-6,6.1,2))
plt.xlim(0,4)
ax.set_xticks(np.arange(0,5,1))

plt.axhline(0, color='black', linewidth=1)
plt.grid(linewidth=0.5)
plt.xlabel('Misprediction count')
plt.ylabel('Relative median length of stay')

ncol = 2
ybox = 1.25
ax.legend(bbox_to_anchor=(0.5, ybox), loc='upper center', frameon=False, handlelength=0.9, ncol=ncol, title='Admitted', columnspacing=0.9)

plt.show()

In [17]:
def get_stat_per_demo(df_input, is_admitted, demo_type):
    if is_admitted:
        sub_group = df_input[df_input['outcome_hospitalization'] == 1].copy()
    else:
        sub_group = df_input[df_input['outcome_hospitalization'] == 0].copy()

    sub_group[demo_type] = sub_group[demo_type].astype(str)
    median_los_per_acuity = sub_group[sub_group['miss_total'] == 0].groupby(demo_type)['ed_los_hours'].median()
    median_los_per_acuity = sub_group[demo_type].map(median_los_per_acuity)
    sub_group['rel_los'] = sub_group['ed_los_hours'] - median_los_per_acuity

    grouped = sub_group.groupby([demo_type, 'miss_total']).agg(
        median=('rel_los', 'median'),
        ci95_lower=('rel_los', lambda x: np.percentile(x, 25)),
        ci95_upper=('rel_los', lambda x: np.percentile(x, 75)),
        count=('rel_los', 'count')
    )
    grouped = grouped.reset_index()

    return grouped

def plot_grouped(grouped, is_admitted, demo_type):
    fig, ax = plt.subplots(1, 1, figsize=(3,3))
    fig.dpi = 600
    palette = ['#E69F00', '#56B4E9', '#009E73', '#D55E00', '#CC79A7']
    group_labels = sorted(set(grouped[demo_type]))

    for idx, label_id in enumerate(group_labels):
        sub_group = grouped[grouped[demo_type] == label_id]
        plt.plot(sub_group['miss_total'], sub_group['median'], color=palette[idx], marker='o', markersize=4, markeredgecolor='w', markeredgewidth=0.75, label=label_id)
        plt.fill_between(sub_group['miss_total'], sub_group['ci95_lower'], sub_group['ci95_upper'], color=palette[idx], alpha=0.2, linewidth=0)

    plt.ylim(-6,6)
    ax.set_yticks(np.arange(-6,6.1,2))
    plt.xlim(0,4)
    ax.set_xticks(np.arange(0,5,1))
    plt.grid(linewidth=0.5)
    plt.axhline(0, color='black', linewidth=1)
    plt.xlabel('Misprediction count')
    plt.ylabel('Relative median length of stay')

    if demo_type == 'triage_acuity':
        ncol = 5
        ybox = 1.25
    elif demo_type == 'age_group':
        ncol = 3
        ybox = 1.35
    elif demo_type == 'gender':
        ncol = 2
        ybox = 1.25
    elif demo_type == 'ethnicity':
        ncol = 3
        ybox = 1.35
    elif demo_type == 'CCI_Group':
        ncol = 2
        ybox = 1.35
    else:
        ncol = 1
        ybox = 1.25
    ax.legend(bbox_to_anchor=(0.5, ybox), loc='upper center', frameon=False, handlelength=0.9, ncol=ncol, title=demo_type, columnspacing=0.9)

    plt.show()

In [None]:
is_admitted = False

for demo_type in ['triage_acuity', 'age_group', 'gender', 'ethnicity', 'CCI_Group']:
    grouped = get_stat_per_demo(merged_df, is_admitted, demo_type)
    grouped.dropna(inplace=True)
    plot_grouped(grouped, is_admitted, demo_type)

In [None]:
is_admitted = True

for demo_type in ['triage_acuity', 'age_group', 'gender', 'ethnicity', 'CCI_Group']:
    grouped = get_stat_per_demo(merged_df, is_admitted, demo_type)
    grouped.dropna(inplace=True)
    plot_grouped(grouped, is_admitted, demo_type)