This notebook serves to reproduce all visualizations used in the paper. To run this notebook, first execute `main.py`.

In [None]:
# Open all required libraries and format for paper
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
from sklearn.metrics import log_loss, roc_auc_score
import statsmodels.formula.api as smf

import yaml

def load_config(config_path):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)
    
sns.set(style="whitegrid")
mpl.rcParams.update({
    'axes.titlesize': 35,
    'axes.labelsize': 35,
    'xtick.labelsize': 35,
    'ytick.labelsize': 35,
    'legend.fontsize': 35,
    'pdf.fonttype': 42,  # For vector text in PDFs
})
# Set colorblind-friendly palette
sns.set_palette("colorblind")

# For notebook environment, directly set the configuration
experiment = "mimic"
config_path = "./configs"

config = load_config(Path(config_path) / f"{experiment}_readmission.yaml")

In [None]:
from pathlib import Path
import pickle
import numpy as np
from collections import defaultdict

tasks = ["long_los", "in_hospital_mortality", "readmission"]

models = [
    'Qwen3-8B',
    'Qwen3-14B',
    'Qwen3-32B',
    #"gpt-5",
    "gemma-3-4b-it",
    "gemma-3-27b-it",
    "gpt-oss-20b",
    "gpt-oss-120b",
    "Llama-3.1-8B-Instruct",
    "Llama-3.1-70B-Instruct",
    "Mistral-Small-3.1-24B-Instruct-2503",
    "medgemma-27b-text-it",
]

renaming_model = {
    "Qwen3-8B": "Qwen3\n8B",
    "Qwen3-14B": "Qwen3\n14B",
    'Qwen3-32B': "Qwen3\n32B",
    "gpt-oss-20b": "GPT-OSS\n20B",
    "gpt-oss-120b": "GPT-OSS\n120B",
    "gpt-5": "GPT-5",
    "gemma-3-4b-it": "Gemma\n4B",
    "gemma-3-12b-it": "Gemma\n12B",
    "gemma-3-27b-it": "Gemma\n27B",
    "Llama-3.1-8B-Instruct": "Llama3.1\n8B",
    "Llama-3.1-70B-Instruct": "Llama3.1\n70B",
    "medgemma-27b-text-it": "MedGemma\n27B",
    "Mistral-Small-3.1-24B-Instruct-2503": "Mistral 3.1\n24B",
}

# Each variant has a human-readable prefix (used in result keys) and flags
variants = [
    ("Dropped",              dict(exp_missing=False, prompt_missing=False)),
    ("Instruction",   dict(exp_missing=False, prompt_missing=True)),
    ("Indicator",            dict(exp_missing=True,  prompt_missing=False)),
    ("Indicator + Inst.", dict(exp_missing=True,  prompt_missing=True)),
]

predictions_dir = Path(config["predictions_dir"])
eps = 1e-12

def safe_load_pickle(path):
    try:
        with open(path, "rb") as f:
            return pickle.load(f)
    except:
        print(path)
        return None

def file_name(task, model_name, exp_missing, labs_only, prompt_missing):
    return predictions_dir / f"predictions_{task}_{model_name}_{exp_missing}_{labs_only}_{prompt_missing}.pkl"

def is_prob(x):
    return (x is not None and isinstance(x, (int, float)) and 0.0 <= float(x) <= 1.0)

task_results_dicts = {}
invalid_counts = {}

for task in tasks:
    config["downstream_task"] = task
    print(f"Processing {task}...")

    # --- Load baseline (optional) ---
    baseline = safe_load_pickle(predictions_dir / f"baseline_predictions_{task}.pkl")
    if baseline is None:
        print(f"⚠️ No baseline predictions for {task}")

    # --- Load all model/variant dicts ---
    # subject_dicts: {(model, variant_name): {sid: {...}}}
    subject_dicts = {}
    for m in models:
        for vname, vflags in variants:
            fname = file_name(task, m, vflags["exp_missing"], config["labs_only"], vflags["prompt_missing"])
            d = safe_load_pickle(fname)
            if d is None:
                print(f"⚠️ No predictions for {task} | {m} | missing={vflags['exp_missing']} | prompt={vflags['prompt_missing']}")
                d = {}
            subject_dicts[(m, vname)] = d

    # --- Compute common subject IDs across all non-empty dicts (+ baseline if present) ---
    nonempty_sets = [set(d.keys()) for d in subject_dicts.values() if d]
    if not nonempty_sets:
        print(f"⚠️ No non-empty prediction dicts for {task}. Skipping.")
        task_results_dicts[task] = {}
        continue

    common_subject_ids = set.intersection(*nonempty_sets)
    baseline_index_map = {}
    if baseline is not None:
        base_ids = list(baseline["lr_no_missing"]["subject_ids"])
        baseline_index_map = {sid: i for i, sid in enumerate(base_ids)}
        common_subject_ids = common_subject_ids & set(baseline["lr_no_missing"]["subject_ids"])

    if not common_subject_ids:
        print(f"⚠️ No overlapping subject IDs for {task}. Skipping.")
        task_results_dicts[task] = {}
        continue

        # --- Pre-allocate aligned collectors ---
    aligned_preds         = defaultdict(list)
    aligned_serializations= defaultdict(list)
    aligned_responses     = defaultdict(list)
    aligned_loglosses     = defaultdict(list)

    aligned_labels            = []  # single list reused across variants
    aligned_baseline_preds_1  = []
    aligned_baseline_preds_2  = []
    aligned_missing_counts    = []
    aligned_baseline_logloss_1= []
    aligned_baseline_logloss_2= []
    aligned_baseline_indices  = []

    invalid_counts[task] = {(m, vname): 0 for m in models for vname, _ in variants}

    # --- Iterate through subjects and collect aligned rows ---
    for sid in common_subject_ids:
        # label from any available dict (or baseline)
        if baseline is not None and sid in baseline_index_map:
            label = baseline["lr_no_missing"]["labels"][baseline_index_map[sid]]
        else:
            # take from first dict that has it
            label = None
            for d in subject_dicts.values():
                if sid in d:
                    label = d[sid].get("label", None)
                    if label is not None:
                        break
        if label is None:
            # cannot validate log loss or continue safely
            continue

        # validate predictions exist and are in [0,1] for ALL requested (model,variant)
        preds_ok = True
        for (m, vname), d in subject_dicts.items():
            if sid not in d or "prediction" not in d[sid] or not is_prob(d[sid]["prediction"]):
                invalid_counts[task][(m, vname)] += 1 
                preds_ok = False
                break
        if not preds_ok:
            continue

        # collect per variant/model
        for (m, vname), d in subject_dicts.items():
            pred = float(d[sid]["prediction"])
            serialization = d[sid].get("serialization")
            response      = d[sid].get("response")

            aligned_preds[(m, vname)].append(pred)
            aligned_serializations[(m, vname)].append(serialization)
            aligned_responses[(m, vname)].append(response)

            ll = -(label * np.log(pred + eps) + (1 - label) * np.log(1 - pred + eps))
            aligned_loglosses[(m, vname)].append(ll)

        # baseline (optional)
        if baseline is not None and sid in baseline_index_map:
            b_idx  = baseline_index_map[sid]
            b_pred1 = float(baseline["lr_no_missing"]["predictions"][b_idx])
            b_pred2 = float(baseline["lr_with_missing"]["predictions"][b_idx])
            m_count = baseline["lr_with_missing"]["missing_counts"][b_idx]
            b_label = baseline["lr_no_missing"]["labels"][b_idx]

            aligned_labels.append(b_label)
            aligned_baseline_preds_1.append(b_pred1)
            aligned_baseline_preds_2.append(b_pred2)
            aligned_missing_counts.append(m_count)

            b_ll1 = -(label * np.log(b_pred1 + eps) + (1 - label) * np.log(1 - b_pred1 + eps))
            b_ll2 = -(label * np.log(b_pred2 + eps) + (1 - label) * np.log(1 - b_pred2 + eps))
            aligned_baseline_logloss_1.append(b_ll1)
            aligned_baseline_logloss_2.append(b_ll2)
            aligned_baseline_indices.append(b_idx)

    # --- Build the final task dict programmatically ---
    task_bundle = {}

    # add model/variant blocks
    for m in models:
        for vname, _ in variants:
            task_bundle[(m, vname)] = {
                "predictions": aligned_preds[(m, vname)],
                "labels": aligned_labels,  # same aligned order
                "serializations": aligned_serializations[(m, vname)],
                "responses": aligned_responses[(m, vname)],
                "log_losses": aligned_loglosses[(m, vname)],
            }

    # add baselines (if present)
    if baseline is not None:
        task_bundle[("Log", "Dropped")] = {
            "predictions": aligned_baseline_preds_1,
            "labels": aligned_labels,
            "log_losses": aligned_baseline_logloss_1,
        }
        task_bundle[("Log", "Indicator")] = {
            "predictions": aligned_baseline_preds_2,
            "labels": aligned_labels,
            "missing_counts": aligned_missing_counts,
            "log_losses": aligned_baseline_logloss_2,
            "feature_importances": baseline["lr_with_missing"].get("feature_importances") if baseline else None,
            "indices": aligned_baseline_indices,
        }

    task_results_dicts[task] = task_bundle

    print(f"Num overlapping for task {task}: {len(aligned_labels)}")

In [None]:
# Display the number of errors per model

def make_invalid_table_latex(
    invalid_counts: dict,
    variant_order: list[str] | None = None,
    task_order: list[str] | None = None,
    caption: str = "Invalid prediction counts by task and variant.",
    label: str = "tab:invalid_counts",
) -> str:
    """
    invalid_counts: {task: {variant: count}}
    Output: LaTeX table with tasks as columns, variants as rows (integers only).
    """
    # Collect tasks/variants
    tasks = list(invalid_counts.keys())
    variants = set()
    for t, d in invalid_counts.items():
        variants.update(d.keys())
    # tasks = task_order or sorted(tasks)
    variants = variant_order or sorted(variants)

    # Build wide table (variants x tasks), fill missing with 0 and cast to int
    df = pd.DataFrame(
        {t: {v: invalid_counts.get(t, {}).get(v, 0) for v in variants} for t in tasks}
    ).fillna(0).astype(int)
    df.index.name = "Variant"
    return df.to_latex()


methods = list(task_results_dicts[tasks[0]].keys())[0:-2]

# Example call (no totals; integers only)
latex = make_invalid_table_latex(
    invalid_counts,
    variant_order=methods,
    caption="Invalid predictions per task/variant.",
    label="tab:invalid_counts"
)
print(latex)

In [None]:
# Compute performance and calibration
def bootstrap_performance(labels, predictions, n_bootstrap=1000, n_bins=5, ece_strategy='adaptive', pi=None, eps=1e-12):
    """Bootstrap calibration curve and ECE with confidence intervals"""
    n_samples = len(labels)
    ece_bootstrap = []
    auroc_bootstrap = []
    
    labels = np.array(labels)
    predictions = np.array(predictions)
    
    for _ in range(n_bootstrap):
        # Bootstrap sample
        indices = np.random.choice(n_samples, n_samples, replace=True)
        boot_labels = np.array(labels)[indices]
        boot_preds = np.array(predictions)[indices]

        # Compute ECE for this bootstrap
        if ece_strategy == 'quantile':
            edges = np.quantile(boot_preds, np.linspace(0, 1, n_bins+1))
            edges[0], edges[-1] = 0.0, 1.0
        else:
            edges = np.linspace(0, 1, n_bins+1)

        if pi is None:
            w = np.ones_like(boot_labels)
        else:
            w = np.where(boot_labels == 1, pi / 0.5, (1 - pi) / (1 - 0.5)).astype(float)

        ece = 0.0
        total_weight = w.sum()

        for b in range(n_bins):
            if b < n_bins - 1:
                mask = (boot_preds >= edges[b]) & (boot_preds < edges[b+1])
            else:
                mask = (boot_preds >= edges[b]) & (boot_preds <= edges[b+1])
            if not mask.any():
                continue

            pw = boot_preds[mask]
            yw = boot_labels[mask]
            ww = w[mask]

            conf = np.average(pw, weights=ww)
            acc  = np.average(yw, weights=ww)
            ece += (ww.sum() / total_weight) * abs(acc - conf)
        ece_bootstrap.append(ece)
        auroc = roc_auc_score(boot_labels, boot_preds)
        auroc_bootstrap.append(auroc)
    
    # Convert to arrays
    ece_bootstrap = np.array(ece_bootstrap)
    auroc_bootstrap = np.array(auroc_bootstrap)
    
    # ECE stats
    ece_mean = np.mean(ece_bootstrap)
    ece_lower = np.percentile(ece_bootstrap, 2.5)
    ece_upper = np.percentile(ece_bootstrap, 97.5)

    auroc_mean = np.mean(auroc_bootstrap)
    auroc_lower = np.percentile(auroc_bootstrap, 2.5)
    auroc_upper = np.percentile(auroc_bootstrap, 97.5)
    
    return {
        "ece": {
            "mean": ece_mean,
            "lower": ece_lower,
            "upper": ece_upper,
        },
        "auroc": {
            "mean": auroc_mean,
            "lower": auroc_lower,
            "upper": auroc_upper,
        },
    }

In [None]:
# Get bootstrapped perforamnces
task_metrics = {}
for task in tasks:
    task_metrics[task] = {}

    for key, d in task_results_dicts[task].items():
        print(key)
        if len(d["predictions"]) > 0:
            task_metrics[task][key] = bootstrap_performance(d["labels"], d["predictions"])

In [None]:
# Create regression
metric = 'auroc'

for task in tasks:
    np.random.seed(42)
    # Create array
    task_data = pd.DataFrame({model: task_metrics[task][model][metric] for model in task_metrics[task]})

    # Difference performance
    difference = task_data.sub(task_data.xs('Dropped', axis=1, level=1), level=0, axis=1)
    difference = difference.loc[:, difference.columns.get_level_values(1) != 'Dropped']

    # Extract model and size
    pattern = r'^([^-]+).*?(\d+[Bb])'
    model_size = difference.columns.get_level_values(0).to_series().str.extract(pattern)
    model_size.columns = ['Model', 'Size']
    model_size['Size'] = model_size['Size'].str.replace(r'[Bb]', '', regex=True).fillna(300).astype(int).apply(np.log)
    model_size['Model'] = model_size['Model'].fillna('gpt')
    model_size['Prompt'] = difference.columns.get_level_values(1).to_series().str.contains('Inst').values
    model_size['Serialization'] = difference.columns.get_level_values(1).to_series().str.contains('Ind').values

    difference.columns = pd.MultiIndex.from_frame(model_size)

    # Fitting a regression on this values
    train = None
    test = None

    # Regressed on model size, token used for training, medical data used and 
    selection = ['mean', 'Size', 'Prompt', 'Serialization']
    
    model = smf.ols(formula='mean ~ Size * C(Prompt) + Size * C(Serialization)', data=difference.T.reset_index()[selection]).fit()
    print(model.summary())

In [None]:
# Display x axis perf with and y without - size reflect model size and color model class
metric = 'ece'
for task in tasks:
    for prompt in ['Indicator + Inst.', 'Indicator', 'Instruction']:
        # Create array
        task_data = pd.DataFrame({model: task_metrics[task][model][metric] for model in task_metrics[task] if model[0] != 'Log'})
        model_size = task_data.columns.get_level_values(0).to_series().str.extract(pattern)
        model_size.columns = ['Model', 'Size']
        model_size['Size'] = model_size['Size'].str.replace(r'[Bb]', '', regex=True).fillna(300).astype(int)
        model_size['Model'] = model_size['Model'].fillna('gpt')
        model_size['Variant'] = task_data.columns.get_level_values(1)
        task_data.columns = pd.MultiIndex.from_frame(model_size)
        task_data = task_data.loc['mean']

        # Plot the gain in ece given size
        plt.figure(figsize=(8,6))

        ax = plt.scatter(task_data.xs('Dropped', level=2), 
                         task_data.xs(prompt, level=2), 
                        s = 10 * (task_data.xs('Dropped', level=2).index.get_level_values('Size') + 10), 
                        c = task_data.xs('Dropped', level=2).index.get_level_values('Model').astype("category").codes, 
                        cmap='viridis', alpha = 0.75)
        
        # LEGEND
        ########
        models = task_data.xs('Dropped', level=2).index.get_level_values('Model').astype("category")
        model_codes = models.codes
        unique_models = models.categories
        unique_codes = np.unique(model_codes)

        cmap = plt.cm.viridis
        import matplotlib.patches as mpatches
        # Create one legend entry per model
        matches = {
            'gpt': 'GPT', 'medgemma': 'Med-Gemma', 'gemma':'Gemma'
        }

        handles = [
            plt.scatter([], [], 
                color=cmap(code / (len(unique_codes)-1)),
                s=200,              # legend marker size
                marker='o',
                label=matches[model] if model in matches else model)
            for code, model in zip(unique_codes, unique_models)
        ]

        #plt.legend(handles=handles, loc='center left', bbox_to_anchor=(1, 0.5))
        #########

        if task == "long_los":
            task_name = "Long LOS"
        elif task == "death" or task == "in_hospital_mortality":
            task_name = "Mortality"
        else:
            task_name = task.capitalize()
        plt.title(f"{task_name}")

        plt.axvline(task_metrics[task][('Log', 'Dropped')][metric]['mean'], ls = '--', lw = 3, c = 'orange')
        plt.axhline(task_metrics[task][('Log', 'Indicator')][metric]['mean'], ls = '--', lw = 3, c = 'orange')
        
        plt.axline((task_data.min(),task_data.min()), slope=1, ls = ':', c = 'k')

        if metric == 'auroc':
            if task == 'long_los':
                plt.ylim(0.55, 0.8)
                plt.xlim(0.55, 0.8)
            elif task_name == 'Mortality':
                plt.ylim(0.65, 0.9)
                plt.xlim(0.65, 0.9)
            else:
                plt.ylim(0.5, 0.7)
                plt.xlim(0.5, 0.7)
        else:
            plt.ylim(-0.05, 0.55)
            plt.xlim(-0.05, 0.55)

        plt.ylabel(metric.upper() + ' ' + prompt)
        plt.xlabel(metric.upper() + ' Dropped')
        plt.savefig('outputs/{}/{}_{}_{}.png'.format(experiment, prompt.replace(' + ', '_'), task, metric), bbox_inches='tight')
        plt.show()