# LE Hyperparameter Search (`lr` / epochs)

This notebook runs a Localized Entropy (LE) hyperparameter search using `configs/default.json` as-is.

- Outer loop: learning rate (`lr`) starts at `1.0` and halves each step.
- Inner loop: train for `5` epochs and evaluate each epoch.
- Per-epoch collection:
  - LE loss on test/eval split
  - Global calibration ratio on eval/test split only
  - Per-condition calibration ratio on eval/test split only


In [None]:
%matplotlib inline
from __future__ import annotations

import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from localized_entropy.hyper_search import build_search_context, run_le_hyper_search

np.set_printoptions(precision=6, suppress=True)
pd.set_option("display.max_rows", 200)
pd.set_option("display.max_columns", 50)

CONFIG_PATH = "configs/hyper.json"
ctx = build_search_context(CONFIG_PATH)

print(f"Evaluation split for LE loss: {ctx.eval_name}")
print(f"Data source: {ctx.data_source}")
print(f"Num conditions/categories: {ctx.splits.num_conditions}")



In [None]:
NORM_VALUES = [None, "batchnorm", "layernorm"]
param_grid = [{"norm": v} for v in NORM_VALUES]


def norm_label(norm_value):
    return "null" if norm_value is None else str(norm_value)


print("Norm values to try:", [norm_label(v) for v in NORM_VALUES])



In [None]:
def apply_params(cfg_run: dict, train_params: dict, params: dict) -> None:
    train_params["epochs"] = 5
    cfg_run.setdefault("model", {})
    cfg_run["model"]["norm"] = params["norm"]


results_df, condition_df = run_le_hyper_search(
    ctx,
    param_grid=param_grid,
    apply_params=apply_params,
    record_params=lambda p: {"norm": norm_label(p["norm"])},
    sort_by=["test_le", "norm", "epoch"],
    ascending=[True, True, True],
)

best_row = results_df.iloc[0]
print("Best LE result overall:")
print(best_row.to_dict())



In [None]:
def plot_metric_lines(metric_df: pd.DataFrame, value_col: str, title: str, y_label: str) -> None:
    plt.figure(figsize=(11, 6))
    norm_order = list(metric_df["norm"].drop_duplicates())
    for norm_name in norm_order:
        block = metric_df.loc[metric_df["norm"] == norm_name, ["epoch", value_col]].sort_values("epoch")
        plt.plot(
            block["epoch"],
            block[value_col],
            marker="o",
            linewidth=1.5,
            markersize=3,
            label=f"norm={norm_name}",
        )
    plt.xlabel("Epoch")
    plt.ylabel(y_label)
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend(ncol=3, fontsize=8)
    if "calibration" in value_col.lower():
        plt.ylim(0, 3)
    plt.tight_layout()
    plt.show()


table_cols = ["norm", "epoch", "test_le", "global_calibration"]
results_sorted = results_df.sort_values(["test_le", "norm", "epoch"], ascending=[True, True, True]).reset_index(drop=True)
condition_sorted = condition_df.sort_values(["condition", "norm", "epoch"]).reset_index(drop=True)

display(results_sorted.loc[:, table_cols].head(100))

best_row = results_sorted.iloc[0]
print("Most optimal norm (by minimum test_le):", str(best_row["norm"]))
print("Best epoch for that norm:", int(best_row["epoch"]))
print("Best test_le:", float(best_row["test_le"]))

plot_metric_lines(
    metric_df=results_sorted,
    value_col="test_le",
    title=f"LE ({ctx.eval_name}) across norm and epoch",
    y_label="LE loss",
)

plot_metric_lines(
    metric_df=results_sorted,
    value_col="global_calibration",
    title="Global calibration ratio across norm and epoch",
    y_label="Calibration ratio",
)

condition_table_cols = ["norm", "epoch", "condition", "count", "base_rate", "pred_mean", "calibration"]
display(condition_sorted.loc[:, condition_table_cols])

if condition_sorted.empty:
    print("No per-condition records to plot.")
else:
    cond_ids = sorted(condition_sorted["condition"].unique())
    n = len(cond_ids)
    ncols = 3
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5 * ncols, 3.8 * nrows))
    axes = np.array(axes).reshape(-1)

    norm_order = list(condition_sorted["norm"].drop_duplicates())
    epoch_order = sorted(condition_sorted["epoch"].unique())

    for i, cond_id in enumerate(cond_ids):
        ax = axes[i]
        block = condition_sorted.loc[condition_sorted["condition"] == cond_id, ["norm", "epoch", "calibration"]]
        ax.set_title(f"Condition {cond_id}")
        for norm_name in norm_order:
            norm_block = block.loc[block["norm"] == norm_name, ["epoch", "calibration"]].sort_values("epoch")
            if norm_block.empty:
                continue
            ax.plot(
                norm_block["epoch"],
                norm_block["calibration"],
                marker="o",
                linewidth=1.2,
                markersize=2.5,
                label=norm_name,
            )
        ax.set_xticks(epoch_order)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Calibration")
        ax.set_ylim(0, 3)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=6)

    for j in range(len(cond_ids), len(axes)):
        axes[j].axis("off")

    fig.suptitle("Per-condition calibration ratio across epoch (line per norm)")
    fig.tight_layout()
    plt.show()

