# LE Hyperparameter Search (`lr` / epochs)

This notebook runs a Localized Entropy (LE) hyperparameter search using `configs/default.json` as-is.

- Outer loop: learning rate (`lr`) starts at `1.0` and halves each step.
- Inner loop: train for `10` epochs and evaluate each epoch.
- Per-epoch collection:
  - LE loss on test/eval split
  - Global calibration ratio on eval/test split only
  - Per-condition calibration ratio on eval/test split only


In [None]:
%matplotlib inline
from __future__ import annotations

import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from localized_entropy.hyper_search import build_search_context, run_le_hyper_search

np.set_printoptions(precision=6, suppress=True)
pd.set_option("display.max_rows", 200)
pd.set_option("display.max_columns", 50)

CONFIG_PATH = "configs/hyper.json"
ctx = build_search_context(CONFIG_PATH)

print(f"Evaluation split for LE loss: {ctx.eval_name}")
print(f"Data source: {ctx.data_source}")
print(f"Num conditions/categories: {ctx.splits.num_conditions}")



In [None]:
def amplification_rate_schedule_doubling(start: float = 1.0, max_value: float = 1024.0):
    values = []
    amp = float(start)
    max_value = float(max_value)
    while amp <= max_value:
        values.append(amp)
        amp = amp * 2.0
    return values


AMPLIFICATION_RATES = amplification_rate_schedule_doubling(start=1.0, max_value=1024.0)
param_grid = [{"amplification_rate": float(v)} for v in AMPLIFICATION_RATES]

print(f"Number of amplification_rate points: {len(AMPLIFICATION_RATES)}")
print("amplification_rate first/last:", AMPLIFICATION_RATES[0], AMPLIFICATION_RATES[-1])



In [None]:
def apply_params(cfg_run: dict, train_params: dict, params: dict) -> None:
    cross_batch_cfg = train_params.get("le_cross_batch_cfg")
    if not isinstance(cross_batch_cfg, dict):
        cross_batch_cfg = {}
    cross_batch_cfg["enabled"] = True
    cross_batch_cfg["amplification_rate"] = float(params["amplification_rate"])
    train_params["le_cross_batch_cfg"] = cross_batch_cfg


results_df, condition_df = run_le_hyper_search(
    ctx,
    param_grid=param_grid,
    apply_params=apply_params,
    record_params=lambda p: {"amplification_rate": float(p["amplification_rate"])},
    sort_by=["test_le", "amplification_rate", "epoch"],
    ascending=[True, True, True],
)

best_row = results_df.iloc[0]
print("Best LE result overall:")
print(best_row.to_dict())



In [None]:
def plot_metric_lines(metric_df: pd.DataFrame, value_col: str, title: str, y_label: str) -> None:
    plt.figure(figsize=(11, 6))
    amp_order = sorted(metric_df["amplification_rate"].unique(), reverse=True)
    for amp in amp_order:
        block = metric_df.loc[metric_df["amplification_rate"] == amp, ["epoch", value_col]].sort_values("epoch")
        plt.plot(
            block["epoch"],
            block[value_col],
            marker="o",
            linewidth=1.5,
            markersize=3,
            label=f"amp={amp:.6g}",
        )
    plt.xlabel("Epoch")
    plt.ylabel(y_label)
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend(ncol=3, fontsize=8)
    if "calibration" in value_col.lower():
        plt.ylim(0, 3)
    plt.tight_layout()
    plt.show()


table_cols = ["amplification_rate", "epoch", "test_le", "global_calibration"]
results_sorted = results_df.sort_values(["test_le", "amplification_rate", "epoch"], ascending=[True, True, True]).reset_index(drop=True)
condition_sorted = condition_df.sort_values(["condition", "amplification_rate", "epoch"]).reset_index(drop=True)

display(results_sorted.loc[:, table_cols].head(100))

best_row = results_sorted.iloc[0]
print("Most optimal amplification_rate (by minimum test_le):", float(best_row["amplification_rate"]))
print("Best epoch for that amplification_rate:", int(best_row["epoch"]))
print("Best test_le:", float(best_row["test_le"]))

plot_metric_lines(
    metric_df=results_sorted,
    value_col="test_le",
    title=f"LE ({ctx.eval_name}) across amplification_rate and epoch",
    y_label="LE loss",
)

plot_metric_lines(
    metric_df=results_sorted,
    value_col="global_calibration",
    title="Global calibration ratio across amplification_rate and epoch",
    y_label="Calibration ratio",
)

condition_table_cols = ["amplification_rate", "epoch", "condition", "count", "base_rate", "pred_mean", "calibration"]
display(condition_sorted.loc[:, condition_table_cols])

if condition_sorted.empty:
    print("No per-condition records to plot.")
else:
    cond_ids = sorted(condition_sorted["condition"].unique())
    n = len(cond_ids)
    ncols = 3
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5 * ncols, 3.8 * nrows))
    axes = np.array(axes).reshape(-1)

    amp_order = sorted(condition_sorted["amplification_rate"].unique(), reverse=True)
    epoch_order = sorted(condition_sorted["epoch"].unique())

    for i, cond_id in enumerate(cond_ids):
        ax = axes[i]
        block = condition_sorted.loc[condition_sorted["condition"] == cond_id, ["amplification_rate", "epoch", "calibration"]]
        ax.set_title(f"Condition {cond_id}")
        for amp in amp_order:
            amp_block = block.loc[block["amplification_rate"] == amp, ["epoch", "calibration"]].sort_values("epoch")
            if amp_block.empty:
                continue
            ax.plot(
                amp_block["epoch"],
                amp_block["calibration"],
                marker="o",
                linewidth=1.2,
                markersize=2.5,
                label=f"{amp:.3g}",
            )
        ax.set_xticks(epoch_order)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Calibration")
        ax.set_ylim(0, 3)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=6)

    for j in range(len(cond_ids), len(axes)):
        axes[j].axis("off")

    fig.suptitle("Per-condition calibration ratio across epoch (line per amplification_rate)")
    fig.tight_layout()
    plt.show()

