# LE Hyperparameter Search (`lr_decay` and `lr_category_decay` / epochs)
- Sweep: `lr_decay` and `lr_category_decay` combinations.

This notebook runs a Localized Entropy (LE) hyperparameter search using `configs/hyper.json`.

Optimization objective (lower is better):
1. `ece_small` (computed on predictions where `p <= 0.01`; threshold from `evaluation.small_prob_max`)
2. `per_condition_calibration_error`
3. `ece`
4. `global_calibration_error`
5. `test_le` (tie-breaker)

The notebook also reports per-condition calibration tables and metric traces across epochs.


In [None]:
%matplotlib inline
from __future__ import annotations

import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from localized_entropy.hyper_search import build_search_context, run_le_hyper_search

np.set_printoptions(precision=6, suppress=True)
pd.set_option("display.max_rows", 200)
pd.set_option("display.max_columns", 50)

CONFIG_PATH = "configs/hyper.json"
ctx = build_search_context(CONFIG_PATH)

print(f"Evaluation split for LE loss: {ctx.eval_name}")
print(f"Data source: {ctx.data_source}")
print(f"Num conditions/categories: {ctx.splits.num_conditions}")



In [None]:
DECAY_VALUES_1 = [ 0.99, 0.995, 0.999, 0.9999]
DECAY_VALUES_2 = [ 0.999, 0.9999]
param_grid = [
    {"lr_decay": float(lr_decay), "lr_category_decay": float(lr_category_decay)}
    for lr_category_decay in DECAY_VALUES_1
    for lr_decay in DECAY_VALUES_2
]

print("lr_decay values:", DECAY_VALUES_1)
print("lr_category_decay values:", DECAY_VALUES_2)
print("Number of grid points:", len(param_grid))



In [None]:
def apply_params(cfg_run: dict, train_params: dict, params: dict) -> None:
    train_params["epochs"] = 5
    train_params["lr_decay"] = float(params["lr_decay"])
    train_params["lr_category_decay"] = float(params["lr_category_decay"])


results_df, condition_df = run_le_hyper_search(
    ctx,
    param_grid=param_grid,
    apply_params=apply_params,
    record_params=lambda p: {
        "lr_decay": float(p["lr_decay"]),
        "lr_category_decay": float(p["lr_category_decay"]),
    },
    sort_by=["ece_small", "per_condition_calibration_error", "ece", "global_calibration_error", "test_le", "lr_category_decay", "lr_decay", "epoch"],
    ascending=[True, True, True, True, True, True, True, True],
)

best_row = results_df.iloc[0]
print("Best objective result overall (ECE-small prioritized):")
print(best_row.to_dict())



In [None]:
def plot_metric_lines(metric_df: pd.DataFrame, value_col: str, title: str, y_label: str) -> None:
    plt.figure(figsize=(11, 6))
    decay_order = sorted(metric_df["lr_decay"].unique(), reverse=True)
    for lr_decay in decay_order:
        block = metric_df.loc[metric_df["lr_decay"] == lr_decay, ["epoch", value_col]].sort_values("epoch")
        plt.plot(
            block["epoch"],
            block[value_col],
            marker="o",
            linewidth=1.5,
            markersize=3,
            label=f"lr_decay={lr_decay:.6g}",
        )
    plt.xlabel("Epoch")
    plt.ylabel(y_label)
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend(ncol=3, fontsize=8)
    if "calibration" in value_col.lower():
        plt.ylim(0, 3)
    plt.tight_layout()
    plt.show()


table_cols = ["lr_category_decay", "lr_decay", "epoch", "ece_small", "per_condition_calibration_error", "ece", "global_calibration_error", "test_le", "global_calibration"]
lr_category_decay_order = sorted(results_df["lr_category_decay"].unique(), reverse=True)

for lr_category_decay in lr_category_decay_order:
    print("=" * 80)
    print(f"lr_category_decay = {lr_category_decay:.6g}")

    cat_results = results_df.loc[results_df["lr_category_decay"] == lr_category_decay].copy()
    cat_results = cat_results.sort_values(["ece_small", "per_condition_calibration_error", "ece", "global_calibration_error", "test_le", "lr_decay", "epoch"], ascending=[True, True, True, True, True, True, True]).reset_index(drop=True)
    cat_condition = condition_df.loc[condition_df["lr_category_decay"] == lr_category_decay].copy()

    display(cat_results.loc[:, table_cols].head(50))

    best_row_cat = cat_results.iloc[0]
    print("Most optimal lr_decay (by minimum ECE-small, then per-condition calibration error):", float(best_row_cat["lr_decay"]))
    print("Best epoch for that lr_decay:", int(best_row_cat["epoch"]))
    print("Best ece_small:", float(best_row_cat["ece_small"]))
    print("Best per_condition_calibration_error:", float(best_row_cat["per_condition_calibration_error"]))
    print("Best ece:", float(best_row_cat["ece"]))
    print("Best test_le (tie-breaker):", float(best_row_cat["test_le"]))

    plot_metric_lines(
        metric_df=cat_results,
        value_col="ece_small",
        title=f"ECE-small ({ctx.eval_name}, p <= 0.01) across lr_decay and epoch | lr_category_decay={lr_category_decay:.6g}",
        y_label="ECE-small",
    )

    plot_metric_lines(
        metric_df=cat_results,
        value_col="per_condition_calibration_error",
        title=f"Per-condition calibration error across lr_decay and epoch | lr_category_decay={lr_category_decay:.6g}",
        y_label="Per-condition calibration error",
    )

    condition_table_cols = [
        "lr_category_decay",
        "lr_decay",
        "epoch",
        "condition",
        "count",
        "base_rate",
        "pred_mean",
        "calibration",
        "calibration_abs_error",
        "ece",
        "ece_small",
    ]
    display(cat_condition.loc[:, condition_table_cols].sort_values(["condition", "lr_decay", "epoch"]).reset_index(drop=True))

    if cat_condition.empty:
        print("No per-condition records to plot.")
        continue

    cond_ids = sorted(cat_condition["condition"].unique())
    n = len(cond_ids)
    ncols = 3
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5 * ncols, 3.8 * nrows))
    axes = np.array(axes).reshape(-1)

    lr_decay_order = sorted(cat_condition["lr_decay"].unique(), reverse=True)
    epoch_order = sorted(cat_condition["epoch"].unique())

    for i, cond_id in enumerate(cond_ids):
        ax = axes[i]
        block = cat_condition.loc[cat_condition["condition"] == cond_id, ["lr_decay", "epoch", "calibration"]]
        ax.set_title(f"Condition {cond_id}")
        for lr_decay in lr_decay_order:
            lr_block = block.loc[block["lr_decay"] == lr_decay, ["epoch", "calibration"]].sort_values("epoch")
            if lr_block.empty:
                continue
            ax.plot(
                lr_block["epoch"],
                lr_block["calibration"],
                marker="o",
                linewidth=1.2,
                markersize=2.5,
                label=f"{lr_decay:.4g}",
            )
        ax.set_xticks(epoch_order)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Calibration")
        ax.set_ylim(0, 3)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=6)

    for j in range(len(cond_ids), len(axes)):
        axes[j].axis("off")

    fig.suptitle(f"Per-condition calibration ratio across epoch (line per lr_decay) | lr_category_decay={lr_category_decay:.6g}")
    fig.tight_layout()
    plt.show()

