# Localized Entropy Coss Function

- Inputs: Age (float), NetWorth (float), MedicalCondition (categorical via embedding).
- Probabilities follow the generated data and are converted to labels with numpy binomial (n=1).
- Training uses only labels (not true probabilities).
- Pre-training: plot N-line distributions on training data for: log10(NetWorth), Age, and log10(true probability).
- Split 90/10 into train/eval.
- Model includes dropout; trains with BCE (BCEWithLogitsLoss), batch size=10,000.
- Configurable epochs (default 30) and learning rate.
- After each epoch: plot eval predicted probability distribution with x-axis=log10(pred p).
- After training: plot Train vs Eval BCE curves; print final BCE; collect eval predictions and plot eval charts analogous to input charts (N lines).


In [None]:

%matplotlib inline
from typing import Tuple, List, Optional
import os
os.environ["LOCALIZED_ENTROPY_NUM_WORKERS"] = "1"
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import norm

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

plt.style.use('seaborn-v0_8')
np.set_printoptions(precision=4, suppress=True)
torch.set_printoptions(precision=4)

USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda' if USE_CUDA else 'cpu')
NON_BLOCKING = USE_CUDA

if USE_CUDA:
    gpu_name = torch.cuda.get_device_name(device)
    print(f'Using CUDA device: {gpu_name}')
    torch.backends.cudnn.benchmark = True
else:
    print('CUDA not available, defaulting to CPU.')

device


In [None]:
# Probability generation
def _sigmoid(x: np.ndarray, mu: float, s: float) -> np.ndarray:
    return 1.0 / (1.0 + np.exp(-(x - mu) / s))

def generate_probs(
    num_samples: int,
    mu_ln: float,
    sigma_ln: float,
    sig_mu: float,
    sig_s: float,
    mu_age: float,
    sigma_age: float,
    interestScale: float,
    min_age: int = 10,
    max_age: int = 100,
    rng: Optional[np.random.Generator] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Implements generation logic and returns
    (net_worth, ages, probabilities).
    """
    if rng is None:
        rng = np.random.default_rng()

    net_worth = rng.lognormal(mean=mu_ln, sigma=sigma_ln, size=num_samples)
    probs = _sigmoid(np.log10(net_worth + 1.0), mu=sig_mu, s=sig_s)

    ages = rng.integers(min_age, max_age, size=num_samples)
    denom = norm.pdf(50, loc=mu_age, scale=sigma_age)
    denom = denom if denom > 0 else 1.0
    interest_prob = norm.pdf(ages, loc=mu_age, scale=sigma_age) / denom
    interest_prob = interest_prob / interestScale

    probs = probs * interest_prob
    probs = np.clip(probs, 0.0, 1.0)
    return net_worth.astype(np.float32), ages.astype(np.float32), probs.astype(np.float32)

def sample_condition_params(rng: np.random.Generator) -> Tuple[float, float, float, float, float, float, float]:
    # Parameter ranges
    mu_ln = 10.0
    sigma_ln = 1.5
    sig_mu = rng.choice(np.linspace(5, 7, 50))
    sig_s = rng.choice(np.linspace(0.2, 0.3, 50))
    mu_age = rng.choice(np.linspace(30, 70, 50))
    sigma_age = rng.choice(np.linspace(10, 30, 50))
    interestScale = 10.0 ** (rng.choice(np.linspace(0, 0.001, 50)))
    return mu_ln, sigma_ln, sig_mu, sig_s, mu_age, sigma_age, interestScale

def make_dataset(
    num_conditions: int,
    min_samples_per_condition: int = 90_000,
    max_samples_per_condition: int = 100_000,
    seed: int = 42,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Returns: ages, net_worth, condition_ids, labels, probs
    Labels are sampled via Binomial(1, p)
    """
    rng = np.random.default_rng(seed)
    ages_all, nw_all, conds_all, labels_all, probs_all = [], [], [], [], []
    for cond in range(num_conditions):
        n = int(rng.integers(min_samples_per_condition, max_samples_per_condition + 1))
        params = sample_condition_params(rng)
        net_worth, ages, probs = generate_probs(n, *params, rng=rng)
        labels = rng.binomial(n=1, p=probs).astype(np.float32)
        ages_all.append(ages)
        nw_all.append(net_worth)
        conds_all.append(np.full_like(ages, fill_value=cond, dtype=np.float32))
        labels_all.append(labels)
        probs_all.append(probs)
    ages = np.concatenate(ages_all, axis=0)
    net_worth = np.concatenate(nw_all, axis=0)
    conds = np.concatenate(conds_all, axis=0)
    labels = np.concatenate(labels_all, axis=0)
    probs = np.concatenate(probs_all, axis=0)
    return ages, net_worth, conds.astype(np.int64), labels, probs


In [None]:
# Plotting helpers
def _density_lines(
    values: np.ndarray,
    groups: np.ndarray,
    num_conditions: int,
    *,
    bins: int = 100,
    transform: Optional[str] = None,  # None or 'log10'
    value_range: Optional[Tuple[float, float]] = None,
    title: str = '',
    x_label: str = ''
) -> None:
    vals = values.astype(np.float64).copy()
    if transform == 'log10':
        eps = 1e-12
        vals = np.log10(np.clip(vals, eps, None))
    if value_range is None:
        vmin, vmax = float(np.nanmin(vals)), float(np.nanmax(vals))
    else:
        vmin, vmax = value_range
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
        vmin, vmax = 0.0, 1.0
    edges = np.linspace(vmin, vmax, bins + 1)
    centers = 0.5 * (edges[:-1] + edges[1:])
    plt.figure(figsize=(10, 6))
    for cond in range(int(num_conditions)):
        m = groups == cond
        if not np.any(m):
            continue
        vv = vals[m]
        hist, _ = np.histogram(vv, bins=edges, density=True)
        plt.plot(centers, hist, label=f'Condition {cond}')
    plt.title(title)
    plt.xlabel(x_label)
    plt.ylabel('Density')
    plt.grid(True, alpha=0.3)
    plt.legend(ncol=2, fontsize=8)
    plt.tight_layout()
    plt.show()

def plot_training_distributions(net_worth: np.ndarray,
                                ages: np.ndarray,
                                probs: np.ndarray,
                                conds: np.ndarray,
                                num_conditions: int) -> None:
    _density_lines(
        values=net_worth,
        groups=conds,
        num_conditions=num_conditions,
        bins=120,
        transform='log10',
        title='Training Data: Distribution by Condition (log10(NetWorth))',
        x_label='log10(NetWorth)'
    )
    _density_lines(
        values=ages,
        groups=conds,
        num_conditions=num_conditions,
        bins=120,
        transform=None,
        title='Training Data: Distribution by Condition (Age)',
        x_label='Age'
    )
    _density_lines(
        values=probs,
        groups=conds,
        num_conditions=num_conditions,
        bins=120,
        transform='log10',
        value_range=(-12, 0),
        title='Training Data: Distribution by Condition (log10(true probability))',
        x_label='log10(true p)'
    )

def plot_eval_log10p_hist(preds: np.ndarray, epoch: int, bins: int = 100) -> None:
    eps = 1e-12
    log10p = np.log10(np.clip(preds, eps, 1.0))
    plt.figure(figsize=(8, 5))
    plt.hist(log10p, bins=bins, range=(-12, 0), density=True, color='#4477aa', alpha=0.85)
    plt.title(f'Eval Predicted Probability: log10(p) (Epoch {epoch})')
    plt.xlabel('log10(pred p)')
    plt.ylabel('Density')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


In [None]:
# Dataset (numeric features + categorical condition)
class MedCondDataset(Dataset):
    def __init__(
        self,
        x_num: np.ndarray,
        conds: np.ndarray,
        labels: np.ndarray,
        net_worth: Optional[np.ndarray] = None,
        share_memory: bool = False,
    ):
        assert x_num.ndim == 2 and x_num.shape[1] == 2
        assert len(x_num) == len(conds) == len(labels)
        self.x = torch.as_tensor(x_num, dtype=torch.float32).contiguous()
        self.c = torch.as_tensor(conds, dtype=torch.long).contiguous()
        self.y = torch.as_tensor(labels, dtype=torch.float32).contiguous()
        if net_worth is None:
            self.nw = torch.zeros(len(labels), dtype=torch.float32)
        else:
            assert len(net_worth) == len(labels)
            self.nw = torch.as_tensor(net_worth, dtype=torch.float32).contiguous()
        if share_memory:
            for tensor in (self.x, self.c, self.y, self.nw):
                if tensor.device.type == 'cpu':
                    tensor.share_memory_()
    def __len__(self) -> int:
        return self.y.numel()
    def __getitem__(self, idx: int):
        return (
            self.x[idx],
            self.c[idx],
            self.y[idx],
            self.nw[idx],
        )

class TensorBatchLoader:
    def __init__(self, tensors: Tuple[torch.Tensor, ...], batch_size: int, shuffle: bool):
        assert len(tensors) > 0
        n = tensors[0].shape[0]
        for t in tensors[1:]:
            assert t.shape[0] == n, 'All tensors must share the first dimension.'
        self.tensors = tensors
        self.batch_size = int(max(1, batch_size))
        self.shuffle = shuffle
        self.length = n
        self.device = tensors[0].device
    def __len__(self) -> int:
        return (self.length + self.batch_size - 1) // self.batch_size
    @property
    def num_workers(self) -> int:
        return 0
    def __iter__(self):
        indices = torch.arange(self.length, device=self.device, dtype=torch.long)
        if self.shuffle:
            indices = indices[torch.randperm(self.length, device=self.device)]
        for start in range(0, self.length, self.batch_size):
            batch_idx = indices[start:start + self.batch_size]
            yield tuple(t.index_select(0, batch_idx) for t in self.tensors)

class ConditionProbNet(nn.Module):
    def __init__(
        self,
        num_conditions: int,
        embed_dim: int = 16,
        hidden_sizes: Optional[Tuple[int, ...]] = None,
        p_drop: float = 0.3,
    ):
        super().__init__()
        if hidden_sizes is None:
            hidden_sizes = (256, 256, 128, 64)
        self.embedding = nn.Embedding(num_conditions, embed_dim)
        layers: List[nn.Module] = []
        in_dim = embed_dim + 2
        for hidden in hidden_sizes:
            layers.extend([
                nn.Linear(in_dim, hidden),
                nn.ReLU(),
                nn.Dropout(p=p_drop),
            ])
            in_dim = hidden
        layers.append(nn.Linear(in_dim, 1))
        self.net = nn.Sequential(*layers)
    def forward(self, x_num: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        emb = self.embedding(cond)
        x = torch.cat([x_num, emb], dim=-1)
        logits = self.net(x).squeeze(-1)
        return logits

# Custom BCE loss (loop-based, with logits)
def custom_bce_with_logits_loop(
    logits: torch.Tensor,
    targets: torch.Tensor,
    conditions_np: np.ndarray,
    net_worth_np: np.ndarray,
    condition_weights: Optional[np.ndarray] = None,
    nw_threshold: Optional[float] = None,
    nw_multiplier: float = 1.0,
    reduction: str = 'mean'
) -> torch.Tensor:
    """
    Loop-based BCE-with-logits. Accepts an extra numpy array of
    per-sample condition ids for future custom modifications.
    """
    logits_flat = logits.view(-1)
    targets_flat = targets.view(-1)
    conds_flat = np.asarray(conditions_np).reshape(-1)
    nw_flat = np.asarray(net_worth_np).reshape(-1)
    total = torch.zeros((), device=logits.device, dtype=logits.dtype)
    for i in range(logits_flat.shape[0]):
        z = logits_flat[i]
        y = targets_flat[i]
        cond_id = int(conds_flat[i])  # available for custom logic
        nw_val = float(nw_flat[i])    # raw net worth (for custom logic)
        # Determine per-condition weight (default=1.0)
        if condition_weights is not None:
            try:
                w = float(condition_weights[cond_id])
                if not np.isfinite(w) or w <= 0:
                    w = 1.0
            except Exception:
                w = 1.0
        else:
            w = 1.0
        # Optional: upweight samples above a net-worth threshold
        if (nw_threshold is not None) and (nw_multiplier != 1.0):
            if np.isfinite(nw_val) and (nw_val >= nw_threshold):
                w = w * float(nw_multiplier)
        # Standard stable BCE-with-logits term, weighted
        per_sample = (torch.clamp_min(z, 0.0) - z * y + torch.log1p(torch.exp(-torch.abs(z))))
        total = total + (w * per_sample)
    if reduction == 'mean':
        return total / logits_flat.shape[0]
    elif reduction == 'sum':
        return total
    else:
        return total

def localized_entropy_bce_torch(
      logits: torch.Tensor,
      targets: torch.Tensor,
      conditions: torch.Tensor,
      base_rates: Optional[torch.Tensor] = None,         # optional per-condition base rates p_j
      net_worth: Optional[torch.Tensor] = None,          # unused currently
      condition_weights: Optional[torch.Tensor] = None,  # optional per-condition weight
      nw_threshold: Optional[float] = None,              # unused currently
      nw_multiplier: float = 1.0,                        # unused currently
      reduction: str = 'mean',
      eps: float = 1e-12,
  ) -> torch.Tensor:
      """
      Localized Entropy (LE) implementation using PyTorch.

      Mathematics
      -----------
        Let classes be indexed by j = 1..M, with N_j samples in class j,
        labels y_i ∈ {0,1}, predicted probs ŷ_i = σ(z_i) from logits z_i, and
        per-class base rate p_j = mean(y | class j).

        Define cross-entropy over class j:
          CE_j(y, q) = Σ_{i in class j} [ -y_i * log(q_i) - (1 - y_i) * log(1 - q_i) ]

        Localized Entropy:
          LE = ( Σ_{j=1..M}  CE_j(y, ŷ) / CE_j(y, p_j) ) / ( Σ_{j=1..M} N_j )

        - Numerator uses per-sample predictions (stable BCE-with-logits).
        - Denominator uses a constant predictor at the class base rate p_j.
        - We clamp probabilities to [eps, 1-eps] for numerical stability.
        - We divide by total samples Σ_j N_j (not by number of classes).

      Gradient safety
      ---------------
        - All math is in torch; gradients flow through logits z_i in the numerator.
        - Denominator depends on labels only (p_j from y), so it's treated as a
          constant w.r.t. logits, which matches the intended LE definition.

      Parameters
      ----------
        logits:        shape (N,) or (N,1). Raw model outputs z_i.
        targets:       shape (N,). Binary labels {0,1}.
        conditions:    shape (N,). Integer class IDs per sample.
        base_rates:    Optional 1D tensor of length >= max condition id + 1.
                       If provided, uses base_rates[j] as p_j for denominator.
                       Expected on same device/dtype as logits/targets (or will be cast).
        condition_weights:
                       Optional 1D tensor of per-class weights where index is class id.
                       If provided and indexable for a class id, scales that class term.
        reduction:     'mean' (default) returns LE as defined; 'sum' returns LE * N.
        eps:           Small constant for numerical stability.

      Returns
      -------
        Scalar torch.Tensor (loss).
      """

      # Flatten and align dtypes
      z = logits.view(-1)                 # logits z_i
      y = targets.view(-1).to(z.dtype)    # labels as float
      c = conditions.view(-1).to(torch.long)

      # Stable BCE-with-logits per-sample numerator terms
      # BCE(z,y) = clamp_min(z,0) - z*y + log1p(exp(-|z|))
      bce_per = torch.clamp_min(z, 0) - z * y + torch.log1p(torch.exp(-torch.abs(z)))

      total = z.new_zeros(())             # accumulator for Σ_j (Num_j / Den_j)
      unique_conds = torch.unique(c)
      N = y.numel()                       # total number of samples (Σ_j N_j)

      # Iterate classes j and compute normalized term
      for cid in unique_conds:
          mask = (c == cid)

          # NUMERATOR: CE_j(y, ŷ) = Σ_i BCE(z_i, y_i) over class j
          num = bce_per[mask].sum()

          # DENOMINATOR: CE_j(y, p_j) with constant p_j = mean(y | class j)
          yj = y[mask]
          n = mask.sum()                    # N_j (int tensor)
          ones = yj.sum()                   # Σ y_i (float)
          zeros = n.to(y.dtype) - ones      # N_j - Σ y_i
          # Determine p_j: prefer provided base_rates if available; else batch mean
          if base_rates is not None:
              idx = cid.item()
              if 0 <= idx < base_rates.numel():
                  pj = base_rates[idx].to(y.dtype)
                  if not torch.isfinite(pj):
                      pj = ones / n.clamp_min(1)
              else:
                  pj = ones / n.clamp_min(1)
          else:
              pj = ones / n.clamp_min(1)
          pj = pj.clamp(eps, 1.0 - eps)     # p_j in [eps, 1-eps]

          # CE at constant q_i = p_j:
          # Σ_i [ -y_i*log(p_j) - (1 - y_i)*log(1 - p_j) ]
          den = ones * (-torch.log(pj)) + zeros * (-torch.log1p(-pj))

          # Normalized class contribution: CE_j(y, ŷ) / CE_j(y, p_j)
          class_term = num / den.clamp_min(eps)

          # Optional per-class weight w_j (default 1.0 if not provided or invalid)
          if condition_weights is not None:
              # Expect weights indexed by class id (0..K-1). If cid beyond length,
              # or weight is non-finite/non-positive, fall back to 1.0.
              idx = cid.item()
              if 0 <= idx < condition_weights.numel():
                  w = condition_weights[idx]
                  if torch.isfinite(w) and (w > 0):
                      class_term = class_term * w
                  # else leave unweighted

          total += class_term

      # Aggregate across classes and average by total samples Σ_j N_j
      loss = total / max(N, 1)

      if reduction == 'sum':
          return loss * N
      return loss

In [None]:

@torch.no_grad()
def evaluate(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    condition_weights: Optional[np.ndarray] = None,
    nw_threshold: Optional[float] = None,
    nw_multiplier: float = 1.0,
    non_blocking: bool = False,
) -> Tuple[float, np.ndarray]:
    model.eval()
    # Using custom BCE with logits (loop-based)
    total_loss = 0.0
    total_count = 0
    preds_all = []
    verified_cuda_batch = False
    for x, c, y, nw in loader:
        x = x.to(device, non_blocking=non_blocking)
        c = c.to(device, non_blocking=non_blocking)
        y = y.to(device, non_blocking=non_blocking)
        nw = nw.to(device, non_blocking=non_blocking)
        logits = model(x, c)
        if (device.type == 'cuda') and (not verified_cuda_batch):
            tensors = (x, c, y, nw, logits)
            if any(t.device.type != 'cuda' for t in tensors):
                raise RuntimeError('Expected CUDA tensors during evaluation but found CPU tensors.')
            verified_cuda_batch = True
        loss = localized_entropy_bce_torch(
            logits=logits,
            targets=y,
            conditions=c,
            net_worth=nw,
            condition_weights=(
            torch.as_tensor(condition_weights, device=logits.device, dtype=logits.dtype)
            if condition_weights is not None else None),
            nw_threshold=nw_threshold,
            nw_multiplier=nw_multiplier,
        )
        total_loss += float(loss.item()) * x.size(0)
        total_count += x.size(0)
        p = torch.sigmoid(logits).detach().cpu().numpy()
        preds_all.append(p)
    mean_loss = total_loss / max(1, total_count)
    preds = np.concatenate(preds_all, axis=0)
    return mean_loss, preds

# Per-epoch cumulative base-rate tracker
class StreamingBaseRate:
    def __init__(self, num_conditions: int, device: torch.device, dtype: torch.dtype = torch.float32):
        self.num_conditions = int(num_conditions)
        self.device = device
        self.dtype = dtype
        self.reset()
    def reset(self):
        self.counts = torch.zeros(self.num_conditions, dtype=torch.long, device=self.device)
        self.sum_ones = torch.zeros(self.num_conditions, dtype=self.dtype, device=self.device)
    @torch.no_grad()
    def update(self, y: torch.Tensor, c: torch.Tensor):
        c = c.view(-1).to(torch.long)
        y = y.view(-1).to(self.dtype)
        cnt = torch.bincount(c, minlength=self.num_conditions)
        s1 = torch.bincount(c, weights=y, minlength=self.num_conditions)
        self.counts += cnt
        self.sum_ones += s1
    @torch.no_grad()
    def rates(self, eps: float = 1e-12) -> torch.Tensor:
        denom = self.counts.clamp_min(1).to(self.sum_ones.dtype)
        return (self.sum_ones / denom).clamp(eps, 1.0 - eps)

def train_with_epoch_plots(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    epochs: int,
    lr: float,
    condition_weights: Optional[np.ndarray] = None,
    nw_threshold: Optional[float] = None,
    nw_multiplier: float = 1.0,
    non_blocking: bool = False,
    plot_eval_hist_epochs: bool = False,
) -> Tuple[List[float], List[float]]:
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    # Using custom BCE with logits (loop-based)
    loss_fn = nn.BCEWithLogitsLoss()
    train_losses: List[float] = []
    val_losses: List[float] = []
    if device.type == 'cuda':
        first_param = next(model.parameters(), None)
        if (first_param is not None) and (first_param.device.type != 'cuda'):
            raise RuntimeError('Model parameters must be on CUDA when device is CUDA.')
    if device.type == 'cuda':
        torch.cuda.reset_peak_memory_stats(device)
    for epoch in range(1, epochs + 1):
        model.train()
        # Reset per-epoch cumulative base rates (class label means)
        num_conds_model = getattr(model, 'embedding').num_embeddings if hasattr(model, 'embedding') else 1
        first_param = next(model.parameters(), None)
        p_dtype = first_param.dtype if first_param is not None else torch.float32
        br_tracker = StreamingBaseRate(num_conds_model, device=device, dtype=p_dtype)
        running = 0.0
        count = 0
        verified_cuda_batch = False
        epoch_start = time.time()
        for x, c, y, nw in train_loader:
            x = x.to(device, non_blocking=non_blocking)
            c = c.to(device, non_blocking=non_blocking)
            y = y.to(device, non_blocking=non_blocking)
            nw = nw.to(device, non_blocking=non_blocking)
            if (device.type == 'cuda') and (not verified_cuda_batch):
                tensors = (x, c, y, nw)
                if any(t.device.type != 'cuda' for t in tensors):
                    raise RuntimeError('Detected CPU tensors in the training loop while using CUDA.')
                verified_cuda_batch = True
            opt.zero_grad(set_to_none=True)
            logits = model(x, c)
            # Update per-epoch base rates with current batch, then use them for denominator
            br_tracker.update(y, c)
            loss = localized_entropy_bce_torch(
                logits=logits,
                targets=y,
                conditions=c,
                base_rates=br_tracker.rates(),
                net_worth=nw,
                condition_weights=(
                torch.as_tensor(condition_weights, device=logits.device, dtype=logits.dtype)
                if condition_weights is not None else None),
                nw_threshold=nw_threshold,
                nw_multiplier=nw_multiplier,
            )
            # loss = loss_fn(logits, y)
            loss.backward()
            opt.step()
            running += float(loss.item()) * x.size(0)
            count += x.size(0)
        if device.type == 'cuda':
            torch.cuda.synchronize(device)
        train_bce = running / max(1, count)
        val_bce, preds = evaluate(
            model, val_loader, device,
            condition_weights=condition_weights,
            nw_threshold=nw_threshold,
            nw_multiplier=nw_multiplier,
            non_blocking=non_blocking,
        )
        train_losses.append(train_bce)
        val_losses.append(val_bce)
        epoch_time = time.time() - epoch_start
        log_msg = f'Epoch {epoch:3d}/{epochs} | Train BCE: {train_bce:.6f} | Eval BCE: {val_bce:.6f} | wall {epoch_time:.2f}s'
        if device.type == 'cuda':
            mem_alloc = torch.cuda.memory_allocated(device) / 1e6
            peak_mem = torch.cuda.max_memory_allocated(device) / 1e6
            log_msg += f" | cuda_mem={mem_alloc:.1f}MB (peak {peak_mem:.1f}MB)"
        print(log_msg)
    if plot_eval_hist_epochs:
        plot_eval_log10p_hist(preds.astype(np.float32), epoch)
    print(f'Final Train BCE: {train_losses[-1]:.10f}')
    print(f'Final Eval  BCE: {val_losses[-1]:.10f}')
    return train_losses, val_losses


In [None]:


# Configuration & data generation
seed = 42
epochs = 12
lr = 1e-3
batch_size = 25_000
num_conditions = 10
min_samples_per_condition = 200_000
max_samples_per_condition = 200_000
PLOT_DATA_BEFORE_TRAINING = False
PLOT_DATA_AFTER_TRAINING = True
PLOT_EVAL_HIST_EPOCHS = False
MOVE_DATASET_TO_CUDA = True
model_hidden_sizes = (256, 256, 128, 64)
model_embed_dim = 16
model_dropout = 0.3

np.random.seed(seed)
torch.manual_seed(seed)
if USE_CUDA:
    torch.cuda.manual_seed_all(seed)

print('Generating dataset...')
ages, net_worth, conds, labels, probs = make_dataset(
    num_conditions=num_conditions,
    min_samples_per_condition=min_samples_per_condition,
    max_samples_per_condition=max_samples_per_condition,
    seed=seed,
)
n_total = len(labels)
print(f'Total samples: {n_total:,}')

# Split 90/10 into train/eval
idx = np.arange(n_total)
np.random.shuffle(idx)
split = int(0.9 * n_total)
train_idx = idx[:split]
eval_idx = idx[split:]

age_tr, age_ev = ages[train_idx], ages[eval_idx]
nw_tr, nw_ev = net_worth[train_idx], net_worth[eval_idx]
cond_tr, cond_ev = conds[train_idx], conds[eval_idx]
y_tr, y_ev = labels[train_idx], labels[eval_idx]
p_tr, p_ev = probs[train_idx], probs[eval_idx]

if PLOT_DATA_BEFORE_TRAINING:
    plot_training_distributions(nw_tr, age_tr, p_tr, cond_tr, num_conditions)
else:
    print('Skipping training data distribution plots before training.')

# Build numeric features: [Age, log10(NetWorth)] then standardize from training stats
log10_nw_tr = np.log10(np.clip(nw_tr, 1e-12, None))
log10_nw_ev = np.log10(np.clip(nw_ev, 1e-12, None))
xnum_tr = np.stack([age_tr, log10_nw_tr], axis=1)
xnum_ev = np.stack([age_ev, log10_nw_ev], axis=1)
mu = xnum_tr.mean(axis=0)
sd = xnum_tr.std(axis=0)
sd[sd < 1e-6] = 1.0
xnum_tr_n = (xnum_tr - mu) / sd
xnum_ev_n = (xnum_ev - mu) / sd

use_tensor_loader = USE_CUDA and MOVE_DATASET_TO_CUDA
if use_tensor_loader:
    print('Staging datasets directly on CUDA for batch sampling.')
    train_tensors = (
        torch.as_tensor(xnum_tr_n, dtype=torch.float32, device=device),
        torch.as_tensor(cond_tr, dtype=torch.long, device=device),
        torch.as_tensor(y_tr, dtype=torch.float32, device=device),
        torch.as_tensor(nw_tr, dtype=torch.float32, device=device),
    )
    eval_tensors = (
        torch.as_tensor(xnum_ev_n, dtype=torch.float32, device=device),
        torch.as_tensor(cond_ev, dtype=torch.long, device=device),
        torch.as_tensor(y_ev, dtype=torch.float32, device=device),
        torch.as_tensor(nw_ev, dtype=torch.float32, device=device),
    )
    train_loader = TensorBatchLoader(train_tensors, batch_size=batch_size, shuffle=True)
    eval_loader = TensorBatchLoader(eval_tensors, batch_size=batch_size, shuffle=False)
    loader_note = (
        f'TensorBatchLoader on CUDA (batches per epoch: {len(train_loader)} / {len(eval_loader)}).'
    )
else:
    train_ds = MedCondDataset(xnum_tr_n, cond_tr, y_tr, net_worth=nw_tr)
    eval_ds = MedCondDataset(xnum_ev_n, cond_ev, y_ev, net_worth=nw_ev)
    loader_common = dict(batch_size=batch_size, drop_last=False, pin_memory=USE_CUDA)
    max_workers = os.cpu_count() or 1
    env_override = os.environ.get('LOCALIZED_ENTROPY_NUM_WORKERS')
    if env_override is not None:
        try:
            num_workers = max(0, min(int(env_override), max_workers))
        except ValueError:
            num_workers = 0
    else:
        num_workers = 0 if USE_CUDA else min(2, max_workers)
    worker_kwargs = {}
    if num_workers > 0:
        worker_kwargs = dict(num_workers=num_workers, persistent_workers=False, prefetch_factor=2)

    def _instantiate_loader(dataset: Dataset, *, shuffle: bool, use_workers: bool) -> DataLoader:
        kwargs = dict(loader_common)
        if use_workers and worker_kwargs:
            kwargs.update(worker_kwargs)
        kwargs['shuffle'] = shuffle
        return DataLoader(dataset, **kwargs)

    def _build_loader(dataset: Dataset, *, shuffle: bool, role: str) -> DataLoader:
        if not worker_kwargs:
            return _instantiate_loader(dataset, shuffle=shuffle, use_workers=False)
        test_iter = None
        try:
            test_loader = _instantiate_loader(dataset, shuffle=shuffle, use_workers=True)
            test_iter = iter(test_loader)
            next(test_iter)
        except Exception as exc:
            if test_iter is not None and hasattr(test_iter, '_shutdown_workers'):
                test_iter._shutdown_workers()
            print(f"[WARN] {role} DataLoader workers failed ({exc}); falling back to workers=0.")
            return _instantiate_loader(dataset, shuffle=shuffle, use_workers=False)
        else:
            if test_iter is not None and hasattr(test_iter, '_shutdown_workers'):
                test_iter._shutdown_workers()
            return _instantiate_loader(dataset, shuffle=shuffle, use_workers=True)

    train_loader = _build_loader(train_ds, shuffle=True, role='Train')
    eval_loader = _build_loader(eval_ds, shuffle=False, role='Eval')
    train_workers = getattr(train_loader, 'num_workers', 0)
    eval_workers = getattr(eval_loader, 'num_workers', 0)
    loader_note = (
        f"Train/Eval DataLoader workers: {train_workers}/{eval_workers} (pin_memory={loader_common.get('pin_memory', False)})"
    )
    if USE_CUDA and train_workers == 0:
        loader_note += " | Multiprocessing disabled for CUDA stability; set LOCALIZED_ENTROPY_NUM_WORKERS>0 to retry."

print(loader_note)

model = ConditionProbNet(
    num_conditions=num_conditions,
    embed_dim=model_embed_dim,
    hidden_sizes=model_hidden_sizes,
    p_drop=model_dropout,
).to(device)
model



In [None]:

# Train
train_losses, eval_losses = train_with_epoch_plots(
    model=model,
    train_loader=train_loader,
    val_loader=eval_loader,
    device=device,
    epochs=epochs,
    lr=lr,
    non_blocking=NON_BLOCKING,
    plot_eval_hist_epochs=PLOT_EVAL_HIST_EPOCHS,
)
train_losses, eval_losses


In [None]:
# Plot BCE curves
plt.figure(figsize=(8, 5))
plt.plot(train_losses, label='Train BCE')
plt.plot(eval_losses, label='Eval BCE')
plt.xlabel('Epoch')
plt.ylabel('BCE Loss')
plt.title('Training vs Evaluation BCE over Epochs')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f'Final Train BCE: {train_losses[-1]:.10f}')
print(f'Final Eval  BCE: {eval_losses[-1]:.10f}')

In [None]:

# Final evaluation and plots analogous to input
eval_bce, eval_preds = evaluate(model, eval_loader, device, non_blocking=NON_BLOCKING)
eval_probs = np.clip(eval_preds, 1e-12, 1 - 1e-12)
plot_eval_log10p_hist(eval_preds.astype(np.float32), epoch=epochs)

# Distribution plots for evaluation predictions vs conditions
_density_lines(
    values=eval_preds,
    groups=cond_ev,
    num_conditions=num_conditions,
    bins=120,
    transform='log10',
    value_range=(-12, 0),
    title='Eval Predictions: Distribution by Condition (log10(pred probability))',
    x_label='log10(pred p)'
)

print(f'Final Evaluation BCE: {eval_bce:.10f}')


In [None]:
# Optional: render training data distributions after training
if PLOT_DATA_AFTER_TRAINING:
    plot_training_distributions(nw_tr, age_tr, p_tr, cond_tr, num_conditions)
else:
    print('Post-training training data plots are disabled. Set PLOT_DATA_AFTER_TRAINING=True to enable.')


In [None]:
# LE per-condition stats + plotting
import math
import numpy as np
import matplotlib.pyplot as plt

def collect_le_stats_per_condition_torch(logits: torch.Tensor,
                                        targets: torch.Tensor,
                                        conditions: torch.Tensor,
                                        eps: float = 1e-12):
    """Compute LE numerator/denominator terms per condition using logits.
    Returns a dict: cond_id -> metrics.
    """
    z = logits.view(-1)
    y = targets.view(-1).to(z.dtype)
    c = conditions.view(-1).to(torch.long)

    # Stable BCE-with-logits per-sample
    bce_per = torch.clamp_min(z, 0) - z * y + torch.log1p(torch.exp(-torch.abs(z)))

    stats = {}
    for cid in torch.unique(c):
        mask = (c == cid)
        n = int(mask.sum().item())
        if n == 0:
            continue
        yj = y[mask]
        num = float(bce_per[mask].sum().item())
        ones = float(yj.sum().item())
        zeros = float(n) - ones
        pj = ones / max(1.0, float(n))
        pj = max(eps, min(1.0 - eps, pj))
        den = ones * (-math.log(pj)) + zeros * (-math.log1p(-pj))
        ratio = num / (den if den > eps else eps)
        stats[int(cid.item())] = {
            'Numerator': num,
            'Denominator': den,
            'Average prediction for denominator': pj,
            'Number of samples with label 1': int(round(ones)),
            'Number of samples with label 0': int(round(zeros)),
            'Numerator/denominator': ratio,
        }
    return stats

def plot_le_stats_per_condition(stats: dict, title: str = 'Localized Entropy terms per condition'):
    conds = sorted(stats.keys())
    nums = [stats[c]['Numerator'] for c in conds]
    dens = [stats[c]['Denominator'] for c in conds]
    ratios = [stats[c]['Numerator/denominator'] for c in conds]
    pj = [stats[c]['Average prediction for denominator'] for c in conds]
    n1 = [stats[c]['Number of samples with label 1'] for c in conds]
    n0 = [stats[c]['Number of samples with label 0'] for c in conds]

    fig, axs = plt.subplots(2, 2, figsize=(12, 8))

    # Ratio
    axs[0, 0].bar(conds, ratios, color='#4477aa')
    axs[0, 0].set_title('Numerator / Denominator')
    axs[0, 0].set_xlabel('Condition')
    axs[0, 0].set_ylabel('Ratio')
    axs[0, 0].grid(True, alpha=0.3)

    # Numerator vs Denominator
    x = np.arange(len(conds))
    width = 0.4
    axs[0, 1].bar(x - width/2, nums, width=width, label='Numerator', color='#66c2a5')
    axs[0, 1].bar(x + width/2, dens, width=width, label='Denominator', color='#fc8d62')
    axs[0, 1].set_xticks(x)
    axs[0, 1].set_xticklabels(conds)
    axs[0, 1].set_title('Numerator vs Denominator')
    axs[0, 1].legend()
    axs[0, 1].grid(True, alpha=0.3)

    # Counts
    axs[1, 0].bar(conds, n0, label='Label 0 count', color='#999999')
    axs[1, 0].bar(conds, n1, bottom=n0, label='Label 1 count', color='#1b9e77')
    axs[1, 0].set_title('Label counts per condition')
    axs[1, 0].set_xlabel('Condition')
    axs[1, 0].set_ylabel('Count')
    axs[1, 0].legend()
    axs[1, 0].grid(True, alpha=0.3)

    # Base rate pj
    axs[1, 1].bar(conds, pj, color='#8da0cb')
    axs[1, 1].set_title('Average prediction for denominator (p_j)')
    axs[1, 1].set_xlabel('Condition')
    axs[1, 1].set_ylabel('p_j')
    axs[1, 1].set_ylim(0, 1)
    axs[1, 1].grid(True, alpha=0.3)

    fig.suptitle(title)
    plt.tight_layout()
    plt.show()

In [None]:

# Collect & plot LE stats per condition on eval set (using logits)
model.eval()
all_logits, all_targets, all_conditions = [], [], []
with torch.no_grad():
    for xb, cb, yb, nw in eval_loader:
        xb = xb.to(device, non_blocking=NON_BLOCKING)
        cb = cb.to(device, non_blocking=NON_BLOCKING)
        yb = yb.to(device, non_blocking=NON_BLOCKING)
        nw = nw.to(device, non_blocking=NON_BLOCKING)
        zb = model(xb, cb)
        all_logits.append(zb.detach().cpu())
        all_targets.append(yb.detach().cpu())
        all_conditions.append(cb.detach().cpu())
z_all = torch.cat(all_logits).view(-1)
y_all = torch.cat(all_targets).view(-1)
c_all = torch.cat(all_conditions).view(-1)

le_stats = collect_le_stats_per_condition_torch(z_all, y_all, c_all, eps=1e-12)

# Print table
print('cond	num	den	avg_p	#y=1	#y=0	ratio')
for cond in sorted(le_stats.keys()):
    s = le_stats[cond]
    print(f"{cond}	{s['Numerator']:.6g}	{s['Denominator']:.6g}	{s['Average prediction for denominator']:.6g}	{s['Number of samples with label 1']}	{s['Number of samples with label 0']}	{s['Numerator/denominator']:.6g}")

plot_le_stats_per_condition(le_stats, title='Localized Entropy terms per condition - Eval set')
