# BCE Hyperparameter Search (`ler` / epochs)

This notebook runs a BCE-only hyperparameter search using `configs/default.json` exactly as-is.

- Outer loop: learning rate (`ler`) starts at `1.0` and halves each step.
- Inner loop: train for `10` epochs and evaluate at every epoch.
- Evaluation at each epoch collects:
  - BCE on test/eval split
  - Global calibration ratio on eval/test split only
  - Per-category (condition) calibration ratio on eval/test split only


In [None]:
%matplotlib inline
from __future__ import annotations

import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from localized_entropy.analysis import bce_log_loss, per_condition_calibration
from localized_entropy.config import load_and_resolve, get_data_source, resolve_ctr_config
from localized_entropy.data.pipeline import prepare_data
from localized_entropy.experiments import build_loss_loaders, build_model, train_single_loss
from localized_entropy.training import predict_probs
from localized_entropy.utils import init_device, set_seed

np.set_printoptions(precision=6, suppress=True)
pd.set_option("display.max_rows", 200)
pd.set_option("display.max_columns", 50)

CONFIG_PATH = "configs/hyper.json"
cfg = load_and_resolve(CONFIG_PATH)

device_cfg = cfg.get("device", {})
device, use_cuda, use_mps, non_blocking = init_device(use_mps=bool(device_cfg.get("use_mps", True)))
cpu_float64 = device.type == "cpu" and not bool(device_cfg.get("use_mps", True))
model_dtype = torch.float64 if cpu_float64 else torch.float32

set_seed(int(cfg["project"]["seed"]), use_cuda)
prepared = prepare_data(cfg, device, use_cuda, use_mps)
splits = prepared.splits

loss_loaders, bce_train_cfg = build_loss_loaders(cfg, "bce", splits, device, use_cuda, use_mps)
data_source = get_data_source(cfg)

test_has_labels = not (data_source == "ctr" and not bool(resolve_ctr_config(cfg).get("test_has_labels", False)))
if loss_loaders.test_loader is not None and splits.y_test is not None and test_has_labels:
    eval_loader = loss_loaders.test_loader
    eval_labels = np.asarray(splits.y_test).reshape(-1)
    eval_conds = np.asarray(splits.c_test).reshape(-1)
    eval_name = "test"
else:
    eval_loader = loss_loaders.eval_loader
    eval_labels = np.asarray(splits.y_eval).reshape(-1)
    eval_conds = np.asarray(splits.c_eval).reshape(-1)
    eval_name = "eval"

print(f"Evaluation split for BCE loss: {eval_name}")
print(f"Data source: {data_source}")
print(f"Num conditions/categories: {splits.num_conditions}")


In [None]:
def lr_schedule_halving(start: float = 1.0, floor: float = 1e-6):
    if floor <= 0:
        raise ValueError("floor must be > 0")
    values = []
    lr = float(start)
    while lr >= floor:
        values.append(lr)
        lr = lr / 2.0
    if not values:
        raise ValueError(f"No learning rates generated: start={start} floor={floor}")
    return values


def _global_calibration_ratio(preds: np.ndarray, labels: np.ndarray, eps: float = 1e-12) -> float:
    pred_mean = float(np.mean(preds))
    label_mean = float(np.mean(labels))
    if label_mean <= eps:
        return float("nan")
    return pred_mean / label_mean


def collect_all_data_metrics(model: torch.nn.Module) -> tuple[float, pd.DataFrame]:
    eval_preds_for_cal = predict_probs(model, eval_loader, device, non_blocking=non_blocking)
    eval_preds_for_cal = np.asarray(eval_preds_for_cal, dtype=np.float64).reshape(-1)

    if eval_preds_for_cal.shape[0] != eval_labels.shape[0] or eval_preds_for_cal.shape[0] != eval_conds.shape[0]:
        raise ValueError(
            f"Mismatched eval lengths: preds={eval_preds_for_cal.shape[0]} labels={eval_labels.shape[0]} conds={eval_conds.shape[0]}"
        )

    global_calibration = _global_calibration_ratio(eval_preds_for_cal, eval_labels)
    per_cond = per_condition_calibration(eval_preds_for_cal, eval_labels, eval_conds)
    return global_calibration, per_cond


LEARNING_RATES = lr_schedule_halving(start=0.1, floor=0.00001)
EPOCHS = 15

print(f"Number of LR points: {len(LEARNING_RATES)}")
print("First/last LR:", LEARNING_RATES[0], LEARNING_RATES[-1])


In [None]:
epoch_records = []
condition_records = []

for lr in LEARNING_RATES:
    set_seed(int(cfg["project"]["seed"]), use_cuda)
    model = build_model(cfg, splits, device, dtype=model_dtype)

    def on_epoch_eval(eval_preds: np.ndarray, epoch: int) -> None:
        test_bce = bce_log_loss(eval_preds.reshape(-1), eval_labels)
        global_calibration, per_cond_df = collect_all_data_metrics(model)
        print(f"[CAL] lr={float(lr):.6g} epoch={int(epoch)} global={float(global_calibration):.6f}")
        if per_cond_df.empty:
            print(f"[CAL] lr={float(lr):.6g} epoch={int(epoch)} per-condition: <none>")
        else:
            cal_pairs = ", ".join(
                f"c{int(r.condition)}={float(r.calibration):.6f}"
                for r in per_cond_df.sort_values("condition").itertuples(index=False)
            )
            print(f"[CAL] lr={float(lr):.6g} epoch={int(epoch)} per-condition: {cal_pairs}")

        epoch_records.append(
            {
                "lr": float(lr),
                "epoch": int(epoch),
                "test_bce": float(test_bce),
                "global_calibration": float(global_calibration),
            }
        )

        for _, row in per_cond_df.iterrows():
            condition_records.append(
                {
                    "lr": float(lr),
                    "epoch": int(epoch),
                    "condition": int(row["condition"]),
                    "count": int(row["count"]),
                    "base_rate": float(row["base_rate"]),
                    "pred_mean": float(row["pred_mean"]),
                    "calibration": float(row["calibration"]),
                }
            )

    _ = train_single_loss(
        model=model,
        loss_mode="bce",
        train_loader=loss_loaders.train_loader,
        train_eval_loader=eval_loader,
        eval_loader=eval_loader,
        device=device,
        epochs=EPOCHS,
        lr=float(lr),
        eval_has_labels=True,
        non_blocking=non_blocking,
        eval_callback=on_epoch_eval,
        plot_eval_hist_epochs=False,
        print_embedding_table=False,
    )

results_df = pd.DataFrame(epoch_records)
condition_df = pd.DataFrame(condition_records)

if results_df.empty:
    raise RuntimeError("Search produced no results.")

results_df = results_df.sort_values(["test_bce", "lr", "epoch"], ascending=[True, True, True]).reset_index(drop=True)
best_row = results_df.iloc[0]

print("Best BCE result:")
print(best_row.to_dict())


In [None]:
def plot_metric_lines(metric_df: pd.DataFrame, value_col: str, title: str, y_label: str) -> None:
    plt.figure(figsize=(11, 6))
    lr_order = sorted(metric_df["lr"].unique(), reverse=True)
    for lr in lr_order:
        block = metric_df.loc[metric_df["lr"] == lr, ["epoch", value_col]].sort_values("epoch")
        plt.plot(
            block["epoch"],
            block[value_col],
            marker="o",
            linewidth=1.5,
            markersize=3,
            label=f"ler={lr:.6g}",
        )
    plt.xlabel("Epoch")
    plt.ylabel(y_label)
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend(ncol=3, fontsize=8)
    if "calibration" in value_col.lower():
        plt.ylim(0, 3)
    plt.tight_layout()
    plt.show()



def display_sortable_table(df: pd.DataFrame, table_id: str = "bce_final_results") -> None:
    from IPython.display import HTML, display

    html_table = df.to_html(index=False, table_id=table_id, classes="display compact", border=0)
    html = f"""
<link rel="stylesheet" href="https://cdn.datatables.net/1.13.8/css/jquery.dataTables.min.css" />
<script src="https://code.jquery.com/jquery-3.7.1.min.js"></script>
<script src="https://cdn.datatables.net/1.13.8/js/jquery.dataTables.min.js"></script>
{html_table}
<script>
(function() {{
  var tableId = '#{table_id}';
  if (window.jQuery && jQuery.fn && jQuery.fn.DataTable) {{
    if (jQuery.fn.DataTable.isDataTable(tableId)) {{
      jQuery(tableId).DataTable().destroy();
    }}
    jQuery(tableId).DataTable({{
      paging: true,
      pageLength: 100,
      lengthMenu: [[25, 50, 100, -1], [25, 50, 100, 'All']],
      ordering: true,
      info: true,
      autoWidth: false,
      scrollX: true
    }});
  }}
}})();
</script>
"""
    display(HTML(html))


table_cols = ["lr", "epoch", "test_bce", "global_calibration"]
final_table_df = results_df.copy()

# Show all rows/columns in a sortable table (with pandas fallback).
with pd.option_context("display.max_rows", None, "display.max_columns", None):
    try:
        display_sortable_table(final_table_df, table_id="bce_final_results")
    except Exception as exc:
        print(f"[WARN] Sortable table renderer unavailable ({exc}); falling back to plain display.")
        display(final_table_df)

print("Most optimal ler (by minimum test_bce):", float(best_row["lr"]))
print("Best epoch for that ler:", int(best_row["epoch"]))
print("Best test_bce:", float(best_row["test_bce"]))

# Chart 1: BCE across learning rate and epochs
plot_metric_lines(
    metric_df=results_df,
    value_col="test_bce",
    title=f"BCE ({eval_name}) across learning rate and epoch",
    y_label="BCE",
)

# Chart 2: Global calibration across learning rate and epochs
plot_metric_lines(
    metric_df=results_df,
    value_col="global_calibration",
    title="Global calibration ratio across learning rate and epoch",
    y_label="Calibration ratio",
)

# Chart 3: Per-condition calibration charts
# Per-condition calibration table used for charts
condition_table_cols = ["lr", "epoch", "condition", "count", "base_rate", "pred_mean", "calibration"]
display(condition_df.loc[:, condition_table_cols].sort_values(["condition", "lr", "epoch"]).reset_index(drop=True))

if condition_df.empty:
    print("No per-condition records to plot.")
else:
    cond_ids = sorted(condition_df["condition"].unique())
    n = len(cond_ids)
    ncols = 3
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5 * ncols, 3.8 * nrows))
    axes = np.array(axes).reshape(-1)

    lr_order = sorted(condition_df["lr"].unique(), reverse=True)
    epoch_order = sorted(condition_df["epoch"].unique())

    for i, cond_id in enumerate(cond_ids):
        ax = axes[i]
        block = condition_df.loc[condition_df["condition"] == cond_id, ["lr", "epoch", "calibration"]]
        ax.set_title(f"Condition {cond_id}")
        for lr in lr_order:
            lr_block = block.loc[block["lr"] == lr, ["epoch", "calibration"]].sort_values("epoch")
            if lr_block.empty:
                continue
            ax.plot(
                lr_block["epoch"],
                lr_block["calibration"],
                marker="o",
                linewidth=1.2,
                markersize=2.5,
                label=f"{lr:.3g}",
            )
        ax.set_xticks(epoch_order)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Calibration")
        ax.set_ylim(0, 2)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=6)

    for j in range(len(cond_ids), len(axes)):
        axes[j].axis("off")

    fig.suptitle("Per-condition calibration ratio across epoch (line per ler)")
    fig.tight_layout()
    plt.show()
