# LE Hyperparameter Search (Batch Size / Epochs)

This notebook runs a Localized Entropy (LE) hyperparameter search using `configs/hyper.json` and sweeps `batch_size` with fixed `epochs=5`.

- Outer loop: `batch_size` grid
- Inner loop: train for `5` epochs and evaluate each epoch
- Per-epoch collection:
  - LE loss on test/eval split
  - Global calibration ratio on eval/test split only
  - Per-condition calibration ratio on eval/test split only


In [None]:
%matplotlib inline
from __future__ import annotations

import copy
import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from localized_entropy.config import get_data_source, load_and_resolve, resolve_ctr_config
from localized_entropy.data.pipeline import prepare_data
from localized_entropy.experiments import build_loss_loaders, build_model, train_single_loss
from localized_entropy.training import compute_base_rates_from_loader, evaluate, predict_probs
from localized_entropy.analysis import per_condition_calibration
from localized_entropy.utils import init_device, set_seed

np.set_printoptions(precision=6, suppress=True)
pd.set_option("display.max_rows", 200)
pd.set_option("display.max_columns", 50)

CONFIG_PATH = "configs/hyper.json"
cfg = load_and_resolve(CONFIG_PATH)

device_cfg = cfg.get("device", {})
device, use_cuda, use_mps, non_blocking = init_device(use_mps=bool(device_cfg.get("use_mps", True)))
cpu_float64 = device.type == "cpu" and not bool(device_cfg.get("use_mps", True))
model_dtype = torch.float64 if cpu_float64 else torch.float32

set_seed(int(cfg["project"]["seed"]), use_cuda)
prepared = prepare_data(cfg, device, use_cuda, use_mps)
splits = prepared.splits

data_source = get_data_source(cfg)
test_has_labels = not (data_source == "ctr" and not bool(resolve_ctr_config(cfg).get("test_has_labels", False)))

print(f"Data source: {data_source}")
print(f"Num conditions/categories: {splits.num_conditions}")


In [None]:
def _global_calibration_ratio(preds: np.ndarray, labels: np.ndarray, eps: float = 1e-12) -> float:
    pred_mean = float(np.mean(preds))
    label_mean = float(np.mean(labels))
    if label_mean <= eps:
        return float("nan")
    return pred_mean / label_mean


def collect_all_data_metrics(model: torch.nn.Module, eval_loader, eval_labels: np.ndarray, eval_conds: np.ndarray):
    eval_preds = predict_probs(model, eval_loader, device, non_blocking=non_blocking)
    eval_preds = np.asarray(eval_preds, dtype=np.float64).reshape(-1)

    if eval_preds.shape[0] != eval_labels.shape[0] or eval_preds.shape[0] != eval_conds.shape[0]:
        raise ValueError(
            f"Mismatched eval lengths: preds={eval_preds.shape[0]} labels={eval_labels.shape[0]} conds={eval_conds.shape[0]}"
        )

    global_calibration = _global_calibration_ratio(eval_preds, eval_labels)
    per_cond = per_condition_calibration(eval_preds, eval_labels, eval_conds)
    return global_calibration, per_cond


BATCH_SIZES = [4096, 8192, 16384, 25000, 32768, 65536]
EPOCHS = 5

print("Batch sizes to try:", BATCH_SIZES)
print("Epochs per run:", EPOCHS)


In [None]:
epoch_records = []
condition_records = []

for batch_size in BATCH_SIZES:
    cfg_run = copy.deepcopy(cfg)
    cfg_run.setdefault("training", {})
    cfg_run["training"]["epochs"] = int(EPOCHS)
    cfg_run["training"]["batch_size"] = int(batch_size)

    # Ensure per-source LE overrides use this sweep value where present.
    le_by_source = cfg_run.get("training", {}).get("by_loss", {}).get("localized_entropy", {}).get("by_source", {})
    for src_name, src_cfg in le_by_source.items():
        if isinstance(src_cfg, dict):
            src_cfg["epochs"] = int(EPOCHS)
            src_cfg["batch_size"] = int(batch_size)

    loss_loaders, le_train_cfg = build_loss_loaders(cfg_run, "localized_entropy", splits, device, use_cuda, use_mps)

    if loss_loaders.test_loader is not None and splits.y_test is not None and test_has_labels:
        eval_loader = loss_loaders.test_loader
        eval_labels = np.asarray(splits.y_test).reshape(-1)
        eval_conds = np.asarray(splits.c_test).reshape(-1)
        eval_name = "test"
    else:
        eval_loader = loss_loaders.eval_loader
        eval_labels = np.asarray(splits.y_eval).reshape(-1)
        eval_conds = np.asarray(splits.c_eval).reshape(-1)
        eval_name = "eval"

    first_param_dtype = torch.float64 if cpu_float64 else torch.float32
    base_rates_train = compute_base_rates_from_loader(
        loss_loaders.train_loader,
        num_conditions=int(splits.num_conditions),
        device=device,
        dtype=first_param_dtype,
        non_blocking=non_blocking,
    )

    set_seed(int(cfg_run["project"]["seed"]), use_cuda)
    model = build_model(cfg_run, splits, device, dtype=model_dtype)

    def on_epoch_eval(_eval_preds: np.ndarray, epoch: int, bs: int = int(batch_size)) -> None:
        le_loss, _ = evaluate(
            model,
            eval_loader,
            device,
            loss_mode="localized_entropy",
            base_rates=base_rates_train,
            non_blocking=non_blocking,
        )
        global_calibration, per_cond_df = collect_all_data_metrics(model, eval_loader, eval_labels, eval_conds)

        epoch_records.append(
            {
                "batch_size": int(bs),
                "epoch": int(epoch),
                "test_le": float(le_loss),
                "global_calibration": float(global_calibration),
                "eval_split": eval_name,
            }
        )

        for _, row in per_cond_df.iterrows():
            condition_records.append(
                {
                    "batch_size": int(bs),
                    "epoch": int(epoch),
                    "condition": int(row["condition"]),
                    "count": int(row["count"]),
                    "base_rate": float(row["base_rate"]),
                    "pred_mean": float(row["pred_mean"]),
                    "calibration": float(row["calibration"]),
                }
            )

    train_single_loss(
        model=model,
        loss_mode="localized_entropy",
        train_loader=loss_loaders.train_loader,
        train_eval_loader=eval_loader,
        eval_loader=eval_loader,
        device=device,
        epochs=int(EPOCHS),
        lr=float(le_train_cfg.get("lr", cfg_run["training"]["lr"])),
        lr_category=le_train_cfg.get("lr_category"),
        lr_decay=float(le_train_cfg.get("lr_decay", cfg_run.get("training", {}).get("lr_decay", 1.0))),
        lr_category_decay=float(le_train_cfg.get("lr_category_decay", cfg_run.get("training", {}).get("lr_category_decay", 1.0))),
        lr_zero_after_epochs=le_train_cfg.get("lr_zero_after_epochs"),
        eval_has_labels=True,
        le_base_rates_train=base_rates_train,
        le_base_rates_train_eval=base_rates_train,
        le_base_rates_eval=base_rates_train,
        non_blocking=non_blocking,
        eval_callback=on_epoch_eval,
        plot_eval_hist_epochs=False,
        print_embedding_table=False,
        le_cross_batch_cfg=copy.deepcopy(le_train_cfg.get("cross_batch")) if isinstance(le_train_cfg.get("cross_batch"), dict) else None,
    )

results_df = pd.DataFrame(epoch_records)
condition_df = pd.DataFrame(condition_records)

if results_df.empty:
    raise RuntimeError("Search produced no results.")

results_df = results_df.sort_values(["test_le", "batch_size", "epoch"], ascending=[True, True, True]).reset_index(drop=True)
best_row = results_df.iloc[0]

print(f"Evaluation split for LE loss: {best_row['eval_split']}")
print("Best LE result:")
print(best_row.to_dict())


In [None]:
def plot_metric_lines(metric_df: pd.DataFrame, value_col: str, title: str, y_label: str) -> None:
    plt.figure(figsize=(11, 6))
    batch_order = sorted(metric_df["batch_size"].unique())
    for batch_size in batch_order:
        block = metric_df.loc[metric_df["batch_size"] == batch_size, ["epoch", value_col]].sort_values("epoch")
        plt.plot(
            block["epoch"],
            block[value_col],
            marker="o",
            linewidth=1.5,
            markersize=3,
            label=f"batch_size={int(batch_size)}",
        )
    plt.xlabel("Epoch")
    plt.ylabel(y_label)
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend(ncol=3, fontsize=8)
    if "calibration" in value_col.lower():
        plt.ylim(0, 3)
    plt.tight_layout()
    plt.show()


table_cols = ["batch_size", "epoch", "test_le", "global_calibration", "eval_split"]
results_sorted = results_df.sort_values(["test_le", "batch_size", "epoch"], ascending=[True, True, True]).reset_index(drop=True)
condition_sorted = condition_df.sort_values(["condition", "batch_size", "epoch"]).reset_index(drop=True)

display(results_sorted.loc[:, table_cols].head(200))

best_row = results_sorted.iloc[0]
print("Most optimal batch_size (by minimum test_le):", int(best_row["batch_size"]))
print("Best epoch for that batch_size:", int(best_row["epoch"]))
print("Best test_le:", float(best_row["test_le"]))

plot_metric_lines(
    metric_df=results_sorted,
    value_col="test_le",
    title="LE across batch size and epoch",
    y_label="LE loss",
)

plot_metric_lines(
    metric_df=results_sorted,
    value_col="global_calibration",
    title="Global calibration ratio across batch size and epoch",
    y_label="Calibration ratio",
)

condition_table_cols = ["batch_size", "epoch", "condition", "count", "base_rate", "pred_mean", "calibration"]
display(condition_sorted.loc[:, condition_table_cols])

if condition_sorted.empty:
    print("No per-condition records to plot.")
else:
    cond_ids = sorted(condition_sorted["condition"].unique())
    n = len(cond_ids)
    ncols = 3
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5 * ncols, 3.8 * nrows))
    axes = np.array(axes).reshape(-1)

    batch_order = sorted(condition_sorted["batch_size"].unique())
    epoch_order = sorted(condition_sorted["epoch"].unique())

    for i, cond_id in enumerate(cond_ids):
        ax = axes[i]
        block = condition_sorted.loc[condition_sorted["condition"] == cond_id, ["batch_size", "epoch", "calibration"]]
        ax.set_title(f"Condition {cond_id}")
        for batch_size in batch_order:
            bs_block = block.loc[block["batch_size"] == batch_size, ["epoch", "calibration"]].sort_values("epoch")
            if bs_block.empty:
                continue
            ax.plot(
                bs_block["epoch"],
                bs_block["calibration"],
                marker="o",
                linewidth=1.2,
                markersize=2.5,
                label=f"{int(batch_size)}",
            )
        ax.set_xticks(epoch_order)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Calibration")
        ax.set_ylim(0, 3)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=6)

    for j in range(len(cond_ids), len(axes)):
        axes[j].axis("off")

    fig.suptitle("Per-condition calibration ratio across epoch (line per batch size)")
    fig.tight_layout()
    plt.show()
