# LE Hyperparameter Search (`lr` / epochs)

This notebook runs a Localized Entropy (LE) hyperparameter search using `configs/default.json` as-is.

- Outer loop: learning rate (`lr`) starts at `1.0` and halves each step.
- Inner loop: train for `10` epochs and evaluate each epoch.
- Per-epoch collection:
  - LE loss on test/eval split
  - Global calibration ratio on all labeled data
  - Per-condition calibration ratio on all labeled data


In [None]:
%matplotlib inline
from __future__ import annotations

import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from localized_entropy.analysis import per_condition_calibration
from localized_entropy.config import load_and_resolve, get_data_source
from localized_entropy.data.pipeline import prepare_data
from localized_entropy.experiments import build_loss_loaders, build_model, train_single_loss
from localized_entropy.training import predict_probs, evaluate, compute_base_rates_from_loader
from localized_entropy.utils import init_device, set_seed

np.set_printoptions(precision=6, suppress=True)
pd.set_option("display.max_rows", 200)
pd.set_option("display.max_columns", 50)

CONFIG_PATH = "configs/default.json"
cfg = load_and_resolve(CONFIG_PATH)

device_cfg = cfg.get("device", {})
device, use_cuda, use_mps, non_blocking = init_device(use_mps=bool(device_cfg.get("use_mps", True)))
cpu_float64 = device.type == "cpu" and not bool(device_cfg.get("use_mps", True))
model_dtype = torch.float64 if cpu_float64 else torch.float32

set_seed(int(cfg["project"]["seed"]), use_cuda)
prepared = prepare_data(cfg, device, use_cuda, use_mps)
splits = prepared.splits

loss_loaders, le_train_cfg = build_loss_loaders(cfg, "localized_entropy", splits, device, use_cuda, use_mps)
data_source = get_data_source(cfg)

test_has_labels = not (data_source == "ctr" and not bool(cfg.get("ctr", {}).get("test_has_labels", False)))
if loss_loaders.test_loader is not None and splits.y_test is not None and test_has_labels:
    eval_loader = loss_loaders.test_loader
    eval_labels = np.asarray(splits.y_test).reshape(-1)
    eval_conds = np.asarray(splits.c_test).reshape(-1)
    eval_name = "test"
else:
    eval_loader = loss_loaders.eval_loader
    eval_labels = np.asarray(splits.y_eval).reshape(-1)
    eval_conds = np.asarray(splits.c_eval).reshape(-1)
    eval_name = "eval"

print(f"Evaluation split for LE loss: {eval_name}")
print(f"Data source: {data_source}")
print(f"Num conditions/categories: {splits.num_conditions}")


In [None]:
def lr_schedule_halving(start: float = 1.0, floor: float = 1e-6):
    values = []
    lr = float(start)
    while lr >= floor:
        values.append(lr)
        lr = lr / 2.0
    return values


def lr_category_schedule_decade(start: float = 1.0, floor: float = 1e-6):
    values = []
    lr_category = float(start)
    while lr_category >= floor:
        values.append(lr_category)
        lr_category = lr_category / 10.0
    return values


def _global_calibration_ratio(preds: np.ndarray, labels: np.ndarray, eps: float = 1e-12) -> float:
    pred_mean = float(np.mean(preds))
    label_mean = float(np.mean(labels))
    if label_mean <= eps:
        return float("nan")
    return pred_mean / label_mean


def collect_all_data_metrics(model: torch.nn.Module) -> tuple[float, pd.DataFrame]:
    train_preds = predict_probs(model, loss_loaders.train_loader, device, non_blocking=non_blocking)
    eval_preds_local = predict_probs(model, loss_loaders.eval_loader, device, non_blocking=non_blocking)

    preds_parts = [train_preds.reshape(-1), eval_preds_local.reshape(-1)]
    labels_parts = [np.asarray(splits.y_train).reshape(-1), np.asarray(splits.y_eval).reshape(-1)]
    conds_parts = [np.asarray(splits.c_train).reshape(-1), np.asarray(splits.c_eval).reshape(-1)]

    if loss_loaders.test_loader is not None and splits.y_test is not None and test_has_labels:
        test_preds_local = predict_probs(model, loss_loaders.test_loader, device, non_blocking=non_blocking)
        preds_parts.append(test_preds_local.reshape(-1))
        labels_parts.append(np.asarray(splits.y_test).reshape(-1))
        conds_parts.append(np.asarray(splits.c_test).reshape(-1))

    all_preds = np.concatenate(preds_parts)
    all_labels = np.concatenate(labels_parts)
    all_conds = np.concatenate(conds_parts)

    global_calibration = _global_calibration_ratio(all_preds, all_labels)
    per_cond = per_condition_calibration(all_preds, all_labels, all_conds)
    return global_calibration, per_cond


LEARNING_RATES = lr_schedule_halving(start=1.0, floor=1e-6)
LR_CATEGORY_RATES = lr_category_schedule_decade(start=1.0, floor=1e-6)
EPOCHS = 15

print(f"Number of lr points: {len(LEARNING_RATES)}")
print(f"Number of lr_category points: {len(LR_CATEGORY_RATES)}")
print("lr first/last:", LEARNING_RATES[0], LEARNING_RATES[-1])
print("lr_category first/last:", LR_CATEGORY_RATES[0], LR_CATEGORY_RATES[-1])


In [None]:
epoch_records = []
condition_records = []

first_param_dtype = torch.float32
if cpu_float64:
    first_param_dtype = torch.float64

# LE evaluation needs base rates from training data.
base_rates_train = compute_base_rates_from_loader(
    loss_loaders.train_loader,
    num_conditions=int(splits.num_conditions),
    device=device,
    dtype=first_param_dtype,
    non_blocking=non_blocking,
)

for lr_category in LR_CATEGORY_RATES:
    for lr in LEARNING_RATES:
        set_seed(int(cfg["project"]["seed"]), use_cuda)
        model = build_model(cfg, splits, device, dtype=model_dtype)

        def on_epoch_eval(_eval_preds: np.ndarray, epoch: int, lr_value: float = lr, lr_category_value: float = lr_category) -> None:
            le_loss, _ = evaluate(
                model,
                eval_loader,
                device,
                loss_mode="localized_entropy",
                base_rates=base_rates_train,
                non_blocking=non_blocking,
            )
            global_calibration, per_cond_df = collect_all_data_metrics(model)

            epoch_records.append(
                {
                    "lr_category": float(lr_category_value),
                    "lr": float(lr_value),
                    "epoch": int(epoch),
                    "test_le": float(le_loss),
                    "global_calibration": float(global_calibration),
                }
            )

            for _, row in per_cond_df.iterrows():
                condition_records.append(
                    {
                        "lr_category": float(lr_category_value),
                        "lr": float(lr_value),
                        "epoch": int(epoch),
                        "condition": int(row["condition"]),
                        "count": int(row["count"]),
                        "base_rate": float(row["base_rate"]),
                        "pred_mean": float(row["pred_mean"]),
                        "calibration": float(row["calibration"]),
                    }
                )

        _ = train_single_loss(
            model=model,
            loss_mode="localized_entropy",
            train_loader=loss_loaders.train_loader,
            train_eval_loader=eval_loader,
            eval_loader=eval_loader,
            device=device,
            epochs=EPOCHS,
            lr=float(lr),
            lr_decay=float(le_train_cfg.get('lr_decay', cfg['training'].get('lr_decay', 1.0))),
            lr_category_decay=float(le_train_cfg.get('lr_category_decay', cfg['training'].get('lr_category_decay', 1.0))),
            lr_category=float(lr_category),
            eval_has_labels=True,
            le_base_rates_train=base_rates_train,
            le_base_rates_train_eval=base_rates_train,
            le_base_rates_eval=base_rates_train,
            non_blocking=non_blocking,
            eval_callback=on_epoch_eval,
            plot_eval_hist_epochs=False,
            print_embedding_table=False,
        )

results_df = pd.DataFrame(epoch_records)
condition_df = pd.DataFrame(condition_records)

if results_df.empty:
    raise RuntimeError("Search produced no results.")

results_df = results_df.sort_values(["test_le", "lr_category", "lr", "epoch"], ascending=[True, True, True, True]).reset_index(drop=True)
best_row = results_df.iloc[0]

print("Best LE result overall:")
print(best_row.to_dict())


In [None]:
def plot_metric_lines(metric_df: pd.DataFrame, value_col: str, title: str, y_label: str) -> None:
    plt.figure(figsize=(11, 6))
    lr_order = sorted(metric_df["lr"].unique(), reverse=True)
    for lr in lr_order:
        block = metric_df.loc[metric_df["lr"] == lr, ["epoch", value_col]].sort_values("epoch")
        plt.plot(
            block["epoch"],
            block[value_col],
            marker="o",
            linewidth=1.5,
            markersize=3,
            label=f"lr={lr:.6g}",
        )
    plt.xlabel("Epoch")
    plt.ylabel(y_label)
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend(ncol=3, fontsize=8)
    if "calibration" in value_col.lower():
        plt.ylim(0, 3)
    plt.tight_layout()
    plt.show()


table_cols = ["lr_category", "lr", "epoch", "test_le", "global_calibration"]
lr_category_order = sorted(results_df["lr_category"].unique(), reverse=True)

for lr_category in lr_category_order:
    print("=" * 80)
    print(f"lr_category = {lr_category:.6g}")

    cat_results = results_df.loc[results_df["lr_category"] == lr_category].copy()
    cat_results = cat_results.sort_values(["test_le", "lr", "epoch"], ascending=[True, True, True]).reset_index(drop=True)
    cat_condition = condition_df.loc[condition_df["lr_category"] == lr_category].copy()

    display(cat_results.loc[:, table_cols].head(50))

    best_row_cat = cat_results.iloc[0]
    print("Most optimal lr (by minimum test_le):", float(best_row_cat["lr"]))
    print("Best epoch for that lr:", int(best_row_cat["epoch"]))
    print("Best test_le:", float(best_row_cat["test_le"]))

    # Chart 1: LE across learning rate and epochs
    plot_metric_lines(
        metric_df=cat_results,
        value_col="test_le",
        title=f"LE ({eval_name}) across learning rate and epoch | lr_category={lr_category:.6g}",
        y_label="LE loss",
    )

    # Chart 2: Global calibration across learning rate and epochs
    plot_metric_lines(
        metric_df=cat_results,
        value_col="global_calibration",
        title=f"Global calibration ratio across learning rate and epoch | lr_category={lr_category:.6g}",
        y_label="Calibration ratio",
    )

    # Chart 3: Per-condition calibration charts (line per lr)
    # Per-condition calibration table used for charts (this lr_category)
    condition_table_cols = ["lr_category", "lr", "epoch", "condition", "count", "base_rate", "pred_mean", "calibration"]
    display(cat_condition.loc[:, condition_table_cols].sort_values(["condition", "lr", "epoch"]).reset_index(drop=True))

    if cat_condition.empty:
        print("No per-condition records to plot.")
        continue

    cond_ids = sorted(cat_condition["condition"].unique())
    n = len(cond_ids)
    ncols = 3
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5 * ncols, 3.8 * nrows))
    axes = np.array(axes).reshape(-1)

    lr_order = sorted(cat_condition["lr"].unique(), reverse=True)
    epoch_order = sorted(cat_condition["epoch"].unique())

    for i, cond_id in enumerate(cond_ids):
        ax = axes[i]
        block = cat_condition.loc[cat_condition["condition"] == cond_id, ["lr", "epoch", "calibration"]]
        ax.set_title(f"Condition {cond_id}")
        for lr in lr_order:
            lr_block = block.loc[block["lr"] == lr, ["epoch", "calibration"]].sort_values("epoch")
            if lr_block.empty:
                continue
            ax.plot(
                lr_block["epoch"],
                lr_block["calibration"],
                marker="o",
                linewidth=1.2,
                markersize=2.5,
                label=f"{lr:.3g}",
            )
        ax.set_xticks(epoch_order)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Calibration")
        ax.set_ylim(0, 3)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=6)

    for j in range(len(cond_ids), len(axes)):
        axes[j].axis("off")

    fig.suptitle(f"Per-condition calibration ratio across epoch (line per lr) | lr_category={lr_category:.6g}")
    fig.tight_layout()
    plt.show()
