In [None]:
import os
import gc
import h5py
import pandas as pd
import numpy as np

import sklearn.metrics as metrics

import matplotlib.pyplot as plt
plt.style.use("paper.mplstyle")
import seaborn as sns

import sys
sys.path.append("../")
import utils_plot

In [None]:
cm = 1/2.54 
plot_alpha=0.2
h=5.8
w=6

resp_res_path = "/cluster/work/grlab/clinical/hirid2/research/event_based_analysis/resp/"
resp_fig_path = "paper_figures_resp"
if not os.path.exists(resp_fig_path):
    os.mkdir(resp_fig_path)

In [None]:
def read_event_based_pr_single_split(res_path, 
                                     pred_win=1440, 
                                     min_event_gap=0, 
                                     t_silence=30,
                                     t_buffer=0,
                                     t_reset=30,
                                     calibration_scaler=1,
                                     random_classifier=False):
    """
    res_path: path to the event-based evaluation results
    pred_win: future prediction window size (in minutes)
    min_event_gap: the minimal event gap length (in minutes), any gap smaller should be closed
    t_silence: alarm silencing time (in minutes)
    t_buffer: minimal buffer time before event
    t_reset: alarm reset time after patient recovers from failure event
    calibration_scaler: scaler to calibrate the prevalence (AUPRC of the random classifier)
    """
    
    prefix_str = "tg-%d_tr-%d_dt-%d_ws-%d_ts-%d"%(min_event_gap,
                                                  t_reset,
                                                  t_buffer,
                                                  pred_win,
                                                  t_silence) # prefix for different configuration
    res = []
    for f in os.listdir(res_path):
        if prefix_str+"_" in f or prefix_str+"." in f:
            if random_classifier and "rand" in f:
                res.append(pd.read_csv(os.path.join(res_path,f)))
            elif not random_classifier and "rand" not in f:
                res.append(pd.read_csv(os.path.join(res_path,f)))     
    try:            
        res = pd.concat(res)
    except:
        raise Exception("%s"%res_path)
    res.loc[:,"FA"] = calibration_scaler * res.FA # calibrate the false alarms using the scale
    
    res.loc[:,"recall"] = res.CE / (res.CE+res.ME)
    res.loc[:,"precision"] = res.TA / (res.TA+res.FA)
    res = res.sort_values(["tau", "recall", "precision"])
    res = res.drop_duplicates("recall", keep="last")
    res = res.reset_index(drop=True)
    return res

def read_event_based_pr_multi_splits(res_path, 
                                     splits,
                                     pred_win=1440, 
                                     min_event_gap=0, 
                                     t_silence=30,
                                     t_buffer=0,
                                     t_reset=30,
                                     calibration_scaler=1,
                                     random_classifier=False):
    """
    res_path: path to the event-based evaluation results
    splits: list of splits 
    pred_win: future prediction window size (in minutes)
    min_event_gap: the minimal event gap length (in minutes), any gap smaller should be closed
    t_silence: alarm silencing time (in minutes)
    t_buffer: minimal buffer time before event
    t_reset: alarm reset time after patient recovers from failure event
    calibration_scaler: scaler to calibrate the prevalence (AUPRC of the random classifier)
    """
    if type(splits) is not list and type(split)==str:
        splits = [splits]
        
    res = dict()
    for split in splits:
        res.update({split: read_event_based_pr_single_split(os.path.join(res_path, split),
                                                             pred_win=pred_win,
                                                             min_event_gap=min_event_gap,
                                                             t_silence=t_silence,
                                                             t_buffer=t_buffer,
                                                             t_reset=t_reset,
                                                             random_classifier=random_classifier,
                                                             calibration_scaler=calibration_scaler)})
        
    return res


def plot_metric_vs_setting(curves, ylabel="auprc", xlabel="t_silence", fixed_rec=0.8, color='C0'):
    """
    curves: a dictionary containing the configuration of all curves in the same plot
    """
    xtick_vals = []
    metric_vals = []
    for i, model in enumerate(curves.keys()):
        
        if "calibration_scaler" in curves[model]:
            calibration_scaler = curves[model]["calibration_scaler"]
        else:
            calibration_scaler = 1
            
        res = read_event_based_pr_multi_splits(curves[model]["res_path"], 
                                               curves[model]["splits"], 
                                               pred_win=curves[model]["pred_win"],
                                               min_event_gap=curves[model]["min_event_gap"],
                                               t_silence=curves[model]["t_silence"],
                                               t_buffer=curves[model]["t_buffer"],
                                               t_reset=curves[model]["t_reset"],
                                               random_classifier=curves[model]["random_classifier"],
                                               calibration_scaler=calibration_scaler)
        
        
        aggr_res = [] # aggregated results from all splits
        for k, v in res.items():
            aggr_res.append(v.set_index("recall").sort_index().rename(columns={"precision":k})[[k]])
            
        aggr_res = pd.concat(aggr_res, axis=1)
        aggr_res = aggr_res.sort_index()
        aggr_res = aggr_res.interpolate(method="index")
        aggr_res = aggr_res[aggr_res.isnull().sum(axis=1)==0]
        
        if "single_point" in curves[model] and curves[model]["single_point"]:
            aggr_res =  aggr_res[aggr_res.index<1]
            precision_mean = aggr_res.mean(axis=1)
            precision_std = aggr_res.std(axis=1) if aggr_res.shape[1]>1 else 0
                        
        else:
            aucs = [metrics.auc(aggr_res.index,aggr_res[k]) for k in aggr_res.columns]

            precision_mean = aggr_res.mean(axis=1)
            precision_std = aggr_res.std(axis=1)

            esti_rec = aggr_res.index[np.argmin(np.abs(aggr_res.index-fixed_rec))]
                            
        xtick_vals.append(curves[model][xlabel])
        if ylabel=="auprc":
            metric_vals.append([np.mean(aucs), np.std(aucs)])
            
        else:
            metric_vals.append([precision_mean.loc[esti_rec]*100, 
                                precision_std.loc[esti_rec]*100])
            
           
    metric_vals = np.array(metric_vals)
    plt.plot(xtick_vals, metric_vals[:,0], color=color)   
    if ylabel=="auprc":
        print(["%2.2f (%2.2f)"%(metric_vals[i,0],metric_vals[i,1]) for i in range(len(metric_vals))])
    else:
        print(["%2.1f (%2.1f)"%(metric_vals[i,0],metric_vals[i,1]) for i in range(len(metric_vals))])
        
    plt.fill_between(xtick_vals, 
                     metric_vals[:,0]-metric_vals[:,1],
                     metric_vals[:,0]+metric_vals[:,1],
                     alpha=0.2, 
                     color=color)
    
    plt.xticks(xtick_vals, xtick_vals)
    plt.xlim([0,xtick_vals[-1]])
    plt.grid(alpha=0.2)    
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    return xtick_vals, metric_vals


In [None]:
# metric = "Precision @ Recall=80%"
metric = "AUPRC"

# early warning system configuration
resp_ews_configs = dict()
resp_ews_configs.update(pred_win=1440)
resp_ews_configs.update(min_event_gap=0)
resp_ews_configs.update(t_silence=240)
resp_ews_configs.update(t_buffer=0)
resp_ews_configs.update(t_reset=30)
splits = ["temporal_%d"%i for i in np.arange(1,6)]

prev_db1 = 0.151
prev_db2 = 0.219
calibrated_s = (1/prev_db1-1) / (1/prev_db2-1)

curves_hirid = dict()
curves_hirid_drop_pharma = dict()
curves_umcdb = dict()
curves_umcdb_drop_pharma = dict()
curves_umcdb_transported = dict()

for i, num_var in enumerate(np.arange(1,21)):
    curves_hirid.update({"#variables = %d"%num_var: dict(res_path=os.path.join(resp_res_path, "WorseFromZeroOrOne_var%d_all"%num_var),
                                                         color="C%d"%i,
                                                         linestyle="-",
                                                         random_classifier=False,
                                                         num_var=num_var,
                                                         splits=splits)})


    curves_umcdb.update({"#variables = %d"%num_var: dict(res_path=os.path.join(resp_res_path, "WorseFromZeroOrOne_val_umc_%dvar_all_rerun"%num_var),
                                                         color="C%d"%i,
                                                         linestyle="-",
                                                         random_classifier=False,
                                                         calibration_scaler=calibrated_s,
                                                         num_var=num_var,
                                                         splits=splits)})
    
    curves_umcdb_transported.update({"#variables = %d"%num_var: dict(res_path=os.path.join(resp_res_path, "WorseFromZeroOrOne_val_transported_%dvars_v2_all"%num_var),
                                         color="C%d"%i,
                                         linestyle="-",
                                         random_classifier=False,
                                         calibration_scaler=calibrated_s,
                                         num_var=num_var,
                                         splits=['1','2','3','4','5'])})

    if num_var<=4:
        curves_hirid_drop_pharma.update({"#variables = %d"%num_var: dict(res_path=os.path.join(resp_res_path, "WorseFromZeroOrOne_var%d_all"%num_var),
                                                             color="C%d"%i,
                                                             linestyle="-",
                                                             random_classifier=False,
                                                             num_var=num_var,
                                                             splits=splits)})

        curves_umcdb_drop_pharma.update({"#variables = %d"%num_var: dict(res_path=os.path.join(resp_res_path, "WorseFromZeroOrOne_val_umc_%dvar_all_rerun"%num_var),
                                                                         color="C%d"%i,
                                                                         linestyle="-",
                                                                         random_classifier=False,
                                                                         calibration_scaler=calibrated_s,
                                                                         num_var=num_var,
                                                                         splits=splits)})
    elif num_var>=7:
        curves_hirid_drop_pharma.update({"#variables = %d"%num_var: dict(res_path=os.path.join(resp_res_path, "WorseFromZeroOrOne_rmsRF_var%d_all"%num_var),
                                                             color="C%d"%i,
                                                             linestyle="-",
                                                             random_classifier=False,
                                                             num_var=num_var,
                                                             splits=splits)})
        curves_umcdb_drop_pharma.update({"#variables = %d"%num_var: dict(res_path=os.path.join(resp_res_path, "WorseFromZeroOrOne_val_umc_drop_pharma_%dvar_all_rerun"%num_var),
                                                                         color="C%d"%i,
                                                                         linestyle="-",
                                                                         random_classifier=False,
                                                                         calibration_scaler=calibrated_s,
                                                                         num_var=num_var,
                                                                         splits=splits)})


for k in curves_hirid.keys():
    curves_hirid[k].update(dict(**resp_ews_configs))
    curves_umcdb[k].update(dict(**resp_ews_configs))
    curves_umcdb_transported[k].update(dict(**resp_ews_configs))
    if k in curves_umcdb_drop_pharma:
        curves_umcdb_drop_pharma[k].update(dict(**resp_ews_configs))
        curves_hirid_drop_pharma[k].update(dict(**resp_ews_configs))
    
    
lst_var = ["FiO$_2$", "SpO$_2$", "PaO$_2$", "Supplemental FiO$_2$ (%)", "Loop diuretics", 
           "Heparin", "Ventilator peak pressure", "Supplemental oxygen", "Propofol", "PEEPs", 
           "Presssure support", "GCS Motor", "Ventilator respiratory rate", "Sex", "Estimated FiO$_2$", 
           "GCS Verbal", "Dobutamine", "Benzodiacepine", "MV(exp)", "PaCO$_2$"]
# lst_var = ["FiO$_2$ (k=1)", "SpO$_2$ (k=2)", "PaO$_2$ (k=3)", "Supplemental FiO$_2$ (%) (k=4)", "Loop diuretics (k=5)", 
#            "Heparin (k=6)", "Ventilator peak pressure (k=7)", "Supplemental oxygen (k=8)", "Propofol (k=9)", "PEEPs (k=10)", 
#            "Presssure support (k=11)", "GCS Motor (k=12)", "Ventilator respiratory rate (k=13)", "Sex (k=14)", "Estimated FiO$_2$ (k=15)", 
#            "GCS Verbal (k=16)", "Dobutamine (k=17)", "Benzodiacepine (k=18)", "MV(exp) (k=19)", "PaCO$_2$ (k=20)"]

if metric=="AUPRC":
    plt.figure(figsize=(w*1.8*cm, h*cm))
    xticks_hirid_drop_pharma, auprc_hirid_drop_pharma = plot_metric_vs_setting(curves_hirid_drop_pharma, ylabel="auprc", xlabel="num_var", fixed_rec=0.8)
    xticks_umcdb_drop_pharma, auprc_umcdb_drop_pharma = plot_metric_vs_setting(curves_umcdb_drop_pharma, ylabel="auprc", xlabel="num_var", fixed_rec=0.8)
    xticks_hirid, auprc_hirid = plot_metric_vs_setting(curves_hirid, ylabel="auprc", xlabel="num_var", fixed_rec=0.8)
    xticks_umcdb, auprc_umcdb = plot_metric_vs_setting(curves_umcdb, ylabel="auprc", xlabel="num_var", fixed_rec=0.8)
    xticks_umcdb_transported, auprc_umcdb_transported = plot_metric_vs_setting(curves_umcdb_transported, ylabel="auprc", xlabel="num_var", fixed_rec=0.8)
    plt.xticks(np.arange(1,21), lst_var, rotation=30, horizontalalignment="right")
    plt.legend(["HiRID -> HiRID", "_nolegend_", "HiRID -> UMCDB", "_nolegend_","HiRID -> HiRID (+pharma)", "_nolegend_", "HiRID -> UMCDB (+pharma)", "_nolegend_", "HiRID -> UMCDB (transported)", "_nolegend_"])
    plt.ylabel("AUPRC")
    plt.tight_layout()
    plt.show()
else:
    plt.figure(figsize=(w*1.8*cm, h*cm))
    xticks_hirid, auprc_hirid = plot_metric_vs_setting(curves_hirid, ylabel="precision", xlabel="num_var", fixed_rec=0.8)
    xticks_umcdb, auprc_umcdb = plot_metric_vs_setting(curves_umcdb, ylabel="precision", xlabel="num_var", fixed_rec=0.8)
    xticks_umcdb_drop_pharma, auprc_umcdb_drop_pharma = plot_metric_vs_setting(curves_umcdb_drop_pharma, ylabel="precision", xlabel="num_var", fixed_rec=0.8)
    plt.xticks(np.arange(1,21), lst_var, rotation=30, horizontalalignment="right")
    plt.legend(["HiRID -> HiRID", "_nolegend_", "HiRID -> UMCDB", "_nolegend_", "HiRID -> UMCDB (no pharma)", "_nolegend_"])
    plt.ylabel("Precision @ Recall=80%")
    plt.xlabel("")
    plt.tight_layout()
    plt.show()

In [None]:
lst_var = ["FiO$_2$", "SpO$_2$", "PaO$_2$", "Supplemental FiO$_2$ (%)", "Loop diuretics", 
           "Heparin", "Ventilator peak pressure", "Supplemental oxygen", "Propofol", "PEEPs", 
           "Presssure support", "GCS Motor", "Ventilator respiratory rate", "Sex", "Estimated FiO$_2$", 
           "GCS Verbal", "Dobutamine", "Benzodiacepine", "MV(exp)", "PaCO$_2$"]
metric = 'AUPRC'
plt.figure(figsize=(w*1.8*cm, h*cm))
for i, tmp in enumerate([("RMS-RF: HiRID->HiRID", xticks_hirid_drop_pharma, auprc_hirid_drop_pharma, "C0", '.', 4),
                         ("RMS-RF: HiRID->UMCDB", xticks_umcdb_drop_pharma, auprc_umcdb_drop_pharma, "C1", '*',3),
                         ("RMS-RF-p: HiRID->HiRID", xticks_hirid, auprc_hirid, "C9", 'd',2),
                         ("RMS-RF-p: HiRID->UMCDB", xticks_umcdb, auprc_umcdb, "C8", 'o',1),
                         ("RMS-RF-p (transported): HiRID->UMCDB", xticks_umcdb_transported, auprc_umcdb_transported, "C7", '.',5)]):
    
    label = tmp[0]
    xtick_vals = np.array(tmp[1])
    metric_vals = np.array(tmp[2])
    color = tmp[3]
    marker = tmp[4]
    zorder = tmp[5]
    metric_vals = np.array(metric_vals)
    if '-p' in label:
        plt.plot(xtick_vals, metric_vals[:,0], label=label, color=color, marker=marker,zorder=zorder)   
        plt.fill_between(xtick_vals, 
                         metric_vals[:,0]-metric_vals[:,1],
                         metric_vals[:,0]+metric_vals[:,1],
                         alpha=0.2, label="_nolegend_", color=color)
    else:
        plt.plot(xtick_vals[~np.isin(xtick_vals,[5,6,9,17,18])], metric_vals[~np.isin(xtick_vals,[5,6,9,17,18]),0], label=label, color=color, marker=marker, zorder=zorder)   
        plt.fill_between(xtick_vals[~np.isin(xtick_vals,[5,6,9,17,18])], 
                         metric_vals[~np.isin(xtick_vals,[5,6,9,17,18]),0]-metric_vals[~np.isin(xtick_vals,[5,6,9,17,18]),1],
                         metric_vals[~np.isin(xtick_vals,[5,6,9,17,18]),0]+metric_vals[~np.isin(xtick_vals,[5,6,9,17,18]),1],
                         alpha=0.2, label="_nolegend_", color=color)
       

    plt.xticks(xtick_vals, xtick_vals)
    plt.xlim([0,xtick_vals[-1]+1])
    if metric=="AUPRC":
        pass
    else:
        plt.ylim([min(auprc_hirid[:,0].min(), auprc_umcdb[:,0].min(), auprc_umcdb_drop_pharma[:,0].min())-7, plt.ylim()[1]])
    plt.grid(alpha=0.2)   
    if i==3:
        ticklabelcolor = ['k']
        for j, x in enumerate(xtick_vals):
            if j>0:
                auprc_diff = metric_vals[j,0]-metric_vals[j-1,0]
                if int(auprc_diff/metric_vals[j-1,0]*100)==0:
                    # plt.arrow(x, metric_vals[j,0], 0, 1, linewidth=1, head_width=0.1, head_length=0.01, color="C8",label="_nolegend_")
                    # plt.text(x, metric_vals[j,0]+2, "%d%%"%(auprc_diff/metric_vals[j-1,0]*100) ,color="C8", horizontalalignment="center", weight="bold")
                    # pass
                    ticklabelcolor.append('k')
                    continue
                    
                if  auprc_diff >0:
                    # if metric=="AUPRC":
                    #     plt.arrow(x, metric_vals[j,0], 0, -0.02, linewidth=1, head_width=0.1, head_length=0.015, color="C8",label="_nolegend_")
                    #     plt.text(x, metric_vals[j,0]-0.09, "%d%%"%(-auprc_diff/metric_vals[j-1,0]*100) ,color="C8", horizontalalignment="center", weight="bold")
                    # else:
                    #     plt.arrow(x, metric_vals[j,0], 0, -1, linewidth=1, head_width=0.1, head_length=0.5, color="C8",label="_nolegend_")
                    #     plt.text(x, metric_vals[j,0]-4, "%d%%"%(-auprc_diff/metric_vals[j-1,0]*100) ,color="C8", horizontalalignment="center", weight="bold")
                    # pass
                    ticklabelcolor.append('k')
                else:
                    if metric=="AUPRC":
                        plt.arrow(x, metric_vals[j,0], 0, -0.02, linewidth=1, head_width=0.1, head_length=0.015, color="C3",label="_nolegend_")
                        plt.text(x, metric_vals[j,0]-0.06, "%d%%"%(-auprc_diff/metric_vals[j-1,0]*100) ,color="C3", horizontalalignment="center", weight="bold")
                    else:
                        plt.arrow(x, metric_vals[j,0], 0, -1, linewidth=1, head_width=0.1, head_length=0.5, color="C3",label="_nolegend_")
                        plt.text(x, metric_vals[j,0]-4, "%d%%"%(-auprc_diff/metric_vals[j-1,0]*100) ,color="C3", horizontalalignment="center", weight="bold")
                    ticklabelcolor.append('C3')
                    
plt.xticks(np.arange(1,21), lst_var, rotation=30, horizontalalignment="right")
ax = plt.gca()
for xtick, color in zip(ax.get_xticklabels(), ticklabelcolor):
    xtick.set_color(color)
plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc="lower left", ncols=2, mode="expand", borderaxespad=0.)
plt.ylabel(metric)
plt.xlabel('Top $k$ variables')
plt.tight_layout()
if metric=="AUPRC":
    plt.savefig(os.path.join(resp_fig_path,"fig2_subfig8"))
else:
    plt.savefig(os.path.join(resp_fig_path,"fig2_subfig9"))    
plt.show()

In [None]:
df_auprc_hirid = pd.read_csv("/cluster/work/grlab/clinical/hirid2/research/RESP_RELEASE/evaluation/time_point_based/eval_fail_prefix_model/task_results_with_error_bars.tsv", sep="\t")
df_auprc_umcdb = pd.read_csv("/cluster/work/grlab/clinical/hirid2/research/RESP_RELEASE/evaluation/time_point_based/eval_fail_prefix_model_val/task_results_with_error_bars.tsv", sep="\t")
df_auprc_umcdb_drop_pharma = pd.read_csv("/cluster/work/grlab/clinical/hirid2/research/RESP_RELEASE/evaluation/time_point_based/eval_fail_prefix_model_val_no_pharma/task_results_with_error_bars.tsv", sep="\t")
df_auprc_hirid_drop_pharma = pd.read_csv("/cluster/work/grlab/clinical/hirid2/research/RESP_RELEASE/evaluation/time_point_based/eval_fail_prefix_model_no_pharma/task_results_with_error_bars.tsv", sep="\t")

metric = "AUPRC"
# metric = "Precision @ Recall=80%"
if metric=="AUPRC":
    xticks_hirid = df_auprc_hirid.task.apply(lambda x: int(x.split("_")[4])).values
    auprc_hirid = df_auprc_hirid.iloc[:,1:3].values

    xticks_umcdb = df_auprc_umcdb.task.apply(lambda x: int(x.split("_")[5][:-3])).values
    auprc_umcdb = df_auprc_umcdb.iloc[:,1:3].values

    xticks_umcdb_drop_pharma = df_auprc_umcdb_drop_pharma.task.apply(lambda x: int(x.split("_")[5][:-3])).values
    auprc_umcdb_drop_pharma = df_auprc_umcdb_drop_pharma.iloc[:,1:3].values

    xticks_hirid_drop_pharma = df_auprc_hirid_drop_pharma.task.apply(lambda x: int(x.split("_")[4])).values
    auprc_hirid_drop_pharma = df_auprc_hirid_drop_pharma.iloc[:,2:4].values
    
else:
    xticks_hirid = df_auprc_hirid.task.apply(lambda x: int(x.split("_")[4])).values
    auprc_hirid = df_auprc_hirid.iloc[:,3:5].values

    xticks_umcdb = df_auprc_umcdb.task.apply(lambda x: int(x.split("_")[5][:-3])).values
    auprc_umcdb = df_auprc_umcdb.iloc[:,3:5].values

    
    xticks_umcdb_drop_pharma = df_auprc_umcdb_drop_pharma.task.apply(lambda x: int(x.split("_")[5][:-3])).values
    auprc_umcdb_drop_pharma = df_auprc_umcdb_drop_pharma.iloc[:,3:5].values

    xticks_hirid_drop_pharma = df_auprc_hirid_drop_pharma.task.apply(lambda x: int(x.split("_")[4])).values
    auprc_hirid_drop_pharma = df_auprc_hirid_drop_pharma.iloc[:,4:6].values
    
lst_var = ["Presssure support", "Benzodiacepine", "Norepinephrine", "FiO$_2$", "Propofol", "MV spont servo", 
           "Ventilator respiratory rate", "RR sp. m", "Ventilation presence", "Supplemental oxygen", 
           "Insulin [fast acting]", "Loop diuretics", "Supplemental FiO$_2$[%]", "MV[exp]", "Ventilator peak pressure",
           "Heparin", "PaCO$_2$", "Emergency admission", "PEEPs", "Ventilation mode group"]
# lst_var = ["Presssure support ($k=1$)", "+Benzodiacepine ($k=2$)", "+Norepinephrine ($k=3$)", "+FiO$_2$ ($k=4$)", "+Propofol ($k=5$)", "+MV spont servo ($k=6$)", 
#            "+Ventilator respiratory rate ($k=7$)", "+RR sp. m ($k=8$)", "+Ventilation presence ($k=9$)", "+Supplemental oxygen ($k=10$)", 
#            "+Insulin [fast acting] ($k=11$)", "+Loop diuretics ($k=12$)", "+Supplemental FiO$_2$[%] ($k=13$)", "+MV[exp] ($k=14$)", "+Ventilator peak pressure ($k=15$)",
#            "+Heparin ($k=16$)", "+PaCO$_2$ ($k=17$)", "+Emergency admission ($k=18$)", "+PEEPs ($k=19$)", "+Ventilation mode group ($k=20$)"]
# lst_var = ["1 (Presssure support)", "2 (+Benzodiacepine)", "3 (+Norepinephrine)", "4 (+FiO$_2$)", "5 (+Propofol)", "6 (+MV spont servo)", 
#            "7 (+Ventilator respiratory rate)", "8 (+RR sp. m)", "9 (+Ventilation presence)", "10 (+Supplemental oxygen)", 
#            "11 (+Insulin [fast acting])", "12 (+Loop diuretics)", "13 (+Supplemental FiO$_2$[%])", "14 (+MV[exp])", "15 (+Ventilator peak pressure)",
#            "16 (+Heparin)", "17 (+PaCO$_2$)", "18 (+Emergency admission)", "19 (+PEEPs)", "20 (+Ventilation mode group)"]

plt.figure(figsize=(w*2*cm, h*cm))
for i, tmp in enumerate([("RMS-EF: HiRID->HiRID ", xticks_hirid, auprc_hirid, "C9", '.', 4),
                         ("RMS-EF: HiRID->UMCDB", xticks_umcdb, auprc_umcdb, "C1", "*", 3), 
                         ("RMS-EF-lite: HiRID->HiRID", xticks_hirid_drop_pharma, auprc_hirid_drop_pharma, "C0", "d", 1), 
                         ("RMS-EF-lite: HiRID->UMCDB", xticks_umcdb_drop_pharma, auprc_umcdb_drop_pharma, "C8", "o", 1)]):
    
    label = tmp[0]
    xtick_vals = tmp[1]
    metric_vals = tmp[2]
    color=tmp[3]
    marker = tmp[4]
    zorder = tmp[5]
    metric_vals = np.array(metric_vals)
    if 'lite' in label:
        plt.plot(xtick_vals[np.arange(20)[~np.isin(range(20),[1,2,4,10,11,15])]], metric_vals[np.arange(20)[~np.isin(range(20),[1,2,4,10,11,15])],0], label=label, color=color, marker=marker, zorder=zorder)   
        plt.fill_between(xtick_vals[np.arange(20)[~np.isin(range(20),[1,2,4,10,11,15])]], 
                         metric_vals[np.arange(20)[~np.isin(range(20),[1,2,4,10,11,15])],0]-metric_vals[np.arange(20)[~np.isin(range(20),[1,2,4,10,11,15])],1],
                         metric_vals[np.arange(20)[~np.isin(range(20),[1,2,4,10,11,15])],0]+metric_vals[np.arange(20)[~np.isin(range(20),[1,2,4,10,11,15])],1],
                         alpha=0.2, label="_nolegend_", color=color)
    else:
        plt.plot(xtick_vals, metric_vals[:,0], label=label, color=color, marker=marker, zorder=zorder)   
        plt.fill_between(xtick_vals, 
                         metric_vals[:,0]-metric_vals[:,1],
                         metric_vals[:,0]+metric_vals[:,1],
                         alpha=0.2, label="_nolegend_", color=color)

    plt.xticks(xtick_vals, xtick_vals)
    plt.xlim([0,xtick_vals[-1]+1])
    # if metric=="AUPRC":
    #     pass
    # else:
    #     plt.ylim([min(auprc_hirid[:,0].min(), auprc_umcdb[:,0].min(), auprc_umcdb_drop_pharma[:,0].min())-5, plt.ylim()[1]])
    plt.grid(alpha=0.2)   
    if i==1:
        ticklabelcolor = ['k']
        for j, x in enumerate(xtick_vals):
            if j>0:
                auprc_diff = metric_vals[j,0]-metric_vals[j-1,0]
                if int(auprc_diff/metric_vals[j-1,0]*100)==0:
                    # plt.arrow(x, metric_vals[j,0], 0, 1, linewidth=1, head_width=0.1, head_length=0.01, color="C8",label="_nolegend_")
                    # plt.text(x, metric_vals[j,0]+2, "%d%%"%(auprc_diff/metric_vals[j-1,0]*100) ,color="C8", horizontalalignment="center", weight="bold")
                    ticklabelcolor.append('k')
                    continue
                    
                if  auprc_diff >0:
                    # if metric=="AUPRC":
                    #     plt.arrow(x, metric_vals[j,0], 0, -0.02, linewidth=1, head_width=0.1, head_length=0.015, color="C8",label="_nolegend_")
                    #     plt.text(x, metric_vals[j,0]-0.09, "%d%%"%(-auprc_diff/metric_vals[j-1,0]*100) ,color="C8", horizontalalignment="center", weight="bold")
                    # else:
                    #     plt.arrow(x, metric_vals[j,0], 0, -1, linewidth=1, head_width=0.1, head_length=0.5, color="C8",label="_nolegend_")
                    #     plt.text(x, metric_vals[j,0]-4, "%d%%"%(-auprc_diff/metric_vals[j-1,0]*100) ,color="C8", horizontalalignment="center", weight="bold")
                    ticklabelcolor.append('k')
                    pass
                else:
                    if metric=="AUPRC":
                        plt.arrow(x, metric_vals[j,0], 0, -0.02, linewidth=1, head_width=0.1, head_length=0.015, color="C3",label="_nolegend_", zorder=10)
                        plt.text(x, metric_vals[j,0]-0.09, "%d%%"%(-auprc_diff/metric_vals[j-1,0]*100) ,color="C3", horizontalalignment="center", weight="bold")
                    else:
                        plt.arrow(x, metric_vals[j,0], 0, -1, linewidth=1, head_width=0.1, head_length=0.5, color="C3",label="_nolegend_", zorder=10)
                        plt.text(x, metric_vals[j,0]-4, "%d%%"%(-auprc_diff/metric_vals[j-1,0]*100) ,color="C3", horizontalalignment="center", weight="bold")
                    ticklabelcolor.append('C3')

plt.xticks(np.arange(1,21), lst_var, rotation=30, horizontalalignment="right")
ax = plt.gca()
for xtick, color in zip(ax.get_xticklabels(), ticklabelcolor):
    xtick.set_color(color)

plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc="lower left", ncols=2, mode="expand", borderaxespad=0.)
plt.ylabel(metric)
plt.xlabel('Top $k$ variables')
plt.tight_layout()
if metric=="AUPRC":
    plt.savefig(os.path.join(resp_fig_path,"fig3_subfig2"))
else:
    plt.savefig(os.path.join(resp_fig_path,"fig3_subfig3"))    
plt.show()