# Localized Entropy Notebook

This notebook runs the end-to-end pipeline using the config in `configs/default.json`.
Switch between synthetic and CTR data, tune the model, and toggle plots from the config file.


In [None]:
%matplotlib inline
from pathlib import Path
import numpy as np
import torch
from localized_entropy.config import (
    load_and_resolve,
    loss_label,
    get_condition_label,
    get_data_source,
    resolve_loss_modes,
)
from localized_entropy.utils import init_device, set_seed
from localized_entropy.data.pipeline import prepare_data
from localized_entropy.training import evaluate, predict_probs
from localized_entropy.analysis import (
    print_pred_summary,
    print_pred_stats_by_condition,
    collect_le_stats_per_condition,
    collect_logits,
    bce_log_loss,
    roc_auc_score,
    pr_auc_score,
    expected_calibration_error,
    binary_classification_metrics,
    per_condition_metrics,
    summarize_per_ad_train_eval_rates,
)
from localized_entropy.plotting import (
    plot_training_distributions,
    plot_eval_log10p_hist,
    plot_loss_curves,
    plot_eval_predictions_by_condition,
    plot_pred_to_train_rate,
    plot_le_stats_per_condition,
    plot_grad_sq_sums_by_condition,
    plot_ctr_filter_stats,
    plot_feature_distributions_by_condition,
    plot_label_rates_by_condition,
    build_eval_epoch_plotter,
    build_eval_batch_plotter,
)
from localized_entropy.experiments import (
    resolve_eval_bundle,
    resolve_train_eval_bundle,
    build_model,
    train_single_loss,
    build_seed_sequence,
    run_repeated_loss_experiments,
)
from localized_entropy.compare import (
    compare_bce_le_runs,
    format_comparison_table,
    format_bce_le_summary,
    build_repeat_metrics_frame,
    summarize_repeat_metrics,
    build_wilcoxon_summary,
    format_wilcoxon_summary,
    build_per_condition_calibration_wilcoxon,
    sort_per_condition_wilcoxon_frame,
)
from localized_entropy.outputs import build_output_paths, start_notebook_output_capture
np.set_printoptions(precision=4, suppress=True)
torch.set_printoptions(precision=4)
CONFIG_PATH = "configs/default.json"
cfg = load_and_resolve(CONFIG_PATH)
train_cfg = cfg["training"]
loss_modes = resolve_loss_modes(train_cfg.get("loss_mode", "localized_entropy"))
if not loss_modes:
    raise ValueError(f"Unsupported loss_mode: {train_cfg.get('loss_mode')}")
output_paths = {loss_mode: build_output_paths(cfg, loss_mode) for loss_mode in loss_modes}
if "notebook_output_capture" in globals():
    if hasattr(notebook_output_capture, "stop"):
        notebook_output_capture.stop()
notebook_output_capture = start_notebook_output_capture(output_paths)
print(f"Using experiment: {cfg['experiment'].get('name', cfg['experiment'].get('active', 'unknown'))}")
device, use_cuda, non_blocking = init_device()
set_seed(cfg['project']['seed'], use_cuda)


In [None]:
data_bundle = prepare_data(cfg, device, use_cuda)
splits = data_bundle.splits
loaders = data_bundle.loaders
plots_cfg = cfg['plots']
train_cfg = cfg['training']
data_source = get_data_source(cfg)
condition_label = get_condition_label(cfg)
loss_modes = resolve_loss_modes(train_cfg.get('loss_mode', 'localized_entropy'))
if not loss_modes:
    raise ValueError(f"Unsupported loss_mode: {train_cfg.get('loss_mode')}")
train_both = len(loss_modes) > 1
output_paths = {loss_mode: build_output_paths(cfg, loss_mode) for loss_mode in loss_modes}
if train_both:
    print(f"Training loss modes: {', '.join(loss_modes)}")
raw_eval_value_range = plots_cfg.get('eval_pred_value_range', [-12, 0])
if isinstance(raw_eval_value_range, (list, tuple)) and len(raw_eval_value_range) == 2:
    eval_value_range = tuple(raw_eval_value_range)
else:
    print("[WARN] plots.eval_pred_value_range should be a 2-item list; using default (-12, 0).")
    eval_value_range = (-12, 0)
# Configure evaluation target via configs/default.json -> evaluation.split.
eval_cfg = cfg.get('evaluation', {})
eval_split, eval_loader, eval_labels, eval_conds, eval_name = resolve_eval_bundle(
    cfg, splits, loaders
)
eval_has_labels = eval_labels is not None
train_eval_loader, train_eval_conds, train_eval_name = resolve_train_eval_bundle(
    eval_split,
    eval_loader,
    eval_labels,
    eval_conds,
    eval_name,
    loaders,
    splits,
)
if cfg.get('logging', {}).get('print_eval_split', True):
    print(f"Evaluation split: {eval_name}")
    if train_eval_name != eval_name:
        print(f"Training eval split: {train_eval_name}")
if cfg.get('logging', {}).get('print_loader_note', True):
    print(loaders.loader_note)
if data_source == 'ctr' and cfg.get('ctr', {}).get('plot_filter_stats', False):
    ctr_stats = data_bundle.plot_data.get('ctr_stats')
    if ctr_stats:
        plot_ctr_filter_stats(ctr_stats['stats_df'], ctr_stats['labels'], ctr_stats['filter_col'])
if plots_cfg.get('data_before_training', False):
    synth = data_bundle.plot_data.get('synthetic')
    if synth:
        plot_training_distributions(
            synth['net_worth'],
            synth['ages'],
            synth['probs'],
            synth['conds'],
            synth['num_conditions'],
        )
    else:
        if plots_cfg.get('ctr_data_distributions', True):
            ctr_plot = data_bundle.plot_data.get('ctr_distributions')
            if ctr_plot:
                
                print("DEBUG")
                print(ctr_plot['feature_names'])
                print(ctr_plot['num_conditions'])
                skip_features = {"device_ip_count", "device_id_count"}
                feature_names = [
                    name for name in ctr_plot['feature_names'] if name not in skip_features
                ]
                if feature_names:
                    feature_indices = [
                        i
                        for i, name in enumerate(ctr_plot['feature_names'])
                        if name not in skip_features
                    ]
                    filtered_xnum = ctr_plot['xnum'][:, feature_indices]
                    log10_features = set(plots_cfg.get('ctr_log10_features', []))
                    log10_features &= set(feature_names)
                    plot_feature_distributions_by_condition(
                        filtered_xnum,
                        ctr_plot['conds'],
                        feature_names,
                        ctr_plot['num_conditions'],
                        max_features=int(plots_cfg.get('ctr_max_features', 3)),
                        log10_features=log10_features,
                        density=bool(plots_cfg.get('ctr_use_density', False)),
                    )
                else:
                    print('CTR distribution plots: all features filtered; skipping.')
                if plots_cfg.get('ctr_label_rates', True):
                    plot_label_rates_by_condition(
                        ctr_plot['labels'],
                        ctr_plot['conds'],
                        ctr_plot['num_conditions'],
                    )
            else:
                print('CTR plot_sample_size is disabled or empty; skipping CTR distributions.')
        else:
            print('CTR distribution plots are disabled in config.')
else:
    print('Skipping training data distribution plots before training.')


In [None]:
# Diagnostics: input/label/condition stats
from localized_entropy.analysis import (
    print_condition_stats,
    print_feature_stats,
    print_label_stats,
)

print('Diagnostics: splits')
print_feature_stats('Train features', splits.x_train)
print_feature_stats('Eval features', splits.x_eval)
if splits.x_test is not None:
    print_feature_stats('Test features', splits.x_test)

print_condition_stats('Train conds', splits.c_train, splits.num_conditions)
print_condition_stats('Eval conds', splits.c_eval, splits.num_conditions)
if splits.c_test is not None:
    print_condition_stats('Test conds', splits.c_test, splits.num_conditions)

print_label_stats('Train labels', splits.y_train, splits.c_train, splits.num_conditions)
print_label_stats('Eval labels', splits.y_eval, splits.c_eval, splits.num_conditions)
if splits.y_test is not None:
    print_label_stats('Test labels', splits.y_test, splits.c_test, splits.num_conditions)


In [None]:
models = {}
for loss_mode in loss_modes:
    set_seed(cfg['project']['seed'], use_cuda)
    model = build_model(cfg, splits, device)
    models[loss_mode] = model
    if len(loss_modes) == 1:
        model
    else:
        print(f"Model ({loss_label(loss_mode)}):")
        print(model)


In [None]:
# Diagnostics: initial logits/prob stats (untrained)
try:
    batch = next(iter(train_eval_loader))
except StopIteration:
    batch = None
if batch is None:
    print(f"{train_eval_name} loader empty; skipping init logits diagnostics.")
else:
    x_b, x_cat_b, c_b, y_b, nw_b = batch
    x_b = x_b.to(device, non_blocking=non_blocking)
    x_cat_b = x_cat_b.to(device, non_blocking=non_blocking)
    c_b = c_b.to(device, non_blocking=non_blocking)
    for loss_mode, model in models.items():
        label = loss_label(loss_mode)
        with torch.no_grad():
            logits = model(x_b, x_cat_b, c_b)
            probs = torch.sigmoid(logits)
        logits_np = logits.detach().cpu().numpy().reshape(-1)
        probs_np = probs.detach().cpu().numpy().reshape(-1)
        print(
            f"{label} init batch logits: n={logits_np.size:,} min={logits_np.min():.6g} "
            f"max={logits_np.max():.6g} mean={logits_np.mean():.6g} std={logits_np.std():.6g}"
        )
        print(
            f"{label} init batch probs:  n={probs_np.size:,} min={probs_np.min():.6g} "
            f"max={probs_np.max():.6g} mean={probs_np.mean():.6g} std={probs_np.std():.6g}"
        )


In [None]:
if plots_cfg.get('eval_pred_by_condition', True):
    for loss_mode, model in models.items():
        loss_name = loss_label(loss_mode)
        pretrain_eval_preds = predict_probs(
            model,
            train_eval_loader,
            device,
            non_blocking=non_blocking,
        )
        if pretrain_eval_preds.size > 0:
            if train_eval_conds is None:
                print(f"{train_eval_name} conditions unavailable; skipping per-condition stats.")
            else:
                print_pred_stats_by_condition(
                    pretrain_eval_preds,
                    train_eval_conds,
                    splits.num_conditions,
                    name=f"Pre-Training {train_eval_name} ({loss_name})",
                )
                plot_eval_predictions_by_condition(
                    pretrain_eval_preds,
                    train_eval_conds,
                    splits.num_conditions,
                    value_range=eval_value_range,
                    title=(
                        f"Pre-Training {train_eval_name} Predictions by {condition_label} ("
                        f"{loss_name})"
                    ),
                )
        else:
            print(f"{train_eval_name} set is empty after filtering; skipping pre-training plot.")


In [None]:
run_results = {}
need_eval_logits = (
    eval_has_labels
    and (
        train_both
        or plots_cfg.get('le_stats', True)
        or plots_cfg.get('print_le_stats_table', True)
    )
)
plot_eval_epochs = plots_cfg.get('eval_pred_by_condition', True)
eval_every_n_batches = int(train_cfg.get('eval_every_n_batches', 0) or 0)
plot_eval_batches = plot_eval_epochs and eval_every_n_batches > 0
track_eval_batch_losses = plots_cfg.get('loss_curves', True) and eval_every_n_batches > 0
for loss_mode, model in models.items():
    set_seed(cfg['project']['seed'], use_cuda)
    loss_name = loss_label(loss_mode)
    print(f"[INFO] Training {loss_name} model")
    eval_callback = (
        build_eval_epoch_plotter(
            train_eval_name,
            train_eval_conds,
            splits.num_conditions,
            eval_value_range,
            condition_label,
            loss_name,
        )
        if plot_eval_epochs else None
    )
    eval_batch_callback = (
        build_eval_batch_plotter(
            train_eval_name,
            train_eval_conds,
            splits.num_conditions,
            eval_value_range,
            condition_label,
            loss_name,
        )
        if plot_eval_batches else None
    )
    run_results[loss_mode] = train_single_loss(
        model=model,
        loss_mode=loss_mode,
        train_loader=loaders.train_loader,
        train_eval_loader=train_eval_loader,
        eval_loader=eval_loader,
        device=device,
        epochs=train_cfg['epochs'],
        lr=train_cfg['lr'],
        eval_has_labels=eval_has_labels,
        non_blocking=non_blocking,
        plot_eval_hist_epochs=plots_cfg.get('eval_hist_epochs', False),
        eval_callback=eval_callback,
        eval_every_n_batches=(
            eval_every_n_batches if (plot_eval_batches or track_eval_batch_losses) else None
        ),
        eval_batch_callback=eval_batch_callback,
        collect_eval_batch_losses=track_eval_batch_losses,
        collect_grad_sq_sums=plots_cfg.get('grad_sq_by_condition', False),
        collect_eval_logits=need_eval_logits,
    )


In [None]:
if plots_cfg.get('loss_curves', True):
    for loss_mode, result in run_results.items():
        output_path = output_paths[loss_mode]["loss_curves"]
        plot_loss_curves(
            result.train_losses,
            result.eval_losses,
            result.loss_label,
            output_path=output_path,
            eval_batch_losses=result.eval_batch_losses if eval_every_n_batches > 0 else None,
        )


In [None]:
eval_cfg = cfg.get('evaluation', {})
for loss_mode, result in run_results.items():
    label = result.loss_label
    eval_preds = result.eval_preds
    eval_loss = result.eval_loss
    model = result.model
    display_name = eval_name if not train_both else f"{eval_name} ({label})"
    print(f"=== {label} model evaluation ===")
    if plots_cfg.get('print_eval_summary', True):
        print_pred_summary(display_name, eval_preds, labels=eval_labels, conds=eval_conds)
    if plots_cfg.get('eval_pred_hist', True):
        plot_eval_log10p_hist(eval_preds.astype(np.float32), epoch=train_cfg['epochs'], name=display_name)
    if plots_cfg.get('eval_pred_by_condition', True):
        if eval_conds is None:
            print(f"{display_name} conditions unavailable; skipping eval predictions by condition.")
        else:
            plot_eval_predictions_by_condition(
                eval_preds,
                eval_conds,
                splits.num_conditions,
                value_range=eval_value_range,
                name=display_name,
                output_path=output_paths[loss_mode]["post_training_eval_predictions"],
                title=(
                    f"Post-Training {display_name} Predictions by {condition_label} ("
                    f"{label})"
                ),
            )
            print("DEBUG: PREDICTIONS ARRAY")
            print(eval_preds[:100])
    if eval_has_labels:
        print(f"Final {eval_name} {label}: {eval_loss:.10f}")
    else:
        print(f"Final {eval_name} {label}: n/a (labels unavailable)")
    if eval_has_labels:
        bins = int(eval_cfg.get('ece_bins', 20))
        min_count = int(eval_cfg.get('ece_min_count', 1))
        small_prob_max_cfg = float(eval_cfg.get('small_prob_max', 0.01))
        small_prob_quantile = float(eval_cfg.get('small_prob_quantile', 0.1))
        total_bce = bce_log_loss(eval_preds, eval_labels)
        total_ece, total_ece_table = expected_calibration_error(
            eval_preds, eval_labels, bins=bins, min_count=min_count
        )
        small_threshold = small_prob_max_cfg
        small_mask = eval_preds <= small_threshold
        if not small_mask.any():
            quantile_threshold = float(np.quantile(eval_preds, small_prob_quantile))
            print(
                f"[INFO] No preds <= {small_threshold:g}; using {small_prob_quantile:.2f} quantile "
                f"threshold {quantile_threshold:g} for small-prob calibration."
            )
            small_threshold = quantile_threshold
            small_mask = eval_preds <= small_threshold
        if small_mask.any():
            total_ece_small, _ = expected_calibration_error(
                eval_preds[small_mask],
                eval_labels[small_mask],
                bins=bins,
                min_count=min_count,
            )
        else:
            total_ece_small = float('nan')
        print(f"{label} Total BCE (log loss): {total_bce:.8f}")
        print(f"{label} Total ECE: {total_ece:.8f}")
        print(f"{label} Total ECE (small p<= {small_threshold:g}): {total_ece_small:.8f}")
        total_auc = roc_auc_score(eval_preds, eval_labels)
        total_pr_auc = pr_auc_score(eval_preds, eval_labels)
        print(f"{label} Total ROC-AUC: {total_auc:.8f}")
        print(f"{label} Total PR-AUC (AP): {total_pr_auc:.8f}")
        cls_metrics = binary_classification_metrics(eval_preds, eval_labels)
        print(f"{label} Total Accuracy@0.5: {cls_metrics['accuracy']:.8f}")
        print(f"{label} Total F1@0.5: {cls_metrics['f1']:.8f}")
        if eval_cfg.get('print_calibration_table', False):
            print(total_ece_table.to_string(index=False))
        if eval_conds is None:
            print(f"[WARN] {display_name} conditions unavailable; skipping per-condition metrics.")
        else:
            per_ad = per_condition_metrics(
                eval_preds,
                eval_labels,
                eval_conds,
                bins=bins,
                min_count=min_count,
                small_prob_max=small_threshold,
            )
            if eval_cfg.get('print_per_ad', True):
                top_k = int(eval_cfg.get('per_ad_top_k', 10))
                print(per_ad.head(top_k).to_string(index=False))
        if not train_both:
            for mode in train_cfg.get('eval_compare_losses', []):
                if mode == loss_mode:
                    continue
                other_loss, _ = evaluate(
                    model,
                    eval_loader,
                    device,
                    loss_mode=mode,
                    non_blocking=non_blocking,
                )
                other_label = loss_label(mode)
                print(f"Final {eval_name} {other_label}: {other_loss:.10f}")
    # Per-ad train click rates vs. eval prediction averages
    plot_df = summarize_per_ad_train_eval_rates(
        splits.y_train,
        splits.c_train,
        eval_preds,
        eval_conds,
        splits.num_conditions,
        condition_label=condition_label,
        eval_name=eval_name,
        top_k=eval_cfg.get('per_ad_top_k', 10),
    )
    if plot_df is not None:
        plot_pred_to_train_rate(
            plot_df,
            condition_label=condition_label,
            eval_name=eval_name,
            output_path=output_paths[loss_mode]["pred_to_train_rate"],
        )


In [None]:
if loaders.test_loader is not None and eval_split != 'test':
    for loss_mode, result in run_results.items():
        label = result.loss_label
        display_name = 'Test' if not train_both else f"Test ({label})"
        test_preds = predict_probs(
            result.model,
            loaders.test_loader,
            device,
            non_blocking=non_blocking,
        )
        if test_preds.size > 0:
            print_pred_summary(display_name, test_preds, labels=None, conds=splits.c_test)
            if plots_cfg.get('eval_pred_hist', True):
                plot_eval_log10p_hist(
                    test_preds.astype(np.float32),
                    epoch=train_cfg['epochs'],
                    name=display_name,
                )
            if plots_cfg.get('eval_pred_by_condition', True):
                if splits.c_test is None:
                    print(f"{display_name} conditions unavailable; skipping test predictions by condition.")
                else:
                    plot_eval_predictions_by_condition(
                        test_preds,
                        splits.c_test,
                        splits.num_conditions,
                        value_range=eval_value_range,
                        name=display_name,
                        title=(
                            f"Post-Training {display_name} Predictions by {condition_label} ("
                            f"{label})"
                        ),
                    )
        else:
            print(f"{display_name} set is empty after filtering; skipping test plot.")


In [None]:
if plots_cfg.get('data_after_training', False):
    synth = data_bundle.plot_data.get('synthetic')
    if synth:
        plot_training_distributions(
            synth['net_worth'],
            synth['ages'],
            synth['probs'],
            synth['conds'],
            synth['num_conditions'],
        )
    else:
        if plots_cfg.get('ctr_data_distributions', True):
            ctr_plot = data_bundle.plot_data.get('ctr_distributions')
            if ctr_plot:
                skip_features = {"device_ip_count", "device_id_count"}
                feature_names = [
                    name for name in ctr_plot['feature_names'] if name not in skip_features
                ]
                if feature_names:
                    feature_indices = [
                        i
                        for i, name in enumerate(ctr_plot['feature_names'])
                        if name not in skip_features
                    ]
                    filtered_xnum = ctr_plot['xnum'][:, feature_indices]
                    log10_features = set(plots_cfg.get('ctr_log10_features', []))
                    log10_features &= set(feature_names)
                    plot_feature_distributions_by_condition(
                        filtered_xnum,
                        ctr_plot['conds'],
                        feature_names,
                        ctr_plot['num_conditions'],
                        max_features=int(plots_cfg.get('ctr_max_features', 3)),
                        log10_features=log10_features,
                        density=bool(plots_cfg.get('ctr_use_density', False)),
                    )
                else:
                    print('CTR distribution plots: all features filtered; skipping.')
                if plots_cfg.get('ctr_label_rates', True):
                    plot_label_rates_by_condition(
                        ctr_plot['labels'],
                        ctr_plot['conds'],
                        ctr_plot['num_conditions'],
                    )
            else:
                print('CTR plot_sample_size is disabled or empty; skipping CTR post-training plots.')
        else:
            print('CTR post-training plots are disabled in config.')
else:
    print('Post-training training data plots are disabled.')


In [None]:
need_le_stats = (
    plots_cfg.get('le_stats', True)
    or plots_cfg.get('print_le_stats_table', True)
    or train_both
)
if need_le_stats:
    if not eval_has_labels:
        print(f"{eval_name} labels unavailable; skipping localized entropy stats.")
    elif eval_conds is None:
        print(f"{eval_name} conditions unavailable; skipping localized entropy stats.")
    else:
        for loss_mode, result in run_results.items():
            label = result.loss_label
            z_all = result.eval_logits
            y_all = result.eval_targets
            c_all = result.eval_conds
            if z_all is None or y_all is None or c_all is None:
                z_all, y_all, c_all = collect_logits(
                    result.model, eval_loader, device, non_blocking=non_blocking
                )
                result.eval_logits = z_all
                result.eval_targets = y_all
                result.eval_conds = c_all
            le_stats = collect_le_stats_per_condition(z_all, y_all, c_all, eps=1e-12)
            if plots_cfg.get('print_le_stats_table', True):
                print(f"LE stats table ({label})")
                print('cond	num	den	avg_p	#y=1	#y=0	ratio')
                for cond in sorted(le_stats.keys()):
                    s = le_stats[cond]
                    print(
                        f"{cond}	{s['Numerator']:.6g}	{s['Denominator']:.6g}	"
                        f"{s['Average prediction for denominator']:.6g}	"
                        f"{s['Number of samples with label 1']}	"
                        f"{s['Number of samples with label 0']}	"
                        f"{s['Numerator/denominator']:.6g}"
                    )
            if plots_cfg.get('le_stats', True):
                plot_le_stats_per_condition(
                    le_stats,
                    title=f"Localized Entropy terms per condition - {eval_name} set ({label})",
                )


In [None]:
compare_cfg = cfg.get('comparison', {})
if train_both and compare_cfg.get('enabled', True):
    if not eval_has_labels or eval_conds is None:
        print(f"[WARN] {eval_name} labels/conditions unavailable; skipping BCE vs LE comparison.")
    else:
        bce_result = run_results.get('bce')
        le_result = run_results.get('localized_entropy')
        if bce_result is None or le_result is None:
            print("[WARN] Missing BCE or LE results; skipping comparison.")
        else:
            comparison = compare_bce_le_runs(
                bce_result,
                le_result,
                eval_labels,
                eval_conds,
                condition_label=condition_label,
                sort_by=str(compare_cfg.get('sort_by', 'count')),
            )
            if compare_cfg.get('print_table', True):
                table_columns = [
                    condition_label,
                    'count',
                    'base_rate',
                    'bce_pred_mean',
                    'le_pred_mean',
                    'bce_calibration',
                    'le_calibration',
                    'delta_calibration',
                    'bce_le_ratio',
                    'le_le_ratio',
                    'delta_le_ratio',
                ]
                title = f"{eval_name}: per-{condition_label} BCE vs LE comparison"
                print(title)
                print(format_comparison_table(
                    comparison,
                    table_columns,
                    top_k=int(compare_cfg.get('top_k', 20)),
                ))
                print('')
                eval_cfg = cfg.get('evaluation', {})
                print(format_bce_le_summary(
                    bce_result,
                    le_result,
                    eval_labels,
                    ece_bins=int(eval_cfg.get('ece_bins', 20)),
                    ece_min_count=int(eval_cfg.get('ece_min_count', 1)),
                ))

grad_plot_enabled = plots_cfg.get('grad_sq_by_condition', False)
if grad_plot_enabled:
    if not train_both:
        print('[WARN] Grad sum plot requires BCE and LE runs; skipping.')
    else:
        bce_result = run_results.get('bce')
        le_result = run_results.get('localized_entropy')
        if bce_result is None or le_result is None:
            print('[WARN] Missing BCE or LE results; skipping grad sum plot.')
        else:
            bce_grads = bce_result.grad_sq_sum_per_condition
            le_grads = le_result.grad_sq_sum_per_condition
            if bce_grads is None or le_grads is None:
                print('[WARN] Grad sums not collected; enable plots.grad_sq_by_condition to track.')
            else:
                plot_grad_sq_sums_by_condition(
                    bce_grads,
                    le_grads,
                    condition_label=condition_label,
                    title=(
                        f"Training gradient sum of squares by {condition_label} (BCE vs LE)"
                    ),
                    top_k=int(plots_cfg.get('grad_sq_top_k', 0) or 0),
                    log10=bool(plots_cfg.get('grad_sq_log10', True)),
                )


In [None]:
repeat_cfg = cfg.get('repeats', {})
repeat_enabled = bool(repeat_cfg.get('enabled', False))
repeat_runs = int(repeat_cfg.get('num_runs', 1) or 0)
repeat_seed_stride = int(repeat_cfg.get('seed_stride', 1) or 1)
include_base_run = bool(repeat_cfg.get('include_base_run', True))
if repeat_enabled:
    if repeat_runs < 1:
        print('[WARN] repeats.enabled is true but repeats.num_runs < 1; skipping repeats.')
    elif not train_both:
        print('[WARN] Repeat runs require training both loss modes; skipping.')
    elif not eval_has_labels:
        print(f"[WARN] {eval_name} labels unavailable; skipping repeat stats.")
    else:
        base_seed = int(cfg['project']['seed'])
        seeds = build_seed_sequence(base_seed, repeat_runs, repeat_seed_stride)
        repeat_results = {loss_mode: [] for loss_mode in loss_modes}
        run_seeds = []
        start_idx = 0
        if include_base_run:
            missing = [mode for mode in loss_modes if mode not in run_results]
            if missing:
                print(f"[WARN] Missing base run results for {missing}; running all repeats from scratch.")
            else:
                for loss_mode in loss_modes:
                    repeat_results[loss_mode].append(run_results[loss_mode])
                run_seeds.append(seeds[0] if seeds else base_seed)
                start_idx = 1
        extra_seeds = seeds[start_idx:]
        if extra_seeds:
            print(f"[INFO] Running {len(extra_seeds)} repeated runs for significance testing.")
            extra_results = run_repeated_loss_experiments(
                cfg=cfg,
                loss_modes=loss_modes,
                splits=splits,
                loaders=loaders,
                train_eval_loader=train_eval_loader,
                eval_loader=eval_loader,
                device=device,
                use_cuda=use_cuda,
                eval_has_labels=eval_has_labels,
                seeds=extra_seeds,
                non_blocking=non_blocking,
                collect_eval_logits=False,
            )
            for loss_mode, runs in extra_results.items():
                repeat_results[loss_mode].extend(runs)
            run_seeds.extend(extra_seeds)
        bce_runs = repeat_results.get('bce', [])
        le_runs = repeat_results.get('localized_entropy', [])
        if len(bce_runs) != len(le_runs):
            print('[WARN] Repeat run counts do not match between BCE and LE; skipping Wilcoxon test.')
        elif not bce_runs:
            print('[WARN] No repeated runs available for significance testing.')
        else:
            eval_cfg = cfg.get('evaluation', {})
            run_values = run_seeds if len(run_seeds) == len(bce_runs) else None
            bce_metrics = build_repeat_metrics_frame(
                bce_runs,
                eval_labels,
                ece_bins=int(eval_cfg.get('ece_bins', 20)),
                ece_min_count=int(eval_cfg.get('ece_min_count', 1)),
                threshold=0.5,
                small_prob_max=float(eval_cfg.get('small_prob_max', 0.01)),
                small_prob_quantile=float(eval_cfg.get('small_prob_quantile', 0.1)),
                run_label='seed',
                run_values=run_values,
            )
            le_metrics = build_repeat_metrics_frame(
                le_runs,
                eval_labels,
                ece_bins=int(eval_cfg.get('ece_bins', 20)),
                ece_min_count=int(eval_cfg.get('ece_min_count', 1)),
                threshold=0.5,
                small_prob_max=float(eval_cfg.get('small_prob_max', 0.01)),
                small_prob_quantile=float(eval_cfg.get('small_prob_quantile', 0.1)),
                run_label='seed',
                run_values=run_values,
            )
            print('Repeated-run metric summary (BCE):')
            print(format_comparison_table(
                summarize_repeat_metrics(bce_metrics),
                ['metric', 'n', 'mean', 'std', 'min', 'max'],
                top_k=10,
            ))
            print('')
            print('Repeated-run metric summary (LE):')
            print(format_comparison_table(
                summarize_repeat_metrics(le_metrics),
                ['metric', 'n', 'mean', 'std', 'min', 'max'],
                top_k=10,
            ))
            print('')
            if len(bce_runs) < 2:
                print('[WARN] Need at least 2 repeats for a Wilcoxon test; skipping.')
            else:
                wilcoxon_summary = build_wilcoxon_summary(
                    bce_metrics,
                    le_metrics,
                    zero_method=str(repeat_cfg.get('wilcoxon_zero_method', 'wilcox')),
                    alternative=str(repeat_cfg.get('wilcoxon_alternative', 'two-sided')),
                )
                print('Wilcoxon signed-rank test (delta > 0 favors LE):')
                print(format_wilcoxon_summary(wilcoxon_summary))
                if eval_conds is None:
                    print('[WARN] Eval conditions unavailable; skipping per-condition calibration test.')
                else:
                    per_cond_min_count = int(repeat_cfg.get('per_condition_min_count', 1))
                    per_cond_sort_by = str(repeat_cfg.get('per_condition_sort_by', 'p_value'))
                    per_cond_top_k = int(repeat_cfg.get('per_condition_top_k', 20))
                    per_cond_summary = build_per_condition_calibration_wilcoxon(
                        bce_runs,
                        le_runs,
                        eval_labels,
                        eval_conds,
                        zero_method=str(repeat_cfg.get('wilcoxon_zero_method', 'wilcox')),
                        alternative=str(repeat_cfg.get('wilcoxon_alternative', 'two-sided')),
                        min_count=per_cond_min_count,
                    )
                    if per_cond_summary is None or len(per_cond_summary) == 0:
                        print('[WARN] Per-condition calibration test returned no rows.')
                    else:
                        per_cond_summary = per_cond_summary.rename(columns={'condition': condition_label})
                        per_cond_summary = sort_per_condition_wilcoxon_frame(per_cond_summary, per_cond_sort_by)
                        print(f'Per-{condition_label} calibration Wilcoxon (abs pred_mean - base_rate):')
                        print(format_comparison_table(
                            per_cond_summary,
                            [condition_label, 'count', 'base_rate', 'bce_gap', 'le_gap', 'delta_mean', 'p_value'],
                            top_k=per_cond_top_k,
                        ))


In [16]:
if "notebook_output_capture" in globals():
    if hasattr(notebook_output_capture, "stop"):
        notebook_output_capture.stop()