In [2]:
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
from cycler import cycler

df = pd.read_csv("./output/runs.csv")
df = df.drop_duplicates(subset=["experiment_name","dataset","shots","config_label"])
df["shots"] = df["shots"].astype(int)

plt.rcParams.update({
    "font.size": 12, "figure.dpi": 150,
    "axes.prop_cycle": cycler(color=plt.cm.tab10.colors)
})
metrics = [("acc_mean","acc_std","Accuracy"),
           ("ece_mean","ece_std","ECE"),
           ("aece_mean","aece_std","AECE")]
outdir = "./output/plots"; os.makedirs(outdir, exist_ok=True)

# # --- Single-metric grouped BAR plots ---
# for dset, g in df.groupby("dataset"):
#     models = sorted(g["experiment_name"].unique())
#     shots = np.array(sorted(g["shots"].unique()))
#     x = np.arange(len(shots)); W = 0.8 / max(1, len(models))

#     for m_mean, m_std, m_name in metrics:
#         plt.figure(figsize=(8,6)); plt.title(f"{dset} — {m_name}")
#         for i, model in enumerate(models):
#             gm = g[g["experiment_name"]==model]
#             h = gm.groupby("shots", as_index=True).agg({m_mean:"mean", m_std:"mean"})
#             y    = h[m_mean].reindex(shots).to_numpy()
#             yerr = h[m_std].reindex(shots).to_numpy()
#             plt.bar(x + (i - (len(models)-1)/2)*W, y, W, yerr=yerr, capsize=3,
#                     edgecolor="black", linewidth=0.6, label=model)
#         plt.xticks(x, shots); plt.xlabel("Shots"); plt.ylabel(m_name)
#         plt.grid(True, ls="--", alpha=.3, axis="y")
#         if m_name in ["ECE","AECE"]:
#             plt.gca().yaxis.set_major_formatter(plt.FormatStrFormatter("%.3f"))
#         if m_name == "Accuracy":
#             lo, hi = g["acc_mean"].min(), g["acc_mean"].max()
#             plt.ylim(max(0, lo - 10), min(100, hi + 10))
#         plt.legend(frameon=False)
#         plt.tight_layout(); plt.savefig(f"{outdir}/{dset}_{m_name}.png"); plt.close()

# --- 3-metric BAR plots per dataset (legend only in AECE) ---
for dset, g in df.groupby("dataset"):
    models = sorted(g["experiment_name"].unique())
    shots = np.array(sorted(g["shots"].unique()))
    x = np.arange(len(shots)); W = 0.8 / max(1, len(models))

    fig, axes = plt.subplots(1, 3, figsize=(20,8), sharex=True)
    for ax, (m_mean, m_std, m_name) in zip(axes, metrics):
        for i, model in enumerate(models):
            gm = g[g["experiment_name"]==model]
            h = gm.groupby("shots", as_index=True).agg({m_mean:"mean", m_std:"mean"})
            y    = h[m_mean].reindex(shots).to_numpy()
            yerr = h[m_std].reindex(shots).to_numpy()
            ax.bar(x + (i - (len(models)-1)/2)*W, y, W, yerr=yerr, capsize=3,
                   edgecolor="black", linewidth=0.6, label=model)
        ax.set_title(m_name); ax.set_xlabel("Shots"); ax.set_ylabel(m_name)
        ax.grid(True, ls="--", alpha=.3, axis="y"); ax.set_xticks(x, shots)
        if m_name in ["ECE","AECE"]:
            ax.yaxis.set_major_formatter(plt.FormatStrFormatter("%.3f"))
        if m_name == "Accuracy":
            lo, hi = g["acc_mean"].min(), g["acc_mean"].max()
            ax.set_ylim(max(0, lo - 10), min(100, hi + 10))
        if m_name == "AECE":
            ax.legend(frameon=False, loc="best")
    plt.tight_layout(); plt.savefig(f"{outdir}/{dset}_all_metrics.png"); plt.close()
