In [None]:
# analysis/aggregate_runs.ipynb

# %%
import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from scipy.stats import t

# ------------------------
# Config
# ------------------------
results_dir = "../results"        # folder containing run .pt files
confidence = 0.95                 # confidence interval
manual_run_paths = None           # optional: list of specific run paths, else None
save_summary = True               # save summary.json in results_dir

# ------------------------
# Helper functions
# ------------------------
def load_runs(run_paths):
    runs = []
    for path in run_paths:
        runs.append(torch.load(path))
    return runs

def mean_confidence_interval(data, confidence=0.95):
    """
    Compute mean and confidence interval along axis=0.
    """
    data = np.array(data)
    mean = np.mean(data, axis=0)
    sem = np.std(data, axis=0, ddof=1) / np.sqrt(len(data))
    h = sem * t.ppf((1 + confidence) / 2., len(data)-1)
    return mean, h

# ------------------------
# Collect run files
# ------------------------
if manual_run_paths is not None:
    run_paths = manual_run_paths
else:
    run_paths = glob(os.path.join(results_dir, "*.pt"))

if len(run_paths) == 0:
    raise ValueError(f"No run files found in {results_dir}")

print(f"Found {len(run_paths)} run files.")
runs = load_runs(run_paths)

# ------------------------
# Aggregate predictive metrics
# ------------------------
metrics_keys = list(runs[0]["test_metrics"].keys())
agg_metrics = {}
for k in metrics_keys:
    vals = [r["test_metrics"].get(k, np.nan) for r in runs]
    mean, ci = mean_confidence_interval(vals, confidence)
    agg_metrics[k] = {"mean": float(mean), "ci": float(ci)}

# Display metrics
print("Aggregated predictive metrics (mean ± CI):")
for k, v in agg_metrics.items():
    print(f"{k}: {v['mean']:.4f} ± {v['ci']:.4f}")

# ------------------------
# Aggregate training/validation curves
# ------------------------
history_keys = ["train_loss", "train_acc", "val_loss", "val_acc"]
agg_curves = {}
for key in history_keys:
    # align lengths (truncate to shortest)
    min_len = min(len(r["train_val_history"][key]) for r in runs)
    aligned = np.array([r["train_val_history"][key][:min_len] for r in runs])
    mean, ci = mean_confidence_interval(aligned, confidence)
    agg_curves[key] = {"mean": mean, "ci": ci}

# ------------------------
# Plot curves with confidence intervals
# ------------------------
for key in history_keys:
    plt.figure(figsize=(6,4))
    mean = agg_curves[key]["mean"]
    ci = agg_curves[key]["ci"]
    steps = np.arange(len(mean))
    plt.plot(steps, mean, label=f"Mean {key}")
    plt.fill_between(steps, mean - ci, mean + ci, alpha=0.3, color='C0')
    plt.title(f"{key} ± {int(confidence*100)}% CI")
    plt.xlabel("Step")
    plt.ylabel(key)
    plt.legend()
    plt.grid(True)
    plt.show()

# ------------------------
# Optionally save aggregated summary
# ------------------------
if save_summary:
    summary_path = os.path.join(results_dir, "summary.json")
    summary = {
        "metrics": agg_metrics,
        "curves": {k: {"mean": agg_curves[k]["mean"].tolist(),
                       "ci": agg_curves[k]["ci"].tolist()} for k in history_keys},
        "run_paths": run_paths
    }
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)
    print(f"Aggregated summary saved -> {summary_path}")

# %%

