# Inference Analysis Notebook

This notebook analyzes the inference outputs from `outputs/inference/train/20260204_173740/ed_reattendance`.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    brier_score_loss,
    roc_curve,
    precision_recall_curve,
)
from sklearn.calibration import calibration_curve
from sklearn.utils import resample
import seaborn as sns

# Set style
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = [10, 6]

## 1. Data Loading

In [None]:
data_path = "outputs/inference/train/20260204_173740/ed_reattendance/0.parquet"
df = pd.read_parquet(data_path)

# Ensure correct types
df["label"] = df["boolean_value"].astype(int)
df["prob"] = pd.to_numeric(df["parsed_probability"], errors="coerce")

# Normalize probability if needed
if df["prob"].max() > 1.0:
    print("Detected probabilities > 1, normalizing by dividing by 100.")
    df["prob"] = df["prob"] / 100.0

print(f"Loaded {len(df)} rows.")
df.head()

## 2. Visualize Examples

In [None]:
for idx, row in df[:3].iterrows():
    print("=" * 80)
    print(f"Row {idx}")
    print(f"Prompt: {row['prompt']}")
    print("-" * 20)
    print(f"Model Output:\n{row['model_output'][:500]}...")
    print("\n")

## 3. Metrics with Bootstrapped Confidence Intervals

In [None]:
def calculate_metrics(y_true, y_prob):
    return {
        "AUROC": roc_auc_score(y_true, y_prob),
        "AUPRC": average_precision_score(y_true, y_prob),
        "Brier": brier_score_loss(y_true, y_prob),
    }


def bootstrap_metrics(y_true, y_prob, n_bootstrap=1000, seed=42):
    np.random.seed(seed)
    metrics_list = []

    base_metrics = calculate_metrics(y_true, y_prob)

    for _ in range(n_bootstrap):
        indices = resample(np.arange(len(y_true)), replace=True)
        y_true_boot = y_true[indices]
        y_prob_boot = y_prob[indices]

        try:
            metrics_list.append(calculate_metrics(y_true_boot, y_prob_boot))
        except ValueError:
            # Handle cases where bootstrap sample has only one class
            continue

    metrics_df = pd.DataFrame(metrics_list)

    results = {}
    for metric in base_metrics:
        lower = np.percentile(metrics_df[metric], 2.5)
        upper = np.percentile(metrics_df[metric], 97.5)
        results[metric] = {"value": base_metrics[metric], "95% CI": (lower, upper)}
    return results


# Filter out NaNs and validate bounds
valid_df = df.dropna(subset=["label", "prob"])
y_true = valid_df["label"].values
y_prob = valid_df["prob"].values

# Ensure probabilities are valid
y_prob = np.clip(y_prob, 0, 1)

results = bootstrap_metrics(y_true, y_prob)

for metric, data in results.items():
    print(
        f"{metric}: {data['value']:.4f} (95% CI: {data['95% CI'][0]:.4f} - {data['95% CI'][1]:.4f})"
    )

## 4. Plots with Bootstrapping

In [None]:
def plot_curve_with_bootstrap(
    y_true, y_prob, curve_type="roc", n_bootstrap=100, ax=None
):
    if ax is None:
        fig, ax = plt.subplots()

    if curve_type == "roc":
        base_x, base_y, _ = roc_curve(y_true, y_prob)
        ax.plot(base_x, base_y, color="blue", label="Actual", linewidth=2)
        ax.plot([0, 1], [0, 1], color="gray", linestyle="--")
        ax.set_xlabel("False Positive Rate")
        ax.set_ylabel("True Positive Rate")
        ax.set_title("ROC Curve")

        # Bootstrap for shading
        tprs = []
        base_fpr = np.linspace(0, 1, 101)

        for _ in range(n_bootstrap):
            indices = resample(np.arange(len(y_true)), replace=True)
            if len(np.unique(y_true[indices])) < 2:
                continue
            fpr, tpr, _ = roc_curve(y_true[indices], y_prob[indices])
            tpr_interp = np.interp(base_fpr, fpr, tpr)
            tpr_interp[0] = 0.0
            tprs.append(tpr_interp)

        tprs = np.array(tprs)
        mean_tprs = tprs.mean(axis=0)
        std_tprs = tprs.std(axis=0)
        tprs_upper = np.minimum(mean_tprs + 2 * std_tprs, 1)
        tprs_lower = np.maximum(mean_tprs - 2 * std_tprs, 0)

        ax.fill_between(
            base_fpr, tprs_lower, tprs_upper, color="blue", alpha=0.2, label="95% CI"
        )

    elif curve_type == "prc":
        precision, recall, _ = precision_recall_curve(y_true, y_prob)
        ax.plot(recall, precision, color="green", label="Actual", linewidth=2)
        ax.set_xlabel("Recall")
        ax.set_ylabel("Precision")
        ax.set_title("Precision-Recall Curve")

        # Bootstrap for shading - interpolation is trickier for PRC, simpler approach: plot many lines
        for _ in range(
            min(n_bootstrap, 50)
        ):  # Limit lines for PRC to avoid clutter if not interpolating
            indices = resample(np.arange(len(y_true)), replace=True)
            if len(np.unique(y_true[indices])) < 2:
                continue
            p, r, _ = precision_recall_curve(y_true[indices], y_prob[indices])
            ax.plot(r, p, color="green", alpha=0.05)

    elif curve_type == "calibration":
        prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=10)
        ax.plot(prob_pred, prob_true, marker="o", label="Actual", color="red")
        ax.plot([0, 1], [0, 1], linestyle="--", color="gray")
        ax.set_xlabel("Mean Predicted Probability")
        ax.set_ylabel("Fraction of Positives")
        ax.set_title("Calibration Curve")

        # Bootstrap
        for _ in range(min(n_bootstrap, 50)):
            indices = resample(np.arange(len(y_true)), replace=True)
            if len(np.unique(y_true[indices])) < 2:
                continue
            pt, pp = calibration_curve(y_true[indices], y_prob[indices], n_bins=10)
            ax.plot(pp, pt, color="red", alpha=0.1)

    ax.legend()


fig, axes = plt.subplots(1, 3, figsize=(20, 6))

plot_curve_with_bootstrap(y_true, y_prob, "roc", ax=axes[0])
plot_curve_with_bootstrap(y_true, y_prob, "prc", ax=axes[1])
plot_curve_with_bootstrap(y_true, y_prob, "calibration", ax=axes[2])

plt.tight_layout()
plt.show()