# Localized Entropy Notebook

This notebook runs the end-to-end pipeline using the config in `configs/default.json`.
Switch between synthetic and CTR data, tune the model, and toggle plots from the config file.


In [None]:
"""Set up imports, config, output capture, and device/seed."""
%matplotlib inline
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from localized_entropy.config import (
    load_and_resolve,
    loss_label,
    get_condition_label,
    get_data_source,
    resolve_loss_modes,
)
from localized_entropy.utils import init_device, set_seed
from localized_entropy.data.pipeline import prepare_data
from localized_entropy.training import evaluate, predict_probs
from localized_entropy.analysis import (
    print_pred_summary,
    print_pred_stats_by_condition,
    collect_le_stats_per_condition,
    collect_logits,
    bce_log_loss,
    roc_auc_score,
    pr_auc_score,
    expected_calibration_error,
    binary_classification_metrics,
    per_condition_metrics,
    per_condition_mean,
    per_condition_calibration,
    per_condition_calibration_from_base_rates,
    summarize_per_ad_train_eval_rates,
)
from localized_entropy.plotting import (
    plot_training_distributions,
    plot_eval_log10p_hist,
    plot_loss_curves,
    plot_eval_predictions_by_condition,
    plot_calibration_ratio_by_condition,
    plot_pred_to_train_rate,
    plot_le_stats_per_condition,
    plot_grad_sq_sums_by_condition,
    plot_ctr_filter_stats,
    plot_feature_distributions_by_condition,
    plot_label_rates_by_condition,
    build_eval_epoch_plotter,
    build_eval_batch_plotter,
)
from localized_entropy.experiments import (
    resolve_eval_bundle,
    resolve_train_eval_bundle,
    build_loss_loaders,
    select_eval_loader,
    build_model,
    train_single_loss,
    build_seed_sequence,
    run_repeated_loss_experiments,
)
from localized_entropy.compare import (
    compare_bce_le_runs,
    summarize_model_metrics,
    format_comparison_table,
    format_bce_le_summary,
    build_repeat_metrics_frame,
    summarize_repeat_metrics,
    build_wilcoxon_summary,
    format_wilcoxon_summary,
    build_per_condition_calibration_wilcoxon,
    sort_per_condition_wilcoxon_frame,
)
from localized_entropy.outputs import build_output_paths, start_notebook_output_capture
np.set_printoptions(precision=4, suppress=True)
torch.set_printoptions(precision=4)
# Load config and resolve experiment overrides
CONFIG_PATH = "configs/default.json"
cfg = load_and_resolve(CONFIG_PATH)
train_cfg = cfg["training"]
loss_modes = resolve_loss_modes(train_cfg.get("loss_mode", "localized_entropy"))
if not loss_modes:
    raise ValueError(f"Unsupported loss_mode: {train_cfg.get('loss_mode')}")
# Start output capture for each loss mode
output_paths = {loss_mode: build_output_paths(cfg, loss_mode) for loss_mode in loss_modes}
if "notebook_output_capture" in globals():
    if hasattr(notebook_output_capture, "stop"):
        notebook_output_capture.stop()
notebook_output_capture = start_notebook_output_capture(output_paths)
print(f"Using experiment: {cfg['experiment'].get('name', cfg['experiment'].get('active', 'unknown'))}")
# Initialize device and RNG seeds
device_cfg = cfg.get("device", {})
use_mps_flag = bool(device_cfg.get("use_mps", True))
device, use_cuda, use_mps, non_blocking = init_device(use_mps=use_mps_flag)
cpu_float64 = device.type == "cpu" and not use_mps_flag
model_dtype = torch.float64 if cpu_float64 else torch.float32
set_seed(cfg['project']['seed'], use_cuda)


In [None]:
"""Prepare data loaders, resolve eval split, and plot pre-training data distributions."""
data_bundle = prepare_data(cfg, device, use_cuda, use_mps)
splits = data_bundle.splits
loaders = data_bundle.loaders
plots_cfg = cfg['plots']
train_cfg = cfg['training']
data_source = get_data_source(cfg)
condition_label = get_condition_label(cfg)
loss_modes = resolve_loss_modes(train_cfg.get('loss_mode', 'localized_entropy'))
if not loss_modes:
    raise ValueError(f"Unsupported loss_mode: {train_cfg.get('loss_mode')}")
train_multi = len(loss_modes) > 1
train_bce_le = ("bce" in loss_modes and "localized_entropy" in loss_modes)
output_paths = {loss_mode: build_output_paths(cfg, loss_mode) for loss_mode in loss_modes}
if train_multi:
    print(f"Training loss modes: {', '.join(loss_modes)}")
raw_eval_value_range = plots_cfg.get('eval_pred_value_range', [-12, 0])
if isinstance(raw_eval_value_range, (list, tuple)) and len(raw_eval_value_range) == 2:
    eval_value_range = tuple(raw_eval_value_range)
else:
    print("[WARN] plots.eval_pred_value_range should be a 2-item list; using default (-12, 0).")
    eval_value_range = (-12, 0)
# Configure evaluation target via configs/default.json -> evaluation.split.
# Keep a train-time eval split with labels/conds for plots even if eval/test lacks labels.
# Resolve evaluation split + training eval split (for labeled diagnostics)
eval_cfg = cfg.get('evaluation', {})
eval_split, eval_loader, eval_labels, eval_conds, eval_name = resolve_eval_bundle(
    cfg, splits, loaders
)
eval_has_labels = eval_labels is not None
train_eval_loader, train_eval_conds, train_eval_name = resolve_train_eval_bundle(
    eval_split,
    eval_loader,
    eval_labels,
    eval_conds,
    eval_name,
    loaders,
    splits,
)
le_base_rates_train = None
le_base_rates_train_eval = None
le_base_rates_eval = None
use_true_le_base_rates = False
# Optional: use true synthetic base rates for LE denominator
if data_source == 'synthetic' and cfg.get('synthetic', {}).get('use_true_base_rates_for_le', False):
    base_rates_train = per_condition_mean(splits.p_train, splits.c_train, splits.num_conditions)
    base_rates_eval = per_condition_mean(splits.p_eval, splits.c_eval, splits.num_conditions)
    if base_rates_train is not None:
        use_true_le_base_rates = True
        le_base_rates_train = base_rates_train
        le_base_rates_train_eval = base_rates_train if train_eval_name == 'Train' else base_rates_eval
        if eval_name == 'Train':
            le_base_rates_eval = base_rates_train
        elif eval_name == 'Eval':
            le_base_rates_eval = base_rates_eval
        else:
            le_base_rates_eval = None
        print('[INFO] Using true per-condition base rates for LE denominator (synthetic).')
# Logging: split selection and loader notes
if cfg.get('logging', {}).get('print_eval_split', True):
    print(f"Evaluation split: {eval_name}")
    if train_eval_name != eval_name:
        print(f"Training eval split: {train_eval_name}")
if cfg.get('logging', {}).get('print_loader_note', True):
    print(loaders.loader_note)
# Optional CTR filter stats plot
if data_source == 'ctr' and cfg.get('ctr', {}).get('plot_filter_stats', False):
    ctr_stats = data_bundle.plot_data.get('ctr_stats')
    if ctr_stats:
        plot_ctr_filter_stats(ctr_stats['stats_df'], ctr_stats['labels'], ctr_stats['filter_col'])
# Optional pre-training data distribution plots
if plots_cfg.get('data_before_training', False):
    synth = data_bundle.plot_data.get('synthetic')
    if synth:
        plot_training_distributions(
            synth['net_worth'],
            synth['ages'],
            synth['probs'],
            synth['conds'],
            synth['num_conditions'],
        )
    else:
        if plots_cfg.get('ctr_data_distributions', True):
            ctr_plot = data_bundle.plot_data.get('ctr_distributions')
            if ctr_plot:
                
                print("DEBUG")
                print(ctr_plot['feature_names'])
                print(ctr_plot['num_conditions'])
                # Skip high-cardinality device count features to keep plots readable.
                skip_features = {"device_ip_count", "device_id_count"}
                feature_names = [
                    name for name in ctr_plot['feature_names'] if name not in skip_features
                ]
                if feature_names:
                    feature_indices = [
                        i
                        for i, name in enumerate(ctr_plot['feature_names'])
                        if name not in skip_features
                    ]
                    filtered_xnum = ctr_plot['xnum'][:, feature_indices]
                    # Apply log scaling only to configured numeric features.
                    log10_features = set(plots_cfg.get('ctr_log10_features', []))
                    log10_features &= set(feature_names)
                    plot_feature_distributions_by_condition(
                        filtered_xnum,
                        ctr_plot['conds'],
                        feature_names,
                        ctr_plot['num_conditions'],
                        max_features=int(plots_cfg.get('ctr_max_features', 3)),
                        log10_features=log10_features,
                        density=bool(plots_cfg.get('ctr_use_density', False)),
                    )
                else:
                    print('CTR distribution plots: all features filtered; skipping.')
                if plots_cfg.get('ctr_label_rates', True):
                    plot_label_rates_by_condition(
                        ctr_plot['labels'],
                        ctr_plot['conds'],
                        ctr_plot['num_conditions'],
                    )
            else:
                print('CTR plot_sample_size is disabled or empty; skipping CTR distributions.')
        else:
            print('CTR distribution plots are disabled in config.')
else:
    print('Skipping training data distribution plots before training.')


In [None]:
"""Print diagnostic stats for features, conditions, and labels."""
# Split-level diagnostics: features, conditions, and labels
# Diagnostics: input/label/condition stats
from localized_entropy.analysis import (
    print_condition_stats,
    print_feature_stats,
    print_label_stats,
)

print('Diagnostics: splits')
# Feature distributions per split
print_feature_stats('Train features', splits.x_train)
print_feature_stats('Eval features', splits.x_eval)
if splits.x_test is not None:
    print_feature_stats('Test features', splits.x_test)

# Condition counts per split
print_condition_stats('Train conds', splits.c_train, splits.num_conditions)
print_condition_stats('Eval conds', splits.c_eval, splits.num_conditions)
if splits.c_test is not None:
    print_condition_stats('Test conds', splits.c_test, splits.num_conditions)

# Label base rates per split
print_label_stats('Train labels', splits.y_train, splits.c_train, splits.num_conditions)
print_label_stats('Eval labels', splits.y_eval, splits.c_eval, splits.num_conditions)
if splits.y_test is not None:
    print_label_stats('Test labels', splits.y_test, splits.c_test, splits.num_conditions)


In [None]:
"""Build model instances for each requested loss mode."""
models = {}
for loss_mode in loss_modes:
    set_seed(cfg['project']['seed'], use_cuda)
    model = build_model(cfg, splits, device, dtype=model_dtype)
    models[loss_mode] = model
    if len(loss_modes) == 1:
        model
    else:
        print(f"Model ({loss_label(loss_mode)}):")
        print(model)


In [None]:
"""Check initial logits/probabilities before training."""
# Use a single batch to sanity-check untrained logits/probabilities
# Diagnostics: initial logits/prob stats (untrained)
try:
    batch = next(iter(train_eval_loader))
except StopIteration:
    batch = None
if batch is None:
    print(f"{train_eval_name} loader empty; skipping init logits diagnostics.")
else:
    # Batch includes per-sample weights (ignored for init diagnostics)
    x_b, x_cat_b, c_b, y_b, _ = batch
    x_b = x_b.to(device, non_blocking=non_blocking)
    x_cat_b = x_cat_b.to(device, non_blocking=non_blocking)
    c_b = c_b.to(device, non_blocking=non_blocking)
    for loss_mode, model in models.items():
        label = loss_label(loss_mode)
        with torch.no_grad():
            logits = model(x_b, x_cat_b, c_b)
            probs = torch.sigmoid(logits)
        logits_np = logits.detach().cpu().numpy().reshape(-1)
        probs_np = probs.detach().cpu().numpy().reshape(-1)
        print(
            f"{label} init batch logits: n={logits_np.size:,} min={logits_np.min():.6g} "
            f"max={logits_np.max():.6g} mean={logits_np.mean():.6g} std={logits_np.std():.6g}"
        )
        print(
            f"{label} init batch probs:  n={probs_np.size:,} min={probs_np.min():.6g} "
            f"max={probs_np.max():.6g} mean={probs_np.mean():.6g} std={probs_np.std():.6g}"
        )


In [None]:
"""Inspect pre-training predictions by condition when enabled."""
if plots_cfg.get('eval_pred_by_condition', True):
    for loss_mode, model in models.items():
        loss_name = loss_label(loss_mode)
        pretrain_eval_preds = predict_probs(
            model,
            train_eval_loader,
            device,
            non_blocking=non_blocking,
        )
        if pretrain_eval_preds.size > 0:
            if train_eval_conds is None:
                print(f"{train_eval_name} conditions unavailable; skipping per-condition stats.")
            else:
                print_pred_stats_by_condition(
                    pretrain_eval_preds,
                    train_eval_conds,
                    splits.num_conditions,
                    name=f"Pre-Training {train_eval_name} ({loss_name})",
                )
                plot_eval_predictions_by_condition(
                    pretrain_eval_preds,
                    train_eval_conds,
                    splits.num_conditions,
                    value_range=eval_value_range,
                    title=(
                        f"Pre-Training {train_eval_name} Predictions by {condition_label} ("
                        f"{loss_name})"
                    ),
                )
        else:
            print(f"{train_eval_name} set is empty after filtering; skipping pre-training plot.")


In [None]:
"""Train models with optional eval callbacks and collect run results."""
run_results = {}
# Collect logits/targets when LE stats or comparisons need them later.
need_eval_logits = (
    eval_has_labels
    and (
        train_bce_le
        or plots_cfg.get('le_stats', True)
        or plots_cfg.get('print_le_stats_table', True)
    )
)
plot_eval_epochs = plots_cfg.get('eval_pred_by_condition', True)
# Resolve per-loss training overrides (batch size, lr, epochs).
loss_train_cfgs = {}
loss_loaders_by_mode = {}
loss_eval_loaders = {}
loss_train_eval_loaders = {}
for loss_mode in loss_modes:
    loss_loaders, loss_train_cfg = build_loss_loaders(cfg, loss_mode, splits, device, use_cuda, use_mps)
    loss_train_cfgs[loss_mode] = loss_train_cfg
    loss_loaders_by_mode[loss_mode] = loss_loaders
    loss_eval_loader = select_eval_loader(eval_split, loss_loaders)
    loss_eval_loaders[loss_mode] = loss_eval_loader
    loss_train_eval_loader, _, _ = resolve_train_eval_bundle(
        eval_split,
        loss_eval_loader,
        eval_labels,
        eval_conds,
        eval_name,
        loss_loaders,
        splits,
    )
    loss_train_eval_loaders[loss_mode] = loss_train_eval_loader
# Run training for each loss mode
for loss_mode, model in models.items():
    set_seed(cfg['project']['seed'], use_cuda)
    loss_name = loss_label(loss_mode)
    print(f"[INFO] Training {loss_name} model")
    loss_train_cfg = loss_train_cfgs[loss_mode]
    focal_cfg = loss_train_cfg.get('focal', {}) if isinstance(loss_train_cfg, dict) else {}
    focal_alpha = focal_cfg.get('alpha') if isinstance(focal_cfg, dict) else None
    focal_gamma = focal_cfg.get('gamma') if isinstance(focal_cfg, dict) else None
    eval_every_n_batches = int(loss_train_cfg.get('eval_every_n_batches', 0) or 0)
    plot_eval_batches = plot_eval_epochs and eval_every_n_batches > 0
    track_eval_batch_losses = plots_cfg.get('loss_curves', True) and eval_every_n_batches > 0
    eval_callback = (
        build_eval_epoch_plotter(
            train_eval_name,
            train_eval_conds,
            splits.num_conditions,
            eval_value_range,
            condition_label,
            loss_name,
        )
        if plot_eval_epochs else None
    )
    eval_batch_callback = (
        build_eval_batch_plotter(
            train_eval_name,
            train_eval_conds,
            splits.num_conditions,
            eval_value_range,
            condition_label,
            loss_name,
        )
        if plot_eval_batches else None
    )
    lr_category = None
    lr_zero_after_epochs = None
    if data_source == 'synthetic' and loss_mode == 'localized_entropy':
        lr_category = loss_train_cfg.get('lr_category', loss_train_cfg.get('LRCategory'))
        lr_zero_after_epochs = loss_train_cfg.get('lr_zero_after_epochs')
    le_cross_batch_cfg = None
    if loss_mode == 'localized_entropy':
        le_cfg = loss_train_cfg.get('localized_entropy')
        if isinstance(le_cfg, dict):
            le_cross_batch_cfg = le_cfg.get('cross_batch')
    run_results[loss_mode] = train_single_loss(
        model=model,
        loss_mode=loss_mode,
        train_loader=loss_loaders_by_mode[loss_mode].train_loader,
        train_eval_loader=loss_train_eval_loaders[loss_mode],
        eval_loader=loss_eval_loaders[loss_mode],
        device=device,
        epochs=loss_train_cfg['epochs'],
        lr=loss_train_cfg['lr'],
        lr_category=lr_category,
        lr_zero_after_epochs=lr_zero_after_epochs,
        eval_has_labels=eval_has_labels,
        le_base_rates_train=le_base_rates_train,
        le_base_rates_train_eval=le_base_rates_train_eval,
        le_base_rates_eval=le_base_rates_eval,
        focal_alpha=focal_alpha,
        focal_gamma=focal_gamma,
        non_blocking=non_blocking,
        plot_eval_hist_epochs=plots_cfg.get('eval_hist_epochs', False),
        eval_callback=eval_callback,
        eval_every_n_batches=(
            eval_every_n_batches if (plot_eval_batches or track_eval_batch_losses) else None
        ),
        eval_batch_callback=eval_batch_callback,
        collect_eval_batch_losses=track_eval_batch_losses,
        collect_grad_sq_sums=plots_cfg.get('grad_sq_by_condition', False),
        debug_gradients=loss_train_cfg.get('debug_gradients', False),
        le_cross_batch_cfg=le_cross_batch_cfg,
        print_embedding_table=loss_train_cfg.get('print_embedding_table', False),
        collect_eval_logits=need_eval_logits,
    )


In [None]:
"""Plot training/eval loss curves."""
if plots_cfg.get('loss_curves', True):
    for loss_mode, result in run_results.items():
        output_path = output_paths[loss_mode]["loss_curves"]
        plot_loss_curves(
            result.train_losses,
            result.eval_losses,
            result.loss_label,
            output_path=output_path,
            eval_batch_losses=result.eval_batch_losses if eval_every_n_batches > 0 else None,
        )


In [None]:
"""Evaluate trained models, compute metrics, and render plots."""
eval_cfg = cfg.get('evaluation', {})
for loss_mode, result in run_results.items():
    label = result.loss_label
    eval_preds = result.eval_preds
    eval_loss = result.eval_loss
    model = result.model
    loss_train_cfg = loss_train_cfgs.get(loss_mode, train_cfg)
    epochs_completed = max(0, len(result.train_losses) - 1)
    display_name = eval_name if not train_multi else f"{eval_name} ({label})"
    print(f"=== {label} model evaluation ===")
    # Summary + plots
    if plots_cfg.get('print_eval_summary', True):
        print_pred_summary(display_name, eval_preds, labels=eval_labels, conds=eval_conds)
    if plots_cfg.get('eval_pred_hist', True):
        plot_eval_log10p_hist(eval_preds.astype(np.float32), epoch=epochs_completed, name=display_name)
    if plots_cfg.get('eval_pred_by_condition', True):
        if eval_conds is None:
            print(f"{display_name} conditions unavailable; skipping eval predictions by condition.")
        else:
            plot_eval_predictions_by_condition(
                eval_preds,
                eval_conds,
                splits.num_conditions,
                value_range=eval_value_range,
                name=display_name,
                output_path=output_paths[loss_mode]["post_training_eval_predictions"],
                title=(
                    f"Post-Training {display_name} Predictions by {condition_label} ("
                    f"{label})"
                ),
            )
            print("DEBUG: PREDICTIONS ARRAY")
            print(eval_preds[:100])
    # Metrics that require labels
    if eval_has_labels:
        print(f"Final {eval_name} {label}: {eval_loss:.10f}")
    else:
        print(f"Final {eval_name} {label}: n/a (labels unavailable)")
    if eval_has_labels:
        bins = int(eval_cfg.get('ece_bins', 20))
        min_count = int(eval_cfg.get('ece_min_count', 1))
        small_prob_max_cfg = float(eval_cfg.get('small_prob_max', 0.01))
        small_prob_quantile = float(eval_cfg.get('small_prob_quantile', 0.1))
        total_bce = bce_log_loss(eval_preds, eval_labels)
        total_ece, total_ece_table = expected_calibration_error(
            eval_preds, eval_labels, bins=bins, min_count=min_count
        )
        # Define low-probability calibration subset; fall back to a quantile if needed.
        small_threshold = small_prob_max_cfg
        small_mask = eval_preds <= small_threshold
        if not small_mask.any():
            quantile_threshold = float(np.quantile(eval_preds, small_prob_quantile))
            print(
                f"[INFO] No preds <= {small_threshold:g}; using {small_prob_quantile:.2f} quantile "
                f"threshold {quantile_threshold:g} for small-prob calibration."
            )
            small_threshold = quantile_threshold
            small_mask = eval_preds <= small_threshold
        if small_mask.any():
            total_ece_small, _ = expected_calibration_error(
                eval_preds[small_mask],
                eval_labels[small_mask],
                bins=bins,
                min_count=min_count,
            )
        else:
            total_ece_small = float('nan')
        print(f"{label} Total BCE (log loss): {total_bce:.8f}")
        print(f"{label} Total ECE: {total_ece:.8f}")
        print(f"{label} Total ECE (small p<= {small_threshold:g}): {total_ece_small:.8f}")
        total_auc = roc_auc_score(eval_preds, eval_labels)
        total_pr_auc = pr_auc_score(eval_preds, eval_labels)
        print(f"{label} Total ROC-AUC: {total_auc:.8f}")
        print(f"{label} Total PR-AUC (AP): {total_pr_auc:.8f}")
        cls_metrics = binary_classification_metrics(eval_preds, eval_labels)
        print(f"{label} Total Accuracy@0.5: {cls_metrics['accuracy']:.8f}")
        print(f"{label} Total F1@0.5: {cls_metrics['f1']:.8f}")
        if eval_cfg.get('print_calibration_table', False):
            print(total_ece_table.to_string(index=False))
        if eval_conds is None:
            print(f"[WARN] {display_name} conditions unavailable; skipping per-condition metrics.")
        else:
            per_ad = per_condition_metrics(
                eval_preds,
                eval_labels,
                eval_conds,
                bins=bins,
                min_count=min_count,
                small_prob_max=small_threshold,
            )
            if eval_cfg.get('print_per_ad', True):
                top_k = int(eval_cfg.get('per_ad_top_k', 10))
                print(per_ad.head(top_k).to_string(index=False))
        if plots_cfg.get('eval_calibration_ratio', True):
            if eval_conds is None:
                print(f"[WARN] {display_name} conditions unavailable; skipping calibration ratio plot.")
            else:
                cal_df = None
                if use_true_le_base_rates and le_base_rates_eval is not None:
                    cal_df = per_condition_calibration_from_base_rates(
                        eval_preds,
                        eval_conds,
                        le_base_rates_eval,
                    )
                elif eval_has_labels:
                    cal_df = per_condition_calibration(
                        eval_preds,
                        eval_labels,
                        eval_conds,
                    )
                if cal_df is None:
                    print(f"[WARN] {display_name} calibration data unavailable; skipping calibration ratio plot.")
                else:
                    plot_calibration_ratio_by_condition(
                        cal_df['base_rate'].to_numpy(),
                        cal_df['calibration'].to_numpy(),
                        name=display_name,
                        condition_label=condition_label,
                        title=f"{display_name} Calibration Ratio vs {condition_label} Base Rate",
                        output_path=output_paths[loss_mode]['calibration_ratio'],
                    )
        if not train_multi:
            focal_cfg = loss_train_cfg.get('focal', {}) if isinstance(loss_train_cfg, dict) else {}
            focal_alpha = focal_cfg.get('alpha') if isinstance(focal_cfg, dict) else None
            focal_gamma = focal_cfg.get('gamma') if isinstance(focal_cfg, dict) else None
            for mode in loss_train_cfg.get('eval_compare_losses', []):
                if mode == loss_mode:
                    continue
                other_base_rates = (
                    le_base_rates_eval if (use_true_le_base_rates and mode == 'localized_entropy') else None
                )
                other_loss, _ = evaluate(
                    model,
                    eval_loader,
                    device,
                    loss_mode=mode,
                    focal_alpha=focal_alpha,
                    focal_gamma=focal_gamma,
                    base_rates=other_base_rates,
                    non_blocking=non_blocking,
                )
                other_label = loss_label(mode)
                print(f"Final {eval_name} {other_label}: {other_loss:.10f}")
    # Per-ad train click rates vs. eval prediction averages
    plot_df = summarize_per_ad_train_eval_rates(
        splits.y_train,
        splits.c_train,
        eval_preds,
        eval_conds,
        splits.num_conditions,
        condition_label=condition_label,
        eval_name=eval_name,
        top_k=eval_cfg.get('per_ad_top_k', 10),
    )
    if plot_df is not None:
        plot_pred_to_train_rate(
            plot_df,
            condition_label=condition_label,
            eval_name=eval_name,
            output_path=output_paths[loss_mode]["pred_to_train_rate"],
        )


In [10]:
"""Run optional test-set inference/plots when configured."""
if loaders.test_loader is not None and eval_split != 'test':
    for loss_mode, result in run_results.items():
        label = result.loss_label
        display_name = 'Test' if not train_multi else f"Test ({label})"
        test_preds = predict_probs(
            result.model,
            loaders.test_loader,
            device,
            non_blocking=non_blocking,
        )
        if test_preds.size > 0:
            print_pred_summary(display_name, test_preds, labels=None, conds=splits.c_test)
            if plots_cfg.get('eval_pred_hist', True):
                plot_eval_log10p_hist(
                    test_preds.astype(np.float32),
                    epoch=train_cfg['epochs'],
                    name=display_name,
                )
            if plots_cfg.get('eval_pred_by_condition', True):
                if splits.c_test is None:
                    print(f"{display_name} conditions unavailable; skipping test predictions by condition.")
                else:
                    plot_eval_predictions_by_condition(
                        test_preds,
                        splits.c_test,
                        splits.num_conditions,
                        value_range=eval_value_range,
                        name=display_name,
                        title=(
                            f"Post-Training {display_name} Predictions by {condition_label} ("
                            f"{label})"
                        ),
                    )
        else:
            print(f"{display_name} set is empty after filtering; skipping test plot.")


In [None]:
"""Compute localized entropy stats per condition and plot tables."""
need_le_stats = (
    plots_cfg.get('le_stats', True)
    or plots_cfg.get('print_le_stats_table', True)
    or train_bce_le
)
if need_le_stats:
    if not eval_has_labels:
        print(f"{eval_name} labels unavailable; skipping localized entropy stats.")
    elif eval_conds is None:
        print(f"{eval_name} conditions unavailable; skipping localized entropy stats.")
    else:
        for loss_mode, result in run_results.items():
            label = result.loss_label
            z_all = result.eval_logits
            y_all = result.eval_targets
            c_all = result.eval_conds
            if z_all is None or y_all is None or c_all is None:
                z_all, y_all, c_all = collect_logits(
                    result.model, eval_loader, device, non_blocking=non_blocking
                )
                result.eval_logits = z_all
                result.eval_targets = y_all
                result.eval_conds = c_all
            le_stats = collect_le_stats_per_condition(z_all, y_all, c_all, eps=1e-12)
            if plots_cfg.get('print_le_stats_table', True):
                print(f"LE stats table ({label})")
                print('cond	num	den	avg_p	#y=1	#y=0	ratio')
                for cond in sorted(le_stats.keys()):
                    s = le_stats[cond]
                    print(
                        f"{cond}	{s['Numerator']:.6g}	{s['Denominator']:.6g}	"
                        f"{s['Average prediction for denominator']:.6g}	"
                        f"{s['Number of samples with label 1']}	"
                        f"{s['Number of samples with label 0']}	"
                        f"{s['Numerator/denominator']:.6g}"
                    )
            if plots_cfg.get('le_stats', True):
                plot_le_stats_per_condition(
                    le_stats,
                    title=f"Localized Entropy terms per condition - {eval_name} set ({label})",
                )


In [None]:
"""Compare BCE vs LE and plot gradient diagnostics if enabled."""
compare_cfg = cfg.get('comparison', {})
compare_enabled = compare_cfg.get('enabled', True)
has_bce = 'bce' in run_results
has_le = 'localized_entropy' in run_results
has_focal = 'focal' in run_results
if train_bce_le and compare_enabled:
    if not eval_has_labels or eval_conds is None:
        print(f"[WARN] {eval_name} labels/conditions unavailable; skipping BCE vs LE comparison.")
    else:
        bce_result = run_results.get('bce')
        le_result = run_results.get('localized_entropy')
        if bce_result is None or le_result is None:
            print("[WARN] Missing BCE or LE results; skipping comparison.")
        else:
            comparison = compare_bce_le_runs(
                bce_result,
                le_result,
                eval_labels,
                eval_conds,
                condition_label=condition_label,
                sort_by=str(compare_cfg.get('sort_by', 'count')),
            )
            comparison_table = comparison.sort_values(condition_label)
            if compare_cfg.get('print_table', True):
                table_columns = [
                    condition_label,
                    'count',
                    'base_rate',
                    'bce_pred_mean',
                    'le_pred_mean',
                    'bce_logloss',
                    'le_logloss',
                    'bce_calibration',
                    'le_calibration',
                    'delta_calibration',
                    'bce_le_ratio',
                    'le_le_ratio',
                    'delta_le_ratio',
                ]
                title = f"{eval_name}: per-{condition_label} BCE vs LE comparison"
                print(title)
                print(format_comparison_table(
                    comparison_table,
                    table_columns,
                    top_k=int(compare_cfg.get('top_k', 20)),
                ))
                print('')
                cal_table = comparison_table[[condition_label, 'bce_calibration', 'le_calibration']].copy()
                if has_focal:
                    focal_result = run_results.get('focal')
                    if focal_result is not None:
                        focal_cal = per_condition_calibration(
                            focal_result.eval_preds,
                            eval_labels,
                            eval_conds,
                        )
                        focal_cal = focal_cal.rename(
                            columns={'calibration': 'fl_calibration'}
                        )
                        cal_table = cal_table.merge(
                            focal_cal[["condition", 'fl_calibration']],
                            left_on=condition_label,
                            right_on='condition',
                            how='left',
                        ).drop(columns=['condition'])
                cal_table['bce_abs_1_minus_cal'] = (1.0 - cal_table['bce_calibration']).abs()
                cal_table['le_abs_1_minus_cal'] = (1.0 - cal_table['le_calibration']).abs()
                if 'fl_calibration' in cal_table.columns:
                    cal_table['fl_abs_1_minus_cal'] = (1.0 - cal_table['fl_calibration']).abs()
                cal_title = (
                    f"{eval_name}: per-{condition_label} abs(1 - calibration) (lower is better)"
                )
                print(cal_title)
                print(format_comparison_table(
                    cal_table.sort_values(condition_label),
                    [
                        condition_label,
                        'bce_abs_1_minus_cal',
                        'le_abs_1_minus_cal',
                        *(['fl_abs_1_minus_cal'] if 'fl_abs_1_minus_cal' in cal_table.columns else []),
                    ],
                    top_k=int(compare_cfg.get('top_k', 20)),
                ))
                print('')
                eval_cfg = cfg.get('evaluation', {})
                print(format_bce_le_summary(
                    bce_result,
                    le_result,
                    eval_labels,
                    eval_conds=eval_conds,
                    condition_label=condition_label,
                    ece_bins=int(eval_cfg.get('ece_bins', 20)),
                    ece_min_count=int(eval_cfg.get('ece_min_count', 1)),
                    small_prob_max=float(eval_cfg.get('small_prob_max', 0.01)),
                    small_prob_quantile=float(eval_cfg.get('small_prob_quantile', 0.1)),
                ))
if compare_enabled and has_focal:
    if not eval_has_labels:
        print(f"[WARN] {eval_name} labels unavailable; skipping focal comparison metrics.")
    else:
        eval_cfg = cfg.get('evaluation', {})
        loss_order = [
            ('bce', 'BCE'),
            ('localized_entropy', 'LE'),
            ('focal', 'Focal'),
        ]
        metrics_rows = []
        for key, name in loss_order:
            result = run_results.get(key)
            if result is None:
                continue
            metrics = summarize_model_metrics(
                result.eval_preds,
                eval_labels,
                ece_bins=int(eval_cfg.get('ece_bins', 20)),
                ece_min_count=int(eval_cfg.get('ece_min_count', 1)),
                threshold=0.5,
                small_prob_max=float(eval_cfg.get('small_prob_max', 0.01)),
                small_prob_quantile=float(eval_cfg.get('small_prob_quantile', 0.1)),
            )
            metrics_rows.append((name, metrics))
        if metrics_rows:
            columns = ['metric'] + [name for name, _ in metrics_rows]
            metrics_frame = pd.DataFrame(
                {
                    'metric': ['accuracy@0.5', 'logloss', 'brier', 'ece', 'ece_small'],
                    **{
                        name: [
                            metrics['accuracy'],
                            metrics['logloss'],
                            metrics['brier'],
                            metrics['ece'],
                            metrics['ece_small'],
                        ]
                        for name, metrics in metrics_rows
                    },
                }
            )
            title = "BCE vs LE vs Focal summary (lower is better for logloss/brier/ece):"
            print(title)
            print(format_comparison_table(metrics_frame, columns, top_k=len(metrics_frame)))

grad_plot_enabled = plots_cfg.get('grad_sq_by_condition', False)
if grad_plot_enabled:
    if not train_bce_le:
        print('[WARN] Grad plot requires BCE and LE runs; skipping.')
    else:
        bce_result = run_results.get('bce')
        le_result = run_results.get('localized_entropy')
        if bce_result is None or le_result is None:
            print('[WARN] Missing BCE or LE results; skipping grad plot.')
        else:
            bce_grad_stats = bce_result.grad_sq_stats
            le_grad_stats = le_result.grad_sq_stats
            if bce_grad_stats is None or le_grad_stats is None:
                print('[WARN] Grad stats not collected; enable plots.grad_sq_by_condition to track.')
            else:
                bce_grads = bce_grad_stats.mean_by_condition
                le_grads = le_grad_stats.mean_by_condition
                plot_grad_sq_sums_by_condition(
                    bce_grads,
                    le_grads,
                    condition_label=condition_label,
                    title=(
                        f"Training gradient mean square by {condition_label} (BCE vs LE)"
                    ),
                    top_k=int(plots_cfg.get('grad_sq_top_k', 0) or 0),
                    log10=bool(plots_cfg.get('grad_sq_log10', True)),
                )

                grad_cond_ids = np.arange(len(bce_grad_stats.mean_by_condition))
                grad_top_k = int(plots_cfg.get('grad_sq_top_k', 0) or 0)
                if grad_top_k > 0 and grad_top_k < grad_cond_ids.size:
                    scores = np.maximum(
                        bce_grad_stats.mean_by_condition,
                        le_grad_stats.mean_by_condition,
                    )
                    grad_cond_ids = np.argsort(scores)[::-1][:grad_top_k]
                bce_vals = bce_grad_stats.mean_by_condition[grad_cond_ids]
                le_vals = le_grad_stats.mean_by_condition[grad_cond_ids]
                ratio = np.divide(
                    le_vals,
                    bce_vals,
                    out=np.full_like(le_vals, np.nan),
                    where=bce_vals != 0,
                )
                grad_table = pd.DataFrame(
                    {
                        condition_label: grad_cond_ids,
                        'bce_grad_mse': bce_vals,
                        'le_grad_mse': le_vals,
                        'le_over_bce_grad_mse': ratio,
                    }
                )
                grad_title = (
                    f"Training per-{condition_label} grad MSE ratio (LE/BCE)"
                )
                print(grad_title)
                print(format_comparison_table(
                    grad_table,
                    [
                        condition_label,
                        'bce_grad_mse',
                        'le_grad_mse',
                        'le_over_bce_grad_mse',
                    ],
                    top_k=len(grad_table),
                ))
                print('')

                def _print_grad_class_mse(label, stats):
                    mse = stats.class_mse
                    counts = stats.class_counts
                    ratio = stats.class_ratio
                    if mse is None or counts is None:
                        print(f"[WARN] {label} grad class MSE unavailable.")
                        return
                    count0 = int(counts[0]) if np.isfinite(counts[0]) else 0
                    count1 = int(counts[1]) if np.isfinite(counts[1]) else 0
                    ratio_str = f"{ratio:.4f}" if np.isfinite(ratio) else "nan"
                    print(
                        f"{label} grad MSE by class: "
                        f"0={mse[0]:.6e} (n={count0}), "
                        f"1={mse[1]:.6e} (n={count1}), "
                        f"ratio={ratio_str}"
                    )

                _print_grad_class_mse('BCE', bce_grad_stats)
                _print_grad_class_mse('LE', le_grad_stats)


In [13]:
"""Run repeated experiments and statistical comparisons."""
repeat_cfg = cfg.get('repeats', {})
repeat_enabled = bool(repeat_cfg.get('enabled', False))
repeat_runs = int(repeat_cfg.get('num_runs', 1) or 0)
repeat_seed_stride = int(repeat_cfg.get('seed_stride', 1) or 1)
include_base_run = bool(repeat_cfg.get('include_base_run', True))
if repeat_enabled:
    if repeat_runs < 1:
        print('[WARN] repeats.enabled is true but repeats.num_runs < 1; skipping repeats.')
    elif not train_bce_le:
        print('[WARN] Repeat runs require training BCE + LE; skipping.')
    elif not eval_has_labels:
        print(f"[WARN] {eval_name} labels unavailable; skipping repeat stats.")
    else:
        base_seed = int(cfg['project']['seed'])
        seeds = build_seed_sequence(base_seed, repeat_runs, repeat_seed_stride)
        repeat_results = {loss_mode: [] for loss_mode in loss_modes}
        run_seeds = []
        start_idx = 0
        # Reuse the already-computed base run when possible to avoid retraining.
        if include_base_run:
            missing = [mode for mode in loss_modes if mode not in run_results]
            if missing:
                print(f"[WARN] Missing base run results for {missing}; running all repeats from scratch.")
            else:
                for loss_mode in loss_modes:
                    repeat_results[loss_mode].append(run_results[loss_mode])
                run_seeds.append(seeds[0] if seeds else base_seed)
                start_idx = 1
        extra_seeds = seeds[start_idx:]
        # Train additional seeds for significance testing.
        if extra_seeds:
            print(f"[INFO] Running {len(extra_seeds)} repeated runs for significance testing.")
            extra_results = run_repeated_loss_experiments(
                cfg=cfg,
                loss_modes=loss_modes,
                splits=splits,
                eval_split=eval_split,
                eval_labels=eval_labels,
                eval_conds=eval_conds,
                eval_name=eval_name,
                device=device,
                use_cuda=use_cuda,
                use_mps=use_mps,
                model_dtype=model_dtype,
                seeds=extra_seeds,
                le_base_rates_train=le_base_rates_train,
                le_base_rates_train_eval=le_base_rates_train_eval,
                le_base_rates_eval=le_base_rates_eval,
                non_blocking=non_blocking,
                collect_eval_logits=False,
            )
            for loss_mode, runs in extra_results.items():
                repeat_results[loss_mode].extend(runs)
            run_seeds.extend(extra_seeds)
        bce_runs = repeat_results.get('bce', [])
        le_runs = repeat_results.get('localized_entropy', [])
        focal_runs = repeat_results.get('focal', [])
        if len(bce_runs) != len(le_runs):
            print('[WARN] Repeat run counts do not match between BCE and LE; skipping Wilcoxon test.')
        elif not bce_runs:
            print('[WARN] No repeated runs available for significance testing.')
        else:
            eval_cfg = cfg.get('evaluation', {})
            run_values = run_seeds if len(run_seeds) == len(bce_runs) else None
            bce_metrics = build_repeat_metrics_frame(
                bce_runs,
                eval_labels,
                ece_bins=int(eval_cfg.get('ece_bins', 20)),
                ece_min_count=int(eval_cfg.get('ece_min_count', 1)),
                threshold=0.5,
                small_prob_max=float(eval_cfg.get('small_prob_max', 0.01)),
                small_prob_quantile=float(eval_cfg.get('small_prob_quantile', 0.1)),
                run_label='seed',
                run_values=run_values,
            )
            le_metrics = build_repeat_metrics_frame(
                le_runs,
                eval_labels,
                ece_bins=int(eval_cfg.get('ece_bins', 20)),
                ece_min_count=int(eval_cfg.get('ece_min_count', 1)),
                threshold=0.5,
                small_prob_max=float(eval_cfg.get('small_prob_max', 0.01)),
                small_prob_quantile=float(eval_cfg.get('small_prob_quantile', 0.1)),
                run_label='seed',
                run_values=run_values,
            )
            print('Repeated-run metric summary (BCE):')
            print(format_comparison_table(
                summarize_repeat_metrics(bce_metrics),
                ['metric', 'n', 'mean', 'std', 'min', 'max'],
                top_k=10,
            ))
            print('')
            print('Repeated-run metric summary (LE):')
            print(format_comparison_table(
                summarize_repeat_metrics(le_metrics),
                ['metric', 'n', 'mean', 'std', 'min', 'max'],
                top_k=10,
            ))
            print('')
            if len(bce_runs) < 2:
                print('[WARN] Need at least 2 repeats for a Wilcoxon test; skipping.')
            else:
                wilcoxon_summary = build_wilcoxon_summary(
                    bce_metrics,
                    le_metrics,
                    zero_method=str(repeat_cfg.get('wilcoxon_zero_method', 'wilcox')),
                    alternative=str(repeat_cfg.get('wilcoxon_alternative', 'two-sided')),
                )
                print('Wilcoxon signed-rank test (delta > 0 favors LE):')
                print(format_wilcoxon_summary(wilcoxon_summary))
                if eval_conds is None:
                    print('[WARN] Eval conditions unavailable; skipping per-condition calibration test.')
                else:
                    per_cond_min_count = int(repeat_cfg.get('per_condition_min_count', 1))
                    per_cond_sort_by = str(repeat_cfg.get('per_condition_sort_by', 'p_value'))
                    per_cond_top_k = int(repeat_cfg.get('per_condition_top_k', 20))
                    per_cond_summary = build_per_condition_calibration_wilcoxon(
                        bce_runs,
                        le_runs,
                        eval_labels,
                        eval_conds,
                        zero_method=str(repeat_cfg.get('wilcoxon_zero_method', 'wilcox')),
                        alternative=str(repeat_cfg.get('wilcoxon_alternative', 'two-sided')),
                        min_count=per_cond_min_count,
                    )
                    if per_cond_summary is None or len(per_cond_summary) == 0:
                        print('[WARN] Per-condition calibration test returned no rows.')
                    else:
                        per_cond_summary = per_cond_summary.rename(columns={'condition': condition_label})
                        per_cond_summary = sort_per_condition_wilcoxon_frame(per_cond_summary, per_cond_sort_by)
                        print(f"Per-{condition_label} calibration Wilcoxon (abs pred_mean - base_rate):")
                        print(format_comparison_table(
                            per_cond_summary,
                            [condition_label, 'count', 'base_rate', 'bce_gap', 'le_gap', 'delta_mean', 'p_value'],
                            top_k=per_cond_top_k,
                        ))
            if focal_runs:
                focal_metrics = build_repeat_metrics_frame(
                    focal_runs,
                    eval_labels,
                    ece_bins=int(eval_cfg.get('ece_bins', 20)),
                    ece_min_count=int(eval_cfg.get('ece_min_count', 1)),
                    threshold=0.5,
                    small_prob_max=float(eval_cfg.get('small_prob_max', 0.01)),
                    small_prob_quantile=float(eval_cfg.get('small_prob_quantile', 0.1)),
                    run_label='seed',
                    run_values=run_values,
                )
                print('')
                print('Repeated-run metric summary (Focal):')
                print(format_comparison_table(
                    summarize_repeat_metrics(focal_metrics),
                    ['metric', 'n', 'mean', 'std', 'min', 'max'],
                    top_k=10,
                ))
                print('')
                if len(focal_runs) != len(bce_runs):
                    print('[WARN] Repeat run counts do not match between BCE and Focal; skipping Wilcoxon test.')
                elif len(focal_runs) < 2:
                    print('[WARN] Need at least 2 repeats for a Wilcoxon test (BCE vs Focal); skipping.')
                else:
                    focal_vs_bce = build_wilcoxon_summary(
                        bce_metrics,
                        focal_metrics,
                        zero_method=str(repeat_cfg.get('wilcoxon_zero_method', 'wilcox')),
                        alternative=str(repeat_cfg.get('wilcoxon_alternative', 'two-sided')),
                    )
                    print('Wilcoxon signed-rank test (delta > 0 favors Focal) [BCE vs Focal]:')
                    print(format_wilcoxon_summary(focal_vs_bce))
                if len(focal_runs) != len(le_runs):
                    print('[WARN] Repeat run counts do not match between LE and Focal; skipping Wilcoxon test.')
                elif len(focal_runs) < 2:
                    print('[WARN] Need at least 2 repeats for a Wilcoxon test (LE vs Focal); skipping.')
                else:
                    focal_vs_le = build_wilcoxon_summary(
                        le_metrics,
                        focal_metrics,
                        zero_method=str(repeat_cfg.get('wilcoxon_zero_method', 'wilcox')),
                        alternative=str(repeat_cfg.get('wilcoxon_alternative', 'two-sided')),
                    )
                    print('Wilcoxon signed-rank test (delta > 0 favors Focal) [LE vs Focal]:')
                    print(format_wilcoxon_summary(focal_vs_le))


In [14]:
"""Stop notebook output capture."""
if "notebook_output_capture" in globals():
    if hasattr(notebook_output_capture, "stop"):
        notebook_output_capture.stop()