# Localized Entropy Notebook

This notebook runs the end-to-end pipeline using the config in `configs/default.json`.
Switch between synthetic and CTR data, tune the model, and toggle plots from the config file.


In [None]:
%matplotlib inline
import numpy as np
import torch

from localized_entropy.config import load_and_resolve, loss_label, get_condition_label, get_data_source
from localized_entropy.utils import init_device, set_seed
from localized_entropy.data.pipeline import prepare_data
from localized_entropy.models import ConditionProbNet
from localized_entropy.training import train_with_epoch_plots, evaluate, predict_probs
from localized_entropy.analysis import (
    print_pred_summary,
    collect_le_stats_per_condition,
    collect_logits,
    bce_log_loss,
    roc_auc_score,
    pr_auc_score,
    expected_calibration_error,
    per_condition_metrics,
)
from localized_entropy.plotting import (
    plot_training_distributions,
    plot_eval_log10p_hist,
    plot_loss_curves,
    plot_eval_predictions_by_condition,
    plot_le_stats_per_condition,
    plot_ctr_filter_stats,
    plot_feature_distributions_by_condition,
    plot_label_rates_by_condition,
)

np.set_printoptions(precision=4, suppress=True)
torch.set_printoptions(precision=4)

CONFIG_PATH = "configs/default.json"
cfg = load_and_resolve(CONFIG_PATH)
print(f"Using experiment: {cfg['experiment'].get('name', cfg['experiment'].get('active', 'unknown'))}")

device, use_cuda, non_blocking = init_device()
set_seed(cfg['project']['seed'], use_cuda)

def print_pred_stats_by_condition(
    preds: np.ndarray,
    conds: np.ndarray,
    num_conditions: int,
    *,
    name: str = 'Eval',
) -> None:
    p = np.asarray(preds, dtype=np.float64).reshape(-1)
    c = np.asarray(conds, dtype=np.int64).reshape(-1)
    counts = np.bincount(c, minlength=int(num_conditions))
    print(f"{name} prediction stats per condition:")
    for cond in range(int(num_conditions)):
        n = int(counts[cond])
        if n == 0:
            print(f"  cond {cond}: n=0")
            continue
        pc = p[c == cond]
        print(
            f"  cond {cond}: n={n} min={pc.min():.6g} max={pc.max():.6g} "
            f"mean={pc.mean():.6g}"
        )


In [None]:
# Create reusable top-N filtered datasets + per-condition stats (if enabled).
from pathlib import Path

import pandas as pd

filter_col = cfg["ctr"].get("filter_col")
label_col = cfg["ctr"].get("label_col", "click")
filter_top_k = cfg.get("ctr", {}).get("filter_top_k")
filter_top_k = int(filter_top_k) if filter_top_k is not None else None
if filter_top_k is not None and filter_top_k <= 0:
    filter_top_k = None
read_rows = cfg.get("ctr", {}).get("read_rows")
read_rows = int(read_rows) if read_rows else None

train_path = Path(cfg["ctr"]["train_path"])
test_path = Path(cfg["ctr"]["test_path"])

if filter_col and filter_top_k:
    ad_id_freq_txt = Path("results/ad_id_impressions.txt")
    ad_id_freq_csv = Path("results/ad_id_impressions.csv")
    filtered_train_path = Path(f"data/train_top_{filter_top_k}.csv")
    filtered_test_path = Path(f"data/test_top_{filter_top_k}.csv")

    if ad_id_freq_txt.exists():
        freq_df = pd.read_csv(
            ad_id_freq_txt,
            sep=r"\s+",
            header=None,
            names=["impressions", filter_col],
        )
        freq_df = freq_df.sort_values("impressions", ascending=False).reset_index(drop=True)
    else:
        counts = pd.Series(dtype="int64")
        for chunk in pd.read_csv(
            train_path,
            usecols=[filter_col],
            chunksize=1_000_000,
            nrows=read_rows,
        ):
            counts = counts.add(chunk[filter_col].value_counts(), fill_value=0)
        freq_df = (
            counts.astype(int)
            .sort_values(ascending=False)
            .rename("impressions")
            .reset_index()
            .rename(columns={"index": filter_col})
        )

    ad_id_freq_csv.parent.mkdir(parents=True, exist_ok=True)
    freq_df.to_csv(ad_id_freq_csv, index=False)

    top_ids = freq_df.head(filter_top_k)[filter_col].tolist()
    print(f"Top {filter_top_k} {filter_col} values by impressions: {top_ids}")

    def filter_to_ids(input_path: Path, output_path: Path, ids: list) -> int:
        output_path.parent.mkdir(parents=True, exist_ok=True)
        wrote = False
        row_count = 0
        for chunk in pd.read_csv(input_path, chunksize=1_000_000, nrows=read_rows):
            filtered = chunk[chunk[filter_col].isin(ids)]
            if filtered.empty:
                continue
            row_count += len(filtered)
            filtered.to_csv(
                output_path,
                mode="w" if not wrote else "a",
                header=not wrote,
                index=False,
            )
            wrote = True
        if not wrote:
            pd.read_csv(input_path, nrows=0).to_csv(output_path, index=False)
        return row_count

    train_rows = filter_to_ids(train_path, filtered_train_path, top_ids)
    test_rows = filter_to_ids(test_path, filtered_test_path, top_ids)

    cfg["ctr"]["train_path"] = str(filtered_train_path)
    if test_rows > 0:
        cfg["ctr"]["test_path"] = str(filtered_test_path)
    else:
        print("[WARN] Filtered test set is empty; keeping unfiltered test rows.")
        cfg["ctr"]["test_path"] = str(test_path)

    filtered_train = pd.read_csv(filtered_train_path, usecols=[filter_col, label_col])
    stats = (
        filtered_train.groupby(filter_col)[label_col]
        .agg(mean="mean", std=lambda s: s.std(ddof=0), impressions="size")
        .reset_index()
    )
    stats["std"] = stats["std"].fillna(0.0)
    epsilon = 1e-12
    stats["log10_mean"] = np.log10(stats["mean"].clip(lower=epsilon))
    stats = stats.sort_values("impressions", ascending=False)
    print(stats[[filter_col, "impressions", "mean", "std", "log10_mean"]].to_string(index=False))
else:
    print("Skipping top-k filter; using full CTR train/test paths.")


In [None]:
data_bundle = prepare_data(cfg, device, use_cuda)
splits = data_bundle.splits
loaders = data_bundle.loaders
plots_cfg = cfg['plots']
train_cfg = cfg['training']
data_source = get_data_source(cfg)
condition_label = get_condition_label(cfg)

if cfg.get('logging', {}).get('print_loader_note', True):
    print(loaders.loader_note)

if data_source == 'ctr' and cfg.get('ctr', {}).get('plot_filter_stats', False):
    ctr_stats = data_bundle.plot_data.get('ctr_stats')
    if ctr_stats:
        plot_ctr_filter_stats(ctr_stats['stats_df'], ctr_stats['labels'], ctr_stats['filter_col'])

if plots_cfg.get('data_before_training', False):
    synth = data_bundle.plot_data.get('synthetic')
    if synth:
        plot_training_distributions(
            synth['net_worth'],
            synth['ages'],
            synth['probs'],
            synth['conds'],
            synth['num_conditions'],
        )
    else:
        if plots_cfg.get('ctr_data_distributions', True):
            ctr_plot = data_bundle.plot_data.get('ctr_distributions')
            if ctr_plot:
                
                print("DEBUG")
                print(ctr_plot['feature_names'])
                print(ctr_plot['num_conditions'])
                log10_features = set(plots_cfg.get('ctr_log10_features', []))
                plot_feature_distributions_by_condition(
                    ctr_plot['xnum'],
                    ctr_plot['conds'],
                    ctr_plot['feature_names'],
                    ctr_plot['num_conditions'],
                    max_features=int(plots_cfg.get('ctr_max_features', 3)),
                    log10_features=log10_features,
                    density=bool(plots_cfg.get('ctr_use_density', False)),
                )
                if plots_cfg.get('ctr_label_rates', True):
                    plot_label_rates_by_condition(
                        ctr_plot['labels'],
                        ctr_plot['conds'],
                        ctr_plot['num_conditions'],
                    )
            else:
                print('CTR plot_sample_size is disabled or empty; skipping CTR distributions.')
        else:
            print('CTR distribution plots are disabled in config.')
else:
    print('Skipping training data distribution plots before training.')


In [None]:
# Diagnostics: input/label/condition stats
from localized_entropy.analysis import (
    print_condition_stats,
    print_feature_stats,
    print_label_stats,
)

print('Diagnostics: splits')
print_feature_stats('Train features', splits.x_train)
print_feature_stats('Eval features', splits.x_eval)
if splits.x_test is not None:
    print_feature_stats('Test features', splits.x_test)

print_condition_stats('Train conds', splits.c_train, splits.num_conditions)
print_condition_stats('Eval conds', splits.c_eval, splits.num_conditions)
if splits.c_test is not None:
    print_condition_stats('Test conds', splits.c_test, splits.num_conditions)

print_label_stats('Train labels', splits.y_train, splits.c_train, splits.num_conditions)
print_label_stats('Eval labels', splits.y_eval, splits.c_eval, splits.num_conditions)
if splits.y_test is not None:
    print_label_stats('Test labels', splits.y_test, splits.c_test, splits.num_conditions)


In [None]:
model_cfg = cfg['model']
num_numeric = splits.x_train.shape[1]
cat_dims = splits.cat_sizes
cat_embed_dim = model_cfg.get('cat_embed_dim', model_cfg['embed_dim'])
model = ConditionProbNet(
    num_conditions=splits.num_conditions,
    num_numeric=num_numeric,
    embed_dim=model_cfg['embed_dim'],
    cat_dims=cat_dims,
    cat_embed_dim=cat_embed_dim,
    hidden_sizes=tuple(model_cfg['hidden_sizes']),
    p_drop=model_cfg['dropout'],
).to(device)
model


In [None]:
# Diagnostics: initial logits/prob stats (untrained)
try:
    batch = next(iter(loaders.eval_loader))
except StopIteration:
    batch = None

if batch is None:
    print('Eval loader empty; skipping init logits diagnostics.')
else:
    x_b, x_cat_b, c_b, y_b, nw_b = batch
    x_b = x_b.to(device, non_blocking=non_blocking)
    x_cat_b = x_cat_b.to(device, non_blocking=non_blocking)
    c_b = c_b.to(device, non_blocking=non_blocking)
    with torch.no_grad():
        logits = model(x_b, x_cat_b, c_b)
        probs = torch.sigmoid(logits)
    logits_np = logits.detach().cpu().numpy().reshape(-1)
    probs_np = probs.detach().cpu().numpy().reshape(-1)
    print(
        f"Init batch logits: n={logits_np.size:,} min={logits_np.min():.6g} max={logits_np.max():.6g} "
        f"mean={logits_np.mean():.6g} std={logits_np.std():.6g}"
    )
    print(
        f"Init batch probs:  n={probs_np.size:,} min={probs_np.min():.6g} max={probs_np.max():.6g} "
        f"mean={probs_np.mean():.6g} std={probs_np.std():.6g}"
    )


In [None]:
loss_name = loss_label(train_cfg['loss_mode'])
if plots_cfg.get('eval_pred_by_condition', True):
    pretrain_eval_preds = predict_probs(
        model,
        loaders.eval_loader,
        device,
        non_blocking=non_blocking,
    )
    if pretrain_eval_preds.size > 0:
        print_pred_stats_by_condition(
            pretrain_eval_preds,
            splits.c_eval,
            splits.num_conditions,
            name=f"Pre-Training ({loss_name})",
        )
        plot_eval_predictions_by_condition(
            pretrain_eval_preds,
            splits.c_eval,
            splits.num_conditions,
            value_range=(-12, 0),
            title=(
                f"Pre-Training Eval Predictions by {condition_label} ("
                f"{loss_name})"
            ),
        )
    else:
        print('Eval set is empty after filtering; skipping pre-training plot.')


In [None]:
loss_name = loss_label(train_cfg['loss_mode'])


def plot_eval_epoch_preds(preds: np.ndarray, epoch: int) -> None:
    if preds.size == 0:
        print(f"Epoch {epoch}: eval set empty; skipping eval plots.")
        return
    plot_eval_predictions_by_condition(
        preds,
        splits.c_eval,
        splits.num_conditions,
        value_range=(-12, 0),
        title=(
            f"Epoch {epoch} Eval Predictions by {condition_label} ("
            f"{loss_name})"
        ),
        print_counts=False,
    )


def plot_eval_batch_preds(preds: np.ndarray, epoch: int, batch_idx: int) -> None:
    if preds.size == 0:
        print(
            f"Epoch {epoch} Batch {batch_idx}: eval set empty; skipping eval plots."
        )
        return
    plot_eval_predictions_by_condition(
        preds,
        splits.c_eval,
        splits.num_conditions,
        value_range=(-12, 0),
        title=(
            f"Epoch {epoch} Batch {batch_idx} Eval Predictions by {condition_label} ("
            f"{loss_name})"
        ),
        print_counts=False,
    )


plot_eval_epochs = plots_cfg.get('eval_pred_by_condition', True)
eval_every_n_batches = int(train_cfg.get('eval_every_n_batches', 0) or 0)
plot_eval_batches = plot_eval_epochs and eval_every_n_batches > 0

train_losses, eval_losses = train_with_epoch_plots(
    model=model,
    train_loader=loaders.train_loader,
    val_loader=loaders.eval_loader,
    device=device,
    epochs=train_cfg['epochs'],
    lr=train_cfg['lr'],
    non_blocking=non_blocking,
    plot_eval_hist_epochs=plots_cfg.get('eval_hist_epochs', False),
    loss_mode=train_cfg['loss_mode'],
    eval_callback=plot_eval_epoch_preds if plot_eval_epochs else None,
    eval_every_n_batches=eval_every_n_batches if plot_eval_batches else None,
    eval_batch_callback=plot_eval_batch_preds if plot_eval_batches else None,
)


In [None]:
label = loss_label(train_cfg['loss_mode'])
if plots_cfg.get('loss_curves', True):
    plot_loss_curves(train_losses, eval_losses, label)


In [None]:
eval_loss, eval_preds = evaluate(
    model,
    loaders.eval_loader,
    device,
    loss_mode=train_cfg['loss_mode'],
    non_blocking=non_blocking,
)
if plots_cfg.get('print_eval_summary', True):
    print_pred_summary('Eval', eval_preds, labels=splits.y_eval, conds=splits.c_eval)

if plots_cfg.get('eval_pred_hist', True):
    plot_eval_log10p_hist(eval_preds.astype(np.float32), epoch=train_cfg['epochs'])

if plots_cfg.get('eval_pred_by_condition', True):
    plot_eval_predictions_by_condition(
        eval_preds,
        splits.c_eval,
        splits.num_conditions,
        value_range=(-12, 0),
        title=(
            f"Post-Training Eval Predictions by {condition_label} ("
            f"{label})"
        ),
    )
    print("PREDICTIONS ARRAY")
    print(eval_preds[:100])

print(f"Final Evaluation {label}: {eval_loss:.10f}")
eval_cfg = cfg.get('evaluation', {})
bins = int(eval_cfg.get('ece_bins', 20))
min_count = int(eval_cfg.get('ece_min_count', 1))
small_prob_max_cfg = float(eval_cfg.get('small_prob_max', 0.01))
small_prob_quantile = float(eval_cfg.get('small_prob_quantile', 0.1))

total_bce = bce_log_loss(eval_preds, splits.y_eval)
total_ece, total_ece_table = expected_calibration_error(
    eval_preds, splits.y_eval, bins=bins, min_count=min_count
)
small_threshold = small_prob_max_cfg
small_mask = eval_preds <= small_threshold
if not small_mask.any():
    quantile_threshold = float(np.quantile(eval_preds, small_prob_quantile))
    print(
        f"[INFO] No preds <= {small_threshold:g}; using {small_prob_quantile:.2f} quantile "
        f"threshold {quantile_threshold:g} for small-prob calibration."
    )
    small_threshold = quantile_threshold
    small_mask = eval_preds <= small_threshold

if small_mask.any():
    total_ece_small, _ = expected_calibration_error(
        eval_preds[small_mask],
        splits.y_eval[small_mask],
        bins=bins,
        min_count=min_count,
    )
else:
    total_ece_small = float('nan')

print(f"Total BCE (log loss): {total_bce:.8f}")
print(f"Total ECE: {total_ece:.8f}")
print(f"Total ECE (small p<= {small_threshold:g}): {total_ece_small:.8f}")
total_auc = roc_auc_score(eval_preds, splits.y_eval)
total_pr_auc = pr_auc_score(eval_preds, splits.y_eval)
print(f"Total ROC-AUC: {total_auc:.8f}")
print(f"Total PR-AUC (AP): {total_pr_auc:.8f}")

if eval_cfg.get('print_calibration_table', False):
    print(total_ece_table.to_string(index=False))

per_ad = per_condition_metrics(
    eval_preds,
    splits.y_eval,
    splits.c_eval,
    bins=bins,
    min_count=min_count,
    small_prob_max=small_threshold,
)
if eval_cfg.get('print_per_ad', True):
    top_k = int(eval_cfg.get('per_ad_top_k', 10))
    print(per_ad.head(top_k).to_string(index=False))

for mode in train_cfg.get('eval_compare_losses', []):
    if mode == train_cfg['loss_mode']:
        continue
    other_loss, _ = evaluate(
        model,
        loaders.eval_loader,
        device,
        loss_mode=mode,
        non_blocking=non_blocking,
    )
    other_label = loss_label(mode)
    print(f"Final Evaluation {other_label}: {other_loss:.10f}")


In [None]:
if loaders.test_loader is not None:
    test_preds = predict_probs(
        model,
        loaders.test_loader,
        device,
        non_blocking=non_blocking,
    )
    if test_preds.size > 0:
        if plots_cfg.get('print_eval_summary', True):
            print_pred_summary('Test', test_preds, labels=None, conds=splits.c_test)
    else:
        print('Test set is empty after filtering; skipping summary.')


In [None]:
if plots_cfg.get('data_after_training', False):
    synth = data_bundle.plot_data.get('synthetic')
    if synth:
        plot_training_distributions(
            synth['net_worth'],
            synth['ages'],
            synth['probs'],
            synth['conds'],
            synth['num_conditions'],
        )
    else:
        if plots_cfg.get('ctr_data_distributions', True):
            ctr_plot = data_bundle.plot_data.get('ctr_distributions')
            if ctr_plot:
                log10_features = set(plots_cfg.get('ctr_log10_features', []))
                plot_feature_distributions_by_condition(
                    ctr_plot['xnum'],
                    ctr_plot['conds'],
                    ctr_plot['feature_names'],
                    ctr_plot['num_conditions'],
                    max_features=int(plots_cfg.get('ctr_max_features', 3)),
                    log10_features=log10_features,
                    density=bool(plots_cfg.get('ctr_use_density', False)),
                )
                if plots_cfg.get('ctr_label_rates', True):
                    plot_label_rates_by_condition(
                        ctr_plot['labels'],
                        ctr_plot['conds'],
                        ctr_plot['num_conditions'],
                    )
            else:
                print('CTR plot_sample_size is disabled or empty; skipping CTR post-training plots.')
        else:
            print('CTR post-training plots are disabled in config.')
else:
    print('Post-training training data plots are disabled.')


In [None]:
if plots_cfg.get('le_stats', True) or plots_cfg.get('print_le_stats_table', True):
    z_all, y_all, c_all = collect_logits(model, loaders.eval_loader, device, non_blocking=non_blocking)
    le_stats = collect_le_stats_per_condition(z_all, y_all, c_all, eps=1e-12)

    if plots_cfg.get('print_le_stats_table', True):
        print('cond	num	den	avg_p	#y=1	#y=0	ratio')
        for cond in sorted(le_stats.keys()):
            s = le_stats[cond]
            print(
                f"{cond}	{s['Numerator']:.6g}	{s['Denominator']:.6g}	"
                f"{s['Average prediction for denominator']:.6g}	"
                f"{s['Number of samples with label 1']}	"
                f"{s['Number of samples with label 0']}	"
                f"{s['Numerator/denominator']:.6g}"
            )

    if plots_cfg.get('le_stats', True):
        plot_le_stats_per_condition(le_stats, title='Localized Entropy terms per condition - Eval set')
