# LE Hyperparameter Search (`lr` / epochs)

This notebook runs a Localized Entropy (LE) hyperparameter search using `configs/default.json` as-is.

- Outer loop: learning rate (`lr`) starts at `1.0` and halves each step.
- Inner loop: train for `5` epochs and evaluate each epoch.
- Per-epoch collection:
  - LE loss on test/eval split
  - Global calibration ratio on eval/test split only
  - Per-condition calibration ratio on eval/test split only


In [None]:
%matplotlib inline
from __future__ import annotations

import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from localized_entropy.hyper_search import build_search_context, run_le_hyper_search

np.set_printoptions(precision=6, suppress=True)
pd.set_option("display.max_rows", 200)
pd.set_option("display.max_columns", 50)

CONFIG_PATH = "configs/hyper.json"
ctx = build_search_context(CONFIG_PATH)

print(f"Evaluation split for LE loss: {ctx.eval_name}")
print(f"Data source: {ctx.data_source}")
print(f"Num conditions/categories: {ctx.splits.num_conditions}")



In [None]:
def lr_schedule_halving(start: float = 1.0, floor: float = 1e-6):
    values = []
    lr = float(start)
    while lr >= floor:
        values.append(lr)
        lr = lr / 2.0
    return values


def lr_category_schedule_decade(start: float = 1.0, floor: float = 1e-6):
    values = []
    lr_category = float(start)
    while lr_category >= floor:
        values.append(lr_category)
        lr_category = lr_category / 10.0
    return values


LEARNING_RATES = lr_schedule_halving(start=0.5, floor=0.001)
LR_CATEGORY_RATES = lr_category_schedule_decade(start=1.0, floor=0.001)
EPOCHS = 5

param_grid = [
    {"lr_category": float(lr_category), "lr": float(lr)}
    for lr_category in LR_CATEGORY_RATES
    for lr in LEARNING_RATES
]

print(f"Number of lr points: {len(LEARNING_RATES)}")
print(f"Number of lr_category points: {len(LR_CATEGORY_RATES)}")
print("lr first/last:", LEARNING_RATES[0], LEARNING_RATES[-1])
print("lr_category first/last:", LR_CATEGORY_RATES[0], LR_CATEGORY_RATES[-1])



In [None]:
def apply_params(cfg_run: dict, train_params: dict, params: dict) -> None:
    train_params["lr"] = float(params["lr"])
    train_params["lr_category"] = float(params["lr_category"])
    train_params["epochs"] = int(EPOCHS)


results_df, condition_df = run_le_hyper_search(
    ctx,
    param_grid=param_grid,
    apply_params=apply_params,
    record_params=lambda p: {"lr_category": float(p["lr_category"]), "lr": float(p["lr"])},
    sort_by=["test_le", "lr_category", "lr", "epoch"],
    ascending=[True, True, True, True],
)

best_row = results_df.iloc[0]
print("Best LE result overall:")
print(best_row.to_dict())



In [None]:
def plot_metric_lines(metric_df: pd.DataFrame, value_col: str, title: str, y_label: str) -> None:
    plt.figure(figsize=(11, 6))
    lr_order = sorted(metric_df["lr"].unique(), reverse=True)
    for lr in lr_order:
        block = metric_df.loc[metric_df["lr"] == lr, ["epoch", value_col]].sort_values("epoch")
        plt.plot(
            block["epoch"],
            block[value_col],
            marker="o",
            linewidth=1.5,
            markersize=3,
            label=f"lr={lr:.6g}",
        )
    plt.xlabel("Epoch")
    plt.ylabel(y_label)
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend(ncol=3, fontsize=8)
    if "calibration" in value_col.lower():
        plt.ylim(0, 3)
    plt.tight_layout()
    plt.show()


table_cols = ["lr_category", "lr", "epoch", "test_le", "global_calibration"]
lr_category_order = sorted(results_df["lr_category"].unique(), reverse=True)

for lr_category in lr_category_order:
    print("=" * 80)
    print(f"lr_category = {lr_category:.6g}")

    cat_results = results_df.loc[results_df["lr_category"] == lr_category].copy()
    cat_results = cat_results.sort_values(["test_le", "lr", "epoch"], ascending=[True, True, True]).reset_index(drop=True)
    cat_condition = condition_df.loc[condition_df["lr_category"] == lr_category].copy()

    display(cat_results.loc[:, table_cols].head(50))

    best_row_cat = cat_results.iloc[0]
    print("Most optimal lr (by minimum test_le):", float(best_row_cat["lr"]))
    print("Best epoch for that lr:", int(best_row_cat["epoch"]))
    print("Best test_le:", float(best_row_cat["test_le"]))

    plot_metric_lines(
        metric_df=cat_results,
        value_col="test_le",
        title=f"LE ({ctx.eval_name}) across learning rate and epoch | lr_category={lr_category:.6g}",
        y_label="LE loss",
    )

    plot_metric_lines(
        metric_df=cat_results,
        value_col="global_calibration",
        title=f"Global calibration ratio across learning rate and epoch | lr_category={lr_category:.6g}",
        y_label="Calibration ratio",
    )

    condition_table_cols = ["lr_category", "lr", "epoch", "condition", "count", "base_rate", "pred_mean", "calibration"]
    display(cat_condition.loc[:, condition_table_cols].sort_values(["condition", "lr", "epoch"]).reset_index(drop=True))

    if cat_condition.empty:
        print("No per-condition records to plot.")
        continue

    cond_ids = sorted(cat_condition["condition"].unique())
    n = len(cond_ids)
    ncols = 3
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5 * ncols, 3.8 * nrows))
    axes = np.array(axes).reshape(-1)

    lr_order = sorted(cat_condition["lr"].unique(), reverse=True)
    epoch_order = sorted(cat_condition["epoch"].unique())

    for i, cond_id in enumerate(cond_ids):
        ax = axes[i]
        block = cat_condition.loc[cat_condition["condition"] == cond_id, ["lr", "epoch", "calibration"]]
        ax.set_title(f"Condition {cond_id}")
        for lr in lr_order:
            lr_block = block.loc[block["lr"] == lr, ["epoch", "calibration"]].sort_values("epoch")
            if lr_block.empty:
                continue
            ax.plot(
                lr_block["epoch"],
                lr_block["calibration"],
                marker="o",
                linewidth=1.2,
                markersize=2.5,
                label=f"{lr:.3g}",
            )
        ax.set_xticks(epoch_order)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Calibration")
        ax.set_ylim(0, 3)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=6)

    for j in range(len(cond_ids), len(axes)):
        axes[j].axis("off")

    fig.suptitle(f"Per-condition calibration ratio across epoch (line per lr) | lr_category={lr_category:.6g}")
    fig.tight_layout()
    plt.show()

