## VAE-AR vs VAE-GP — Summary

We compare two VAE architectures that differ only in how temporal structure is encoded in the latent space.  
VAE-GP relies on a smooth Gaussian-process prior, capturing global correlations but no causal direction, while VAE-AR induces an explicit latent autoregressive structure through Gaussian conditioning.

Empirically, both models achieve similar perplexities, but their behavior diverges at long horizons: the non-AR GP variant tends to collapse, whereas the latent-autoregressive model maintains coherent generations.  
This suggests that while smooth latent geometry is sufficient for local modeling, **causal structure in the latent space is crucial for long-range sequence stability**, even with a non-autoregressive decoder.


# VAE-AR vs VAE-GP  
## Latent Autoregression versus Smooth Gaussian-Process Priors

This repository compares two closely related variational autoencoder (VAE) architectures that differ **only in the way temporal structure is modeled in the latent space**:

- **VAE-AR**: a latent autoregressive model induced by a causal Gaussian-process prior.
- **VAE-GP (non-AR)**: a VAE with a smooth Gaussian-process prior but no latent autoregression.

The goal is not to rank models by a single metric, but to understand **where their behaviors differ, what each model is good at, and which assumptions they encode about sequential structure**.

---

## 1. Conceptual difference

### VAE-AR (Latent Autoregression)

- The latent variables form a **causal sequence**:
  \[
  p(z_{1:L}) = \prod_{t=1}^L p(z_t \mid z_{<t})
  \]
- Autoregression is **not implemented by a neural recurrence**, but emerges from **Gaussian conditioning** under a GP prior.
- Temporal directionality is an explicit **probabilistic property of the latent path**.

**Interpretation:**  
Sequential structure is carried by the *latent trajectory itself*, independently of the decoder.

---

### VAE-GP (Non-Autoregressive)

- The latent variables are jointly distributed under a **smooth Gaussian-process prior**:
  \[
  p(z_{1:L}) = \mathcal{N}(0, K)
  \]
- No causal factorization is enforced in the latent space.
- Correlations exist, but they are **global and non-directional**.

**Interpretation:**  
The model captures *global smoothness*, but not causal or step-by-step dependence.

---

## 2. Empirical comparison (WikiText-2)

| Aspect | VAE-AR | VAE-GP (non-AR) |
|------|-------|----------------|
| Validation NLL | comparable | comparable |
| Perplexity | comparable | comparable |
| KL divergence | **high** | **low** |
| Latent autocorrelation | **causal, directional** | smooth, non-causal |
| Mutual information (latent) | higher | lower |
| Short-range generation | good | slightly better |
| Long-range generation | **stable** | prone to collapse |
| Training cost | similar | similar |
| Inference speed | slightly slower | slightly faster |

---

## 3. What really differs

### Perplexity is *not* the discriminating factor
Both models achieve similar token-level perplexities.  
This shows that **local likelihood metrics are insufficient** to assess sequential coherence.

### Latent structure is the key difference
- **VAE-GP** minimizes KL by keeping latents close to a smooth prior.
- **VAE-AR** accepts a higher KL cost to encode **causal latent trajectories**.

This trade-off directly affects long-horizon behavior.

---

## 4. Strengths and limitations

### VAE-AR — strengths
- Enforces **latent causality** without token-level autoregression.
- Produces **stable long-range generations**.
- Explicitly models temporal structure in latent space.

### VAE-AR — limitations
- Higher KL cost.
- Slightly more complex inference.
- Perplexity remains higher than highly optimized Transformers.

---

### VAE-GP — strengths
- Simple, smooth latent geometry.
- Lower KL divergence.
- Efficient for short-range or interpolation tasks.

### VAE-GP — limitations
- No causal latent mechanism.
- Long-range generations may collapse despite good local metrics.

---

## 5. Takeaway

**VAE-AR and VAE-GP do not solve the same problem.**

- VAE-GP models *smoothness*.
- VAE-AR models *sequential causality*.

This comparison supports the idea that **part of sequence modeling can be shifted from token-level autoregression to the probabilistic geometry of latent space**, but only if that geometry encodes causal structure.

---

## 6. Status

This comparison is intended as a **proof-of-concept study**.
The models are deliberately constrained to highlight structural effects rather than to compete with large-scale Transformer baselines.



In [1]:
# VAE on WikiText-2 blocks — AR PRIOR ONLY
# "NO-DRAMA" + MANY METRICS + PERF/GPU
# Run this cell alone in a fresh notebook runtime for clean GPU measurements.

import math
import os
import time
import json
import random
from dataclasses import dataclass, asdict
from typing import Dict, Tuple, List, Any, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from datasets import load_dataset
from transformers import AutoTokenizer


# Environment / warnings

os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True



# Config

@dataclass
class CFG:
    seed: int = 0
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Data
    dataset_name: str = "wikitext"
    dataset_config: str = "wikitext-2-raw-v1"
    tokenizer_name: str = "gpt2"
    block_size: int = 256
    batch_size: int = 16
    num_workers: int = 2
    pin_memory: bool = True

    # Model
    vocab_size: int = 50257  # overwritten after tokenizer load
    d_model: int = 384
    n_layers: int = 6
    n_heads: int = 6
    dropout: float = 0.1

    # Latent
    z_dim: int = 64
    n_z_samples: int = 1  # MC samples for KL estimate

    # Training
    lr: float = 3e-4
    weight_decay: float = 0.01
    max_steps: int = 6000
    warmup_steps: int = 300
    grad_clip: float = 1.0
    eval_every: int = 400
    log_every: int = 50
    amp: bool = True

    # KL anneal
    beta_start: float = 0.0
    beta_end: float = 1.0
    beta_warmup_steps: int = 1500

    # AR prior
    ar_init_rho: float = 0.95
    ar_sigma: float = 0.5

    # Eval controls
    eval_max_batches: int = 80
    eval_train_batches: int = 20
    eval_gen_prompts: int = 12
    eval_gen_max_new: int = 96
    eval_gen_top_k: int = 50

    # ECE calibration
    ece_bins: int = 15

    # Bits-back diagnostics
    kldim_eps: float = 0.01  # threshold for "inactive" dims (per-token)

    # MI proxy (batch-based, subsampled for cost)
    mi_max_components: int = 512   # max components in mixture for q(z)
    mi_max_points: int = 256       # max sample points z to evaluate
    mi_seed: int = 0

    # Rate–Distortion points
    rd_betas: Tuple[float, ...] = (0.0, 0.1, 0.5, 1.0, 2.0)

    # Logging / saving
    out_dir: str = "runs"
    run_name: str = "wt2_latentAR"


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)



# Dataset

class LMBlocks(Dataset):
    def __init__(self, token_ids: List[int], block_size: int):
        self.block_size = block_size
        n = (len(token_ids) - 1) // block_size
        self.data = token_ids[: n * block_size + 1]

    def __len__(self):
        return (len(self.data) - 1) // self.block_size

    def __getitem__(self, idx):
        i = idx * self.block_size
        x = torch.tensor(self.data[i : i + self.block_size], dtype=torch.long)
        y = torch.tensor(self.data[i + 1 : i + self.block_size + 1], dtype=torch.long)
        return x, y


def tokenize_split_streaming(tokenizer, texts: List[str], chunk_chars: int = 200_000) -> List[int]:
    ids: List[int] = []
    buf: List[str] = []
    cur_len = 0
    sep = "\n\n"

    for t in texts:
        if not t:
            continue
        add = t + sep
        buf.append(add)
        cur_len += len(add)
        if cur_len >= chunk_chars:
            chunk = "".join(buf)
            ids.extend(tokenizer.encode(chunk))
            buf = []
            cur_len = 0

    if buf:
        chunk = "".join(buf)
        ids.extend(tokenizer.encode(chunk))

    return ids


def load_wt2_blocks(cfg: CFG, tokenizer) -> Tuple[DataLoader, DataLoader]:
    ds = load_dataset(cfg.dataset_name, cfg.dataset_config)

    train_ids = tokenize_split_streaming(tokenizer, ds["train"]["text"])
    val_ids = tokenize_split_streaming(tokenizer, ds["validation"]["text"])

    train_set = LMBlocks(train_ids, cfg.block_size)
    val_set = LMBlocks(val_ids, cfg.block_size)

    train_loader = DataLoader(
        train_set,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory,
        drop_last=True,
        persistent_workers=(cfg.num_workers > 0),
    )
    val_loader = DataLoader(
        val_set,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory,
        drop_last=False,
        persistent_workers=(cfg.num_workers > 0),
    )
    return train_loader, val_loader



# Transformer blocks

class TransformerEncoder(nn.Module):
    def __init__(self, d_model, n_heads, n_layers, dropout):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.enc = nn.TransformerEncoder(layer, num_layers=n_layers)

    def forward(self, x, src_key_padding_mask=None):
        return self.enc(x, src_key_padding_mask=src_key_padding_mask)


class TransformerDecoderLM(nn.Module):
    def __init__(self, d_model, n_heads, n_layers, dropout):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.dec = nn.TransformerEncoder(layer, num_layers=n_layers)

    def forward(self, x):
        _, T, _ = x.shape
        attn_mask = torch.full((T, T), float("-inf"), device=x.device)
        attn_mask = torch.triu(attn_mask, diagonal=1)
        return self.dec(x, mask=attn_mask)



# Priors: AR only

class PriorBase(nn.Module):
    def log_p(self, z: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def log_p_per_dim(self, z: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    @torch.no_grad()
    def sample(self, B: int, T: int, Dz: int, device: str) -> torch.Tensor:
        raise NotImplementedError


class PriorAR(PriorBase):
    """
    p(z1)=N(0,I), p(zt|z_{t-1})=N(rho z_{t-1}, sigma^2 I)
    rho learnable scalar.
    """
    def __init__(self, Dz: int, rho_init: float, sigma: float):
        super().__init__()
        self.logit_rho = nn.Parameter(torch.logit(torch.tensor(float(rho_init))))
        self.sigma = float(sigma)
        self.Dz = Dz

    def rho(self):
        return torch.sigmoid(self.logit_rho).clamp(1e-4, 0.9999)

    def log_p(self, z):
        B, T, Dz = z.shape
        lp1 = (-0.5 * (z[:, 0] ** 2 + math.log(2 * math.pi))).sum(dim=1)
        if T == 1:
            return lp1
        rho = self.rho()
        sigma2 = self.sigma ** 2
        resid = z[:, 1:] - rho * z[:, :-1]
        lp = (-0.5 * ((resid**2) / sigma2 + math.log(2 * math.pi * sigma2))).sum(dim=(1, 2))
        return lp1 + lp

    def log_p_per_dim(self, z):
        B, T, Dz = z.shape
        lp1 = (-0.5 * (z[:, 0] ** 2 + math.log(2 * math.pi)))  # (B,Dz)
        if T == 1:
            return lp1
        rho = self.rho()
        sigma2 = self.sigma ** 2
        resid = z[:, 1:] - rho * z[:, :-1]  # (B,T-1,Dz)
        lp = -0.5 * ((resid**2) / sigma2 + math.log(2 * math.pi * sigma2))  # (B,T-1,Dz)
        return lp1 + lp.sum(dim=1)

    @torch.no_grad()
    def sample(self, B, T, Dz, device):
        rho = float(self.rho().item())
        z = torch.zeros(B, T, Dz, device=device)
        z[:, 0] = torch.randn(B, Dz, device=device)
        for t in range(1, T):
            z[:, t] = rho * z[:, t - 1] + self.sigma * torch.randn(B, Dz, device=device)
        return z



# VAE model

class VAETextLM(nn.Module):
    def __init__(self, cfg: CFG, prior: PriorBase):
        super().__init__()
        self.cfg = cfg
        self.prior = prior

        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.block_size, cfg.d_model)
        self.drop = nn.Dropout(cfg.dropout)

        self.encoder = TransformerEncoder(cfg.d_model, cfg.n_heads, cfg.n_layers, cfg.dropout)

        self.to_mu = nn.Linear(cfg.d_model, cfg.z_dim)
        self.to_logvar = nn.Linear(cfg.d_model, cfg.z_dim)

        self.z_proj = nn.Linear(cfg.z_dim, cfg.d_model)
        self.decoder = TransformerDecoderLM(cfg.d_model, cfg.n_heads, cfg.n_layers, cfg.dropout)

        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

    def encode(self, x_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, T = x_ids.shape
        tok = self.tok_emb(x_ids)
        pos = self.pos_emb(torch.arange(T, device=x_ids.device))[None, :, :]
        h = self.drop(tok + pos)
        h = self.encoder(h)
        mu = self.to_mu(h)
        logvar = self.to_logvar(h).clamp(-12.0, 6.0)
        return mu, logvar

    def reparam(self, mu, logvar, n_samples: int):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn((n_samples,) + mu.shape, device=mu.device)
        return mu[None] + eps * std[None]

    def log_q_total(self, z, mu, logvar) -> torch.Tensor:
        var = torch.exp(logvar)
        log2pi = math.log(2 * math.pi)
        diff2 = (z - mu[None]) ** 2
        lq = -0.5 * (diff2 / var[None] + logvar[None] + log2pi)
        return lq.sum(dim=(2, 3))

    def log_q_per_dim(self, z_btD: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        var = torch.exp(logvar)
        log2pi = math.log(2 * math.pi)
        diff2 = (z_btD - mu) ** 2
        lq = -0.5 * (diff2 / var + logvar + log2pi)
        return lq.sum(dim=1)

    def decode_logits(self, x_ids: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        B, T = x_ids.shape
        tok = self.tok_emb(x_ids)
        pos = self.pos_emb(torch.arange(T, device=x_ids.device))[None, :, :]
        h = self.drop(tok + pos + self.z_proj(z))
        h = self.decoder(h)
        return self.lm_head(h)

    def forward(self, x_ids: torch.Tensor, y_ids: torch.Tensor, beta: float) -> Dict[str, torch.Tensor]:
        B, T = x_ids.shape
        mu, logvar = self.encode(x_ids)
        zS = self.reparam(mu, logvar, self.cfg.n_z_samples)

        nll_list, logq_list, logp_list = [], [], []
        last_logits = None
        last_z = None

        for s in range(self.cfg.n_z_samples):
            z = zS[s]
            logits = self.decode_logits(x_ids, z)
            last_logits = logits
            last_z = z

            nll = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                y_ids.view(-1),
                reduction="none",
            ).view(B, T).sum(dim=1)
            nll_list.append(nll)

            logq = self.log_q_total(zS[s:s+1], mu, logvar)[0]
            logq_list.append(logq)

            logp = self.prior.log_p(z)
            logp_list.append(logp)

        nll = torch.stack(nll_list, dim=0).mean(dim=0)
        logq = torch.stack(logq_list, dim=0).mean(dim=0)
        logp = torch.stack(logp_list, dim=0).mean(dim=0)

        kl = (logq - logp)
        elbo = -(nll + beta * kl)

        nll_tok = nll.mean() / T
        kl_tok = kl.mean() / T
        elbo_tok = elbo.mean() / T

        ppl = torch.exp(nll_tok.detach())
        bits_per_tok = (nll_tok.detach() / math.log(2.0))

        loss = -(elbo_tok)
        return {
            "loss": loss,
            "elbo_tok": elbo_tok.detach(),
            "nll_tok": nll_tok.detach(),
            "kl_tok": kl_tok.detach(),
            "ppl": ppl.detach(),
            "bits_per_tok": bits_per_tok.detach(),
            "mu": mu,
            "logvar": logvar,
            "z_last": last_z,
            "logits_last": last_logits,
        }

    @torch.no_grad()
    def generate(self, prompt_ids: torch.Tensor, max_new_tokens: int = 80, top_k: int = 0) -> torch.Tensor:
        self.eval()
        device = prompt_ids.device
        B, Tp = prompt_ids.shape
        total_T = min(self.cfg.block_size, Tp + max_new_tokens)

        out = prompt_ids.clone()
        z = self.prior.sample(B, total_T, self.cfg.z_dim, device=device)

        for _ in range(max_new_tokens):
            Tcur = out.size(1)
            if Tcur >= total_T:
                break
            x = out[:, :Tcur]
            zcur = z[:, :Tcur]
            logits = self.decode_logits(x, zcur)
            next_logits = logits[:, -1, :]

            if top_k and top_k > 0:
                vals, idx = torch.topk(next_logits, k=top_k, dim=-1)
                probs = torch.zeros_like(next_logits).scatter_(-1, idx, F.softmax(vals, dim=-1))
            else:
                probs = F.softmax(next_logits, dim=-1)

            next_id = torch.multinomial(probs, num_samples=1)
            out = torch.cat([out, next_id], dim=1)

        return out



# Perf / GPU utilities

def _now():
    return time.perf_counter()

def sync_if_cuda(device: str):
    if isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available():
        torch.cuda.synchronize()

def cuda_mem_reset(device: str):
    if isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

@torch.no_grad()
def cuda_mem_snapshot_mb(device: str) -> Dict[str, float]:
    if not (isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available()):
        return {"peak_alloc_MB": float("nan"), "peak_reserved_MB": float("nan")}
    peak_alloc = torch.cuda.max_memory_allocated() / (1024**2)
    peak_reserved = torch.cuda.max_memory_reserved() / (1024**2)
    return {"peak_alloc_MB": float(peak_alloc), "peak_reserved_MB": float(peak_reserved)}

class RunningMean:
    def __init__(self):
        self.sum = 0.0
        self.n = 0
    def add(self, x: float):
        if math.isfinite(x):
            self.sum += float(x)
            self.n += 1
    def mean(self) -> float:
        return self.sum / max(1, self.n)



# Metric helpers

@torch.no_grad()
def logits_entropy_and_topk_mass(logits: torch.Tensor, top_k: int = 50) -> Dict[str, float]:
    probs = F.softmax(logits, dim=-1)
    logp = torch.log(probs.clamp_min(1e-12))
    ent = -(probs * logp).sum(dim=-1)
    ent_mean = ent.mean().item()

    if top_k and top_k > 0:
        V = probs.size(-1)
        vals, _ = torch.topk(probs, k=min(top_k, V), dim=-1)
        topk_mass = vals.sum(dim=-1).mean().item()
    else:
        topk_mass = float("nan")

    return {"tok_entropy_mean": float(ent_mean), "topk_mass_mean": float(topk_mass)}

@torch.no_grad()
def ece_from_logits(logits: torch.Tensor, targets: torch.Tensor, n_bins: int = 15) -> float:
    probs = F.softmax(logits, dim=-1)
    conf, pred = probs.max(dim=-1)
    acc = (pred == targets).float()

    conf = conf.reshape(-1)
    acc = acc.reshape(-1)

    bins = torch.linspace(0, 1, n_bins + 1, device=logits.device)
    ece = torch.zeros((), device=logits.device)
    for i in range(n_bins):
        lo, hi = bins[i], bins[i + 1]
        mask = (conf > lo) & (conf <= hi) if i > 0 else (conf >= lo) & (conf <= hi)
        if mask.any():
            ece = ece + (mask.float().mean()) * (acc[mask].mean() - conf[mask].mean()).abs()
    return float(ece.item())

@torch.no_grad()
def posterior_collapse_ratio(mu: torch.Tensor, logvar: torch.Tensor) -> Dict[str, float]:
    std = torch.exp(0.5 * logvar)
    mu_abs_mean = mu.abs().mean(dim=(0, 1))
    std_mean = std.mean(dim=(0, 1))
    collapsed = ((mu_abs_mean < 0.02) & ((std_mean - 1.0).abs() < 0.05)).float().mean().item()
    return {
        "post_std_mean": float(std.mean().item()),
        "post_mu_abs_mean": float(mu.abs().mean().item()),
        "collapse_dim_frac": float(collapsed),
    }

@torch.no_grad()
def latent_autocorr(x: torch.Tensor, lag: int = 1) -> float:
    if x.size(1) <= lag:
        return float("nan")
    a = x[:, :-lag, :].reshape(-1, x.size(-1))
    b = x[:, lag:, :].reshape(-1, x.size(-1))
    a = a - a.mean(dim=0, keepdim=True)
    b = b - b.mean(dim=0, keepdim=True)
    cov = (a * b).mean(dim=0)
    va = (a * a).mean(dim=0).clamp_min(1e-8)
    vb = (b * b).mean(dim=0).clamp_min(1e-8)
    return float((cov / torch.sqrt(va * vb)).mean().item())

@torch.no_grad()
def kl_per_dim_stats(model: VAETextLM, mu: torch.Tensor, logvar: torch.Tensor, z: torch.Tensor, eps: float) -> Dict[str, float]:
    B, T, Dz = z.shape
    logq_bd = model.log_q_per_dim(z, mu, logvar)      # (B,Dz)
    logp_bd = model.prior.log_p_per_dim(z)            # (B,Dz)
    kld_bd = (logq_bd - logp_bd) / max(1, T)          # per-token
    kld_d = kld_bd.mean(dim=0).detach().cpu().numpy()
    return {
        "kldim_mean": float(np.mean(kld_d)),
        "kldim_median": float(np.median(kld_d)),
        "kldim_max": float(np.max(kld_d)),
        "kldim_frac_below_eps": float(np.mean(kld_d < eps)),
    }

@torch.no_grad()
def mi_proxy_batch(mu: torch.Tensor, logvar: torch.Tensor, z: torch.Tensor, max_components: int, max_points: int, seed: int = 0) -> float:
    B, T, Dz = z.shape
    C = B * T
    if C == 0:
        return float("nan")
    rng = np.random.default_rng(seed)

    mu_c = mu.reshape(C, Dz)
    lv_c = logvar.reshape(C, Dz)
    z_c  = z.reshape(C, Dz)

    comp_idx = np.arange(C)
    pt_idx = np.arange(C)
    if C > max_components:
        comp_idx = rng.choice(comp_idx, size=max_components, replace=False)
    if C > max_points:
        pt_idx = rng.choice(pt_idx, size=max_points, replace=False)

    mu_s = mu_c[torch.as_tensor(comp_idx, device=mu.device)]
    lv_s = lv_c[torch.as_tensor(comp_idx, device=mu.device)]
    var_s = torch.exp(lv_s).clamp_min(1e-12)

    z_pts  = z_c [torch.as_tensor(pt_idx, device=mu.device)]
    mu_pts = mu_c[torch.as_tensor(pt_idx, device=mu.device)]
    lv_pts = lv_c[torch.as_tensor(pt_idx, device=mu.device)]
    var_pts = torch.exp(lv_pts).clamp_min(1e-12)

    log2pi = math.log(2 * math.pi)
    lq_cond = -0.5 * (((z_pts - mu_pts) ** 2) / var_pts + lv_pts + log2pi).sum(dim=-1)
    eq_logq_z_given_x = lq_cond.mean()

    z_exp = z_pts[:, None, :]
    mu_exp = mu_s[None, :, :]
    lv_exp = lv_s[None, :, :]
    var_exp = var_s[None, :, :]

    lcomp = -0.5 * (((z_exp - mu_exp) ** 2) / var_exp + lv_exp + log2pi).sum(dim=-1)
    lmix = torch.logsumexp(lcomp, dim=1) - math.log(lcomp.size(1))
    eq_logq_z = lmix.mean()

    return float((eq_logq_z_given_x - eq_logq_z).item())

def rd_points(nll_tok: float, kl_tok: float, betas: Tuple[float, ...]) -> Dict[str, float]:
    out = {}
    for b in betas:
        out[f"rd_elbo_tok_beta_{b:g}"] = -(nll_tok + b * kl_tok)
    out["rd_nll_tok"] = float(nll_tok)
    out["rd_kl_tok"] = float(kl_tok)
    return out


def distinct_ngrams(token_seq: List[int], n: int) -> float:
    if len(token_seq) < n or n <= 0:
        return 0.0
    total = 0
    uniq = set()
    for i in range(len(token_seq) - n + 1):
        uniq.add(tuple(token_seq[i:i+n]))
        total += 1
    return (len(uniq) / total) if total > 0 else 0.0

def repetition_rate(token_seq: List[int], n: int) -> float:
    if len(token_seq) < n or n <= 0:
        return 0.0
    total = 0
    counts = {}
    for i in range(len(token_seq) - n + 1):
        ng = tuple(token_seq[i:i+n])
        counts[ng] = counts.get(ng, 0) + 1
        total += 1
    repeated = sum(c for c in counts.values() if c > 1)
    return repeated / total if total > 0 else 0.0


@torch.no_grad()
def generation_metrics(model: VAETextLM, tokenizer, device: str, n_prompts: int, max_new: int, top_k: int) -> Dict[str, Any]:
    model.eval()
    prompts = [
        "The meaning of life is",
        "In the middle of the night",
        "The government announced that",
        "A new theory suggests",
        "Once upon a time",
        "The experiment shows",
        "In a shocking discovery",
        "The book describes",
        "Scientists found that",
        "The president said",
        "In the future,",
        "The story begins",
    ]

    cuda_mem_reset(device)
    sync_if_cuda(device)
    t0 = _now()
    gen_new_tokens_total = 0

    per_prompt: List[Dict[str, Any]] = []
    decoded_samples: List[str] = []

    for i in range(n_prompts):
        p = prompts[i % len(prompts)]
        p_ids = tokenizer(p, return_tensors="pt").input_ids.to(device)
        if p_ids.size(1) > model.cfg.block_size // 2:
            p_ids = p_ids[:, : model.cfg.block_size // 2]

        out_ids = model.generate(p_ids, max_new_tokens=max_new, top_k=top_k)
        seq = out_ids[0].tolist()

        prompt_len = int(p_ids.size(1))
        new_tokens = max(0, len(seq) - prompt_len)
        gen_new_tokens_total += new_tokens

        st = {
            "prompt": p,
            "prompt_len": prompt_len,
            "total_len": len(seq),
            "new_tokens": new_tokens,
            "rep2": repetition_rate(seq, 2),
            "rep3": repetition_rate(seq, 3),
            "distinct2": distinct_ngrams(seq, 2),
            "distinct3": distinct_ngrams(seq, 3),
        }

        txt = tokenizer.decode(seq[: min(len(seq), 250)], skip_special_tokens=True)
        st["sample"] = txt
        per_prompt.append(st)

        if i < 3:
            decoded_samples.append(txt)

    sync_if_cuda(device)
    dt = max(1e-9, (_now() - t0))
    mem = cuda_mem_snapshot_mb(device)

    rep2 = float(np.mean([p["rep2"] for p in per_prompt])) if per_prompt else 0.0
    rep3 = float(np.mean([p["rep3"] for p in per_prompt])) if per_prompt else 0.0
    distinct2 = float(np.mean([p["distinct2"] for p in per_prompt])) if per_prompt else 0.0
    distinct3 = float(np.mean([p["distinct3"] for p in per_prompt])) if per_prompt else 0.0
    lengths = [p["total_len"] for p in per_prompt]
    newlens = [p["new_tokens"] for p in per_prompt]

    return {
        "gen_distinct2": distinct2,
        "gen_distinct3": distinct3,
        "gen_rep2": rep2,
        "gen_rep3": rep3,
        "gen_len_mean": float(np.mean(lengths)) if lengths else 0.0,
        "gen_len_std": float(np.std(lengths)) if lengths else 0.0,
        "gen_new_tokens_mean": float(np.mean(newlens)) if newlens else 0.0,
        "gen_new_tokens_total": int(gen_new_tokens_total),

        "time_gen_seconds_total": float(dt),
        "time_gen_tokens_per_s": float(gen_new_tokens_total / dt),
        "gpu_gen_peak_alloc_MB": mem["peak_alloc_MB"],
        "gpu_gen_peak_reserved_MB": mem["peak_reserved_MB"],

        "gen_per_prompt": per_prompt,
        "gen_sample_0": decoded_samples[0] if len(decoded_samples) > 0 else "",
        "gen_sample_1": decoded_samples[1] if len(decoded_samples) > 1 else "",
        "gen_sample_2": decoded_samples[2] if len(decoded_samples) > 2 else "",
    }


@torch.no_grad()
def eval_many_metrics(
    model: VAETextLM,
    loader: DataLoader,
    device: str,
    beta: float,
    max_batches: int,
    ece_bins: int,
    gen_do: bool,
    tokenizer=None,
    gen_prompts: int = 12,
    gen_max_new: int = 96,
    gen_top_k: int = 50,
    rd_betas: Tuple[float, ...] = (0.0, 1.0),
    measure_perf: bool = True,
    kldim_eps: float = 0.01,
    mi_max_components: int = 512,
    mi_max_points: int = 256,
    mi_seed: int = 0,
) -> Dict[str, Any]:
    model.eval()

    acc = {
        "nll_tok": 0.0,
        "kl_tok": 0.0,
        "elbo_tok": 0.0,
        "tok_entropy_mean": 0.0,
        "topk_mass_mean": 0.0,
        "ece": 0.0,
        "post_std_mean": 0.0,
        "post_mu_abs_mean": 0.0,
        "collapse_dim_frac": 0.0,
        "mu_ac1": 0.0,
        "mu_ac5": 0.0,
        "z_ac1": 0.0,
        "z_ac5": 0.0,
        "kldim_mean": 0.0,
        "kldim_median": 0.0,
        "kldim_max": 0.0,
        "kldim_frac_below_eps": 0.0,
        "mi_proxy": 0.0,
    }
    n = 0

    step_ms = RunningMean()
    tokens_per_s = RunningMean()
    total_tokens = 0

    if measure_perf:
        cuda_mem_reset(device)

    for i, (x, y) in enumerate(loader):
        if i >= max_batches:
            break
        x = x.to(device)
        y = y.to(device)

        if measure_perf:
            sync_if_cuda(device)
            t0 = _now()

        out = model(x, y, beta=beta)

        if measure_perf:
            sync_if_cuda(device)
            dt = max(1e-9, (_now() - t0))
            step_ms.add(1000.0 * dt)
            B, T = x.shape
            tok = int(B * T)
            total_tokens += tok
            tokens_per_s.add(tok / dt)

        nll_tok = float(out["nll_tok"])
        kl_tok = float(out["kl_tok"])
        elbo_tok = float(out["elbo_tok"])
        if not (math.isfinite(nll_tok) and math.isfinite(kl_tok) and math.isfinite(elbo_tok)):
            continue

        logits = out["logits_last"]
        mu = out["mu"].detach()
        logvar = out["logvar"].detach()
        zlast = out["z_last"].detach()

        ent_topk = logits_entropy_and_topk_mass(logits, top_k=gen_top_k)
        ece = ece_from_logits(logits, y, n_bins=ece_bins)
        post = posterior_collapse_ratio(mu, logvar)

        mu_ac1 = latent_autocorr(mu, lag=1)
        mu_ac5 = latent_autocorr(mu, lag=5)
        z_ac1 = latent_autocorr(zlast, lag=1)
        z_ac5 = latent_autocorr(zlast, lag=5)

        kld = kl_per_dim_stats(model, mu, logvar, zlast, eps=kldim_eps)
        mi = mi_proxy_batch(mu, logvar, zlast, max_components=mi_max_components, max_points=mi_max_points, seed=mi_seed)

        acc["nll_tok"] += nll_tok
        acc["kl_tok"] += kl_tok
        acc["elbo_tok"] += elbo_tok
        acc["tok_entropy_mean"] += ent_topk["tok_entropy_mean"]
        acc["topk_mass_mean"] += ent_topk["topk_mass_mean"]
        acc["ece"] += ece
        acc["post_std_mean"] += post["post_std_mean"]
        acc["post_mu_abs_mean"] += post["post_mu_abs_mean"]
        acc["collapse_dim_frac"] += post["collapse_dim_frac"]
        acc["mu_ac1"] += mu_ac1
        acc["mu_ac5"] += mu_ac5
        acc["z_ac1"] += z_ac1
        acc["z_ac5"] += z_ac5
        acc["kldim_mean"] += kld["kldim_mean"]
        acc["kldim_median"] += kld["kldim_median"]
        acc["kldim_max"] += kld["kldim_max"]
        acc["kldim_frac_below_eps"] += kld["kldim_frac_below_eps"]
        acc["mi_proxy"] += mi
        n += 1

    if n == 0:
        return {k: float("nan") for k in acc.keys()}

    for k in acc:
        acc[k] /= n

    acc["ppl"] = math.exp(acc["nll_tok"])
    acc["bits_per_tok"] = acc["nll_tok"] / math.log(2.0)
    acc.update(rd_points(acc["nll_tok"], acc["kl_tok"], rd_betas))

    if measure_perf:
        mem = cuda_mem_snapshot_mb(device)
        acc["time_eval_step_ms_mean"] = step_ms.mean()
        acc["time_eval_tokens_per_s_mean"] = tokens_per_s.mean()
        acc["gpu_eval_peak_alloc_MB"] = mem["peak_alloc_MB"]
        acc["gpu_eval_peak_reserved_MB"] = mem["peak_reserved_MB"]
        acc["eval_total_tokens_measured"] = int(total_tokens)

    if gen_do and tokenizer is not None:
        acc.update(
            generation_metrics(model, tokenizer, device=device, n_prompts=gen_prompts, max_new=gen_max_new, top_k=gen_top_k)
        )

    return acc



# Schedules

def lr_schedule(step, cfg: CFG):
    if step < cfg.warmup_steps:
        return cfg.lr * (step / max(1, cfg.warmup_steps))
    progress = (step - cfg.warmup_steps) / max(1, cfg.max_steps - cfg.warmup_steps)
    return cfg.lr * (0.1 + 0.9 * 0.5 * (1.0 + math.cos(math.pi * progress)))

def beta_schedule(step, cfg: CFG):
    if step >= cfg.beta_warmup_steps:
        return cfg.beta_end
    a = step / max(1, cfg.beta_warmup_steps)
    return cfg.beta_start + a * (cfg.beta_end - cfg.beta_start)



# Train (AR only)

def train_ar(cfg: CFG, train_loader, val_loader, tokenizer) -> Dict[str, Any]:
    device = cfg.device
    prior = PriorAR(cfg.z_dim, rho_init=cfg.ar_init_rho, sigma=cfg.ar_sigma)
    model = VAETextLM(cfg, prior).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    use_amp = bool(cfg.amp and device.startswith("cuda"))
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

    os.makedirs(cfg.out_dir, exist_ok=True)
    run_stamp = int(time.time())
    log_path = os.path.join(cfg.out_dir, f"{cfg.run_name}_vae_ar_{run_stamp}.jsonl")

    best_val = float("inf")
    best: Dict[str, Any] = {}

    t0 = time.time()
    pbar = tqdm(total=cfg.max_steps, desc="train[vae_ar]")

    it = iter(train_loader)

    train_step_ms = RunningMean()
    train_tokens_per_s = RunningMean()
    train_tokens_total = 0
    cuda_mem_reset(device)

    for step in range(1, cfg.max_steps + 1):
        try:
            x, y = next(it)
        except StopIteration:
            it = iter(train_loader)
            x, y = next(it)

        x, y = x.to(device), y.to(device)

        beta = beta_schedule(step, cfg)
        lr = lr_schedule(step, cfg)
        for pg in opt.param_groups:
            pg["lr"] = lr

        model.train()
        opt.zero_grad(set_to_none=True)

        if use_amp:
            sync_if_cuda(device)
        t_step0 = _now()

        with torch.amp.autocast("cuda", enabled=use_amp):
            out = model(x, y, beta=beta)
            loss = out["loss"]

        scaler.scale(loss).backward()
        if cfg.grad_clip > 0:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        scaler.step(opt)
        scaler.update()

        if use_amp:
            sync_if_cuda(device)
        dt = max(1e-9, (_now() - t_step0))
        train_step_ms.add(1000.0 * dt)
        B, T = x.shape
        tok = int(B * T)
        train_tokens_total += tok
        train_tokens_per_s.add(tok / dt)

        if step % cfg.log_every == 0:
            pbar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "nll_tok": f"{float(out['nll_tok']):.4f}",
                "kl_tok": f"{float(out['kl_tok']):.4f}",
                "ppl": f"{float(out['ppl']):.2f}",
                "beta": f"{beta:.2f}",
                "lr": f"{lr:.1e}",
            })

        if step % cfg.eval_every == 0 or step == 1:
            mem_train = cuda_mem_snapshot_mb(device)
            train_perf = {
                "time_train_step_ms_mean_window": train_step_ms.mean(),
                "time_train_tokens_per_s_mean_window": train_tokens_per_s.mean(),
                "gpu_train_peak_alloc_MB_window": mem_train["peak_alloc_MB"],
                "gpu_train_peak_reserved_MB_window": mem_train["peak_reserved_MB"],
                "train_tokens_measured_window": int(train_tokens_total),
            }

            train_eval = eval_many_metrics(
                model, train_loader, device=device, beta=1.0,
                max_batches=min(cfg.eval_train_batches, cfg.eval_max_batches),
                ece_bins=cfg.ece_bins,
                gen_do=False,
                rd_betas=cfg.rd_betas,
                measure_perf=True,
                kldim_eps=cfg.kldim_eps,
                mi_max_components=cfg.mi_max_components,
                mi_max_points=cfg.mi_max_points,
                mi_seed=cfg.mi_seed,
                gen_top_k=cfg.eval_gen_top_k,
            )
            val_eval = eval_many_metrics(
                model, val_loader, device=device, beta=1.0,
                max_batches=cfg.eval_max_batches,
                ece_bins=cfg.ece_bins,
                gen_do=True, tokenizer=tokenizer,
                gen_prompts=cfg.eval_gen_prompts,
                gen_max_new=cfg.eval_gen_max_new,
                gen_top_k=cfg.eval_gen_top_k,
                rd_betas=cfg.rd_betas,
                measure_perf=True,
                kldim_eps=cfg.kldim_eps,
                mi_max_components=cfg.mi_max_components,
                mi_max_points=cfg.mi_max_points,
                mi_seed=cfg.mi_seed,
            )

            val_nll = val_eval.get("nll_tok", float("inf"))
            if math.isfinite(val_nll) and (val_nll < best_val):
                best_val = float(val_nll)
                best = {
                    "variant": "vae_ar",
                    "step": step,
                    "wall_s": time.time() - t0,
                    "best_val_nll_tok": best_val,
                    "train_perf_window": train_perf,
                    "train": train_eval,
                    "val": val_eval,
                }

            rec = {
                "variant": "vae_ar",
                "step": step,
                "wall_s": time.time() - t0,
                "lr": float(lr),
                "beta": float(beta),
                "train_perf_window": train_perf,
                "train": train_eval,
                "val": val_eval,
                "best_val_nll_tok_so_far": float(best_val),
            }
            with open(log_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(rec, ensure_ascii=False) + "\n")

            print(
                f"[eval] vae_ar step={step:5d} "
                f"train_tok/s={train_eval.get('time_eval_tokens_per_s_mean', float('nan')):.1f} "
                f"val_tok/s={val_eval.get('time_eval_tokens_per_s_mean', float('nan')):.1f} "
                f"gen_tok/s={val_eval.get('time_gen_tokens_per_s', float('nan')):.1f} "
                f"val_ppl={val_eval.get('ppl', float('nan')):.2f} "
                f"val_nll={val_eval.get('nll_tok', float('nan')):.4f} val_kl={val_eval.get('kl_tok', float('nan')):.4f} "
                f"ece={val_eval.get('ece', float('nan')):.3f} ent={val_eval.get('tok_entropy_mean', float('nan')):.3f} "
                f"mi={val_eval.get('mi_proxy', float('nan')):.3f} "
                f"kldim_med={val_eval.get('kldim_median', float('nan')):.3f} "
                f"train_peakMB={train_perf.get('gpu_train_peak_alloc_MB_window', float('nan')):.1f}"
            )

            train_step_ms = RunningMean()
            train_tokens_per_s = RunningMean()
            train_tokens_total = 0
            cuda_mem_reset(device)

        pbar.update(1)

    pbar.close()

    if not best:
        best = {"variant": "vae_ar", "step": cfg.max_steps, "wall_s": time.time() - t0, "best_val_nll_tok": best_val}
    best["log_path"] = log_path
    return best


def main():
    cfg = CFG()
    set_seed(cfg.seed)

    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name)
    tokenizer.model_max_length = int(1e9)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    cfg.vocab_size = len(tokenizer)

    train_loader, val_loader = load_wt2_blocks(cfg, tokenizer)

    best = train_ar(cfg, train_loader, val_loader, tokenizer)

    os.makedirs(cfg.out_dir, exist_ok=True)
    out_json = os.path.join(cfg.out_dir, f"{cfg.run_name}_BEST_{int(time.time())}.json")
    with open(out_json, "w", encoding="utf-8") as f:
        json.dump({"cfg": asdict(cfg), "best": best}, f, indent=2, ensure_ascii=False)

    print(f"\nSaved BEST summary JSON: {out_json}")
    print(f"Best log JSONL: {best.get('log_path','')}")


if __name__ == "__main__":
    main()


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

train[vae_ar]:   0%|          | 1/6000 [00:15<25:02:07, 15.02s/it]

[eval] vae_ar step=    1 train_tok/s=64052.0 val_tok/s=65908.0 gen_tok/s=276.9 val_ppl=111782.79 val_nll=11.6243 val_kl=648.1745 ece=0.002 ent=10.022 mi=49.299 kldim_med=9.817 train_peakMB=3511.3


train[vae_ar]:   7%|▋         | 400/6000 [01:51<6:18:10,  4.05s/it, loss=12.2158, nll_tok=6.4745, kl_tok=21.5299, ppl=648.41, beta=0.27, lr=3.0e-04]

[eval] vae_ar step=  400 train_tok/s=66070.3 val_tok/s=66008.1 gen_tok/s=287.8 val_ppl=798.93 val_nll=6.6833 val_kl=20.9481 ece=0.018 ent=6.394 mi=0.301 kldim_med=0.327 train_peakMB=3981.0


train[vae_ar]:  13%|█▎        | 800/6000 [03:28<5:44:56,  3.98s/it, loss=17.0916, nll_tok=6.1767, kl_tok=20.4656, ppl=481.39, beta=0.53, lr=2.9e-04]

[eval] vae_ar step=  800 train_tok/s=66113.6 val_tok/s=66118.6 gen_tok/s=305.5 val_ppl=566.15 val_nll=6.3389 val_kl=20.2432 ece=0.022 ent=6.004 mi=0.105 kldim_med=0.316 train_peakMB=3981.0


train[vae_ar]:  20%|██        | 1200/6000 [05:05<5:16:18,  3.95s/it, loss=21.9683, nll_tok=5.9479, kl_tok=20.0255, ppl=382.96, beta=0.80, lr=2.8e-04]

[eval] vae_ar step= 1200 train_tok/s=66086.8 val_tok/s=66174.9 gen_tok/s=310.9 val_ppl=447.52 val_nll=6.1037 val_kl=19.7930 ece=0.022 ent=5.759 mi=0.044 kldim_med=0.309 train_peakMB=3981.0


train[vae_ar]:  27%|██▋       | 1600/6000 [06:42<4:52:11,  3.98s/it, loss=25.1890, nll_tok=5.6690, kl_tok=19.5200, ppl=289.74, beta=1.00, lr=2.7e-04]

[eval] vae_ar step= 1600 train_tok/s=66151.9 val_tok/s=66090.1 gen_tok/s=304.8 val_ppl=390.57 val_nll=5.9676 val_kl=19.4372 ece=0.022 ent=5.632 mi=0.032 kldim_med=0.304 train_peakMB=3981.0


train[vae_ar]:  33%|███▎      | 2000/6000 [08:19<4:23:38,  3.95s/it, loss=24.3795, nll_tok=5.3694, kl_tok=19.0101, ppl=214.73, beta=1.00, lr=2.4e-04]

[eval] vae_ar step= 2000 train_tok/s=66230.1 val_tok/s=66170.3 gen_tok/s=311.0 val_ppl=353.85 val_nll=5.8689 val_kl=19.0480 ece=0.028 ent=5.225 mi=0.028 kldim_med=0.298 train_peakMB=3981.0


train[vae_ar]:  40%|████      | 2400/6000 [09:57<3:59:51,  4.00s/it, loss=23.6080, nll_tok=4.9795, kl_tok=18.6285, ppl=145.40, beta=1.00, lr=2.2e-04]

[eval] vae_ar step= 2400 train_tok/s=65896.5 val_tok/s=65920.1 gen_tok/s=306.8 val_ppl=331.72 val_nll=5.8043 val_kl=18.6132 ece=0.024 ent=5.235 mi=0.029 kldim_med=0.291 train_peakMB=3981.0


train[vae_ar]:  47%|████▋     | 2800/6000 [11:34<3:33:55,  4.01s/it, loss=23.3647, nll_tok=5.0411, kl_tok=18.3235, ppl=154.65, beta=1.00, lr=1.9e-04]

[eval] vae_ar step= 2800 train_tok/s=65994.0 val_tok/s=65938.0 gen_tok/s=302.6 val_ppl=317.58 val_nll=5.7607 val_kl=18.2731 ece=0.031 ent=4.905 mi=0.026 kldim_med=0.285 train_peakMB=3981.0


train[vae_ar]:  53%|█████▎    | 3200/6000 [13:12<3:08:01,  4.03s/it, loss=22.7308, nll_tok=4.9556, kl_tok=17.7752, ppl=141.96, beta=1.00, lr=1.6e-04]

[eval] vae_ar step= 3200 train_tok/s=65985.1 val_tok/s=65989.8 gen_tok/s=298.8 val_ppl=309.96 val_nll=5.7365 val_kl=17.9617 ece=0.030 ent=4.813 mi=0.025 kldim_med=0.281 train_peakMB=3981.0


train[vae_ar]:  60%|██████    | 3600/6000 [14:50<2:41:02,  4.03s/it, loss=22.3814, nll_tok=4.7536, kl_tok=17.6278, ppl=116.00, beta=1.00, lr=1.3e-04]

[eval] vae_ar step= 3600 train_tok/s=65804.8 val_tok/s=65872.5 gen_tok/s=300.6 val_ppl=301.18 val_nll=5.7077 val_kl=17.6996 ece=0.033 ent=4.718 mi=0.025 kldim_med=0.277 train_peakMB=3981.0


train[vae_ar]:  67%|██████▋   | 4000/6000 [16:28<2:13:19,  4.00s/it, loss=22.0798, nll_tok=4.4993, kl_tok=17.5805, ppl=89.95, beta=1.00, lr=1.0e-04] 

[eval] vae_ar step= 4000 train_tok/s=65972.9 val_tok/s=66002.2 gen_tok/s=306.5 val_ppl=300.70 val_nll=5.7061 val_kl=17.4905 ece=0.037 ent=4.618 mi=0.027 kldim_med=0.273 train_peakMB=3981.0


train[vae_ar]:  73%|███████▎  | 4400/6000 [18:06<1:47:10,  4.02s/it, loss=21.8012, nll_tok=4.4362, kl_tok=17.3650, ppl=84.45, beta=1.00, lr=7.9e-05]

[eval] vae_ar step= 4400 train_tok/s=65913.4 val_tok/s=65889.1 gen_tok/s=302.8 val_ppl=301.00 val_nll=5.7071 val_kl=17.3283 ece=0.041 ent=4.491 mi=0.024 kldim_med=0.271 train_peakMB=3981.0


train[vae_ar]:  80%|████████  | 4800/6000 [19:44<1:21:19,  4.07s/it, loss=21.5226, nll_tok=4.3356, kl_tok=17.1870, ppl=76.37, beta=1.00, lr=5.8e-05]

[eval] vae_ar step= 4800 train_tok/s=65712.7 val_tok/s=65836.7 gen_tok/s=293.0 val_ppl=299.34 val_nll=5.7016 val_kl=17.1759 ece=0.039 ent=4.484 mi=0.026 kldim_med=0.268 train_peakMB=3981.0


train[vae_ar]:  87%|████████▋ | 5200/6000 [21:22<53:49,  4.04s/it, loss=21.3161, nll_tok=4.3390, kl_tok=16.9771, ppl=76.63, beta=1.00, lr=4.3e-05]  

[eval] vae_ar step= 5200 train_tok/s=65796.9 val_tok/s=65949.7 gen_tok/s=298.9 val_ppl=301.90 val_nll=5.7101 val_kl=17.0925 ece=0.047 ent=4.362 mi=0.024 kldim_med=0.267 train_peakMB=3981.0


train[vae_ar]:  93%|█████████▎| 5600/6000 [23:00<26:48,  4.02s/it, loss=21.4050, nll_tok=4.4308, kl_tok=16.9742, ppl=84.00, beta=1.00, lr=3.3e-05]

[eval] vae_ar step= 5600 train_tok/s=65874.8 val_tok/s=65855.0 gen_tok/s=303.2 val_ppl=301.89 val_nll=5.7101 val_kl=17.0195 ece=0.045 ent=4.401 mi=0.025 kldim_med=0.266 train_peakMB=3981.0


train[vae_ar]: 100%|██████████| 6000/6000 [24:38<00:00,  4.06it/s, loss=21.1580, nll_tok=4.1740, kl_tok=16.9840, ppl=64.98, beta=1.00, lr=3.0e-05]

[eval] vae_ar step= 6000 train_tok/s=65738.4 val_tok/s=65848.4 gen_tok/s=306.6 val_ppl=302.95 val_nll=5.7136 val_kl=16.9583 ece=0.044 ent=4.379 mi=0.024 kldim_med=0.265 train_peakMB=3981.0

Saved BEST summary JSON: runs/wt2_latentAR_BEST_1765784520.json
Best log JSONL: runs/wt2_latentAR_vae_ar_1765783041.jsonl





In [1]:
# GP-VAE on WikiText-2 blocks — GP PRIOR ONLY
# "NO-DRAMA" + MANY METRICS + PERF/GPU

import math
import os
import time
import json
import random
from dataclasses import dataclass, asdict
from typing import Dict, Tuple, List, Any, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from datasets import load_dataset
from transformers import AutoTokenizer


# Environment / warnings

os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True



# Config

@dataclass
class CFG:
    seed: int = 0
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Data
    dataset_name: str = "wikitext"
    dataset_config: str = "wikitext-2-raw-v1"
    tokenizer_name: str = "gpt2"
    block_size: int = 256
    batch_size: int = 16
    num_workers: int = 2
    pin_memory: bool = True

    # Model
    vocab_size: int = 50257  # overwritten after tokenizer load
    d_model: int = 384
    n_layers: int = 6
    n_heads: int = 6
    dropout: float = 0.1

    # Latent
    z_dim: int = 64
    n_z_samples: int = 1  # MC samples for KL estimate

    # Training
    lr: float = 3e-4
    weight_decay: float = 0.01
    max_steps: int = 6000
    warmup_steps: int = 300
    grad_clip: float = 1.0
    eval_every: int = 400
    log_every: int = 50
    amp: bool = True

    # KL anneal
    beta_start: float = 0.0
    beta_end: float = 1.0
    beta_warmup_steps: int = 1500

    # GP prior (RBF kernel)
    gp_lengthscale: float = 25.0
    gp_sigma: float = 1.0
    gp_jitter: float = 1e-3

    # Eval controls
    eval_max_batches: int = 80
    eval_train_batches: int = 20
    eval_gen_prompts: int = 12
    eval_gen_max_new: int = 96
    eval_gen_top_k: int = 50

    # ECE calibration
    ece_bins: int = 15

    # Bits-back diagnostics
    kldim_eps: float = 0.01

    # MI proxy
    mi_max_components: int = 512
    mi_max_points: int = 256
    mi_seed: int = 0

    # Rate–Distortion points
    rd_betas: Tuple[float, ...] = (0.0, 0.1, 0.5, 1.0, 2.0)

    # Logging / saving
    out_dir: str = "runs"
    run_name: str = "wt2_latentGP"


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)



# Dataset

class LMBlocks(Dataset):
    def __init__(self, token_ids: List[int], block_size: int):
        self.block_size = block_size
        n = (len(token_ids) - 1) // block_size
        self.data = token_ids[: n * block_size + 1]

    def __len__(self):
        return (len(self.data) - 1) // self.block_size

    def __getitem__(self, idx):
        i = idx * self.block_size
        x = torch.tensor(self.data[i : i + self.block_size], dtype=torch.long)
        y = torch.tensor(self.data[i + 1 : i + self.block_size + 1], dtype=torch.long)
        return x, y


def tokenize_split_streaming(tokenizer, texts: List[str], chunk_chars: int = 200_000) -> List[int]:
    ids: List[int] = []
    buf: List[str] = []
    cur_len = 0
    sep = "\n\n"

    for t in texts:
        if not t:
            continue
        add = t + sep
        buf.append(add)
        cur_len += len(add)
        if cur_len >= chunk_chars:
            chunk = "".join(buf)
            ids.extend(tokenizer.encode(chunk))
            buf = []
            cur_len = 0

    if buf:
        chunk = "".join(buf)
        ids.extend(tokenizer.encode(chunk))

    return ids


def load_wt2_blocks(cfg: CFG, tokenizer) -> Tuple[DataLoader, DataLoader]:
    ds = load_dataset(cfg.dataset_name, cfg.dataset_config)

    train_ids = tokenize_split_streaming(tokenizer, ds["train"]["text"])
    val_ids = tokenize_split_streaming(tokenizer, ds["validation"]["text"])

    train_set = LMBlocks(train_ids, cfg.block_size)
    val_set = LMBlocks(val_ids, cfg.block_size)

    train_loader = DataLoader(
        train_set,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory,
        drop_last=True,
        persistent_workers=(cfg.num_workers > 0),
    )
    val_loader = DataLoader(
        val_set,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory,
        drop_last=False,
        persistent_workers=(cfg.num_workers > 0),
    )
    return train_loader, val_loader



# Transformer blocks

class TransformerEncoder(nn.Module):
    def __init__(self, d_model, n_heads, n_layers, dropout):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.enc = nn.TransformerEncoder(layer, num_layers=n_layers)

    def forward(self, x, src_key_padding_mask=None):
        return self.enc(x, src_key_padding_mask=src_key_padding_mask)


class TransformerDecoderLM(nn.Module):
    def __init__(self, d_model, n_heads, n_layers, dropout):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.dec = nn.TransformerEncoder(layer, num_layers=n_layers)

    def forward(self, x):
        _, T, _ = x.shape
        attn_mask = torch.full((T, T), float("-inf"), device=x.device)
        attn_mask = torch.triu(attn_mask, diagonal=1)
        return self.dec(x, mask=attn_mask)



# Prior: GP only

class PriorBase(nn.Module):
    def log_p(self, z: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def log_p_per_dim(self, z: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    @torch.no_grad()
    def sample(self, B: int, T: int, Dz: int, device: str) -> torch.Tensor:
        raise NotImplementedError


class PriorGP(PriorBase):
    """
    GP prior per latent dim: z_{:, :, d} ~ N(0, K_T) with RBF kernel over time.
    Kernel built dynamically per T.
    """
    def __init__(self, lengthscale: float, sigma: float, jitter: float):
        super().__init__()
        self.log_lengthscale = nn.Parameter(torch.log(torch.tensor(float(lengthscale))))
        self.log_sigma = nn.Parameter(torch.log(torch.tensor(float(sigma))))
        self.jitter = float(jitter)

    def _kernel(self, T: int, device):
        t = torch.arange(T, device=device).float()
        dt2 = (t[:, None] - t[None, :]) ** 2
        ell = torch.exp(self.log_lengthscale).clamp(1e-3, 1e6)
        sig = torch.exp(self.log_sigma).clamp(1e-6, 1e6)
        K = (sig**2) * torch.exp(-0.5 * dt2 / (ell**2))
        return K + self.jitter * torch.eye(T, device=device)

    def _chol(self, K: torch.Tensor) -> torch.Tensor:
        try:
            return torch.linalg.cholesky(K)
        except RuntimeError:
            T = K.size(0)
            return torch.linalg.cholesky(K + (10.0 * self.jitter) * torch.eye(T, device=K.device))

    def log_p(self, z):
        B, T, Dz = z.shape
        K = self._kernel(T, z.device)
        L = self._chol(K)

        z_td = z.transpose(1, 2).contiguous().view(B * Dz, T, 1)
        alpha = torch.cholesky_solve(z_td, L)
        quad = (z_td * alpha).sum(dim=(1, 2))
        logdet = 2.0 * torch.log(torch.diagonal(L)).sum()
        const = T * math.log(2 * math.pi)
        lp = -0.5 * (quad + logdet + const)
        return lp.view(B, Dz).sum(dim=1)

    def log_p_per_dim(self, z):
        B, T, Dz = z.shape
        K = self._kernel(T, z.device)
        L = self._chol(K)
        logdet = 2.0 * torch.log(torch.diagonal(L)).sum()
        const = T * math.log(2 * math.pi)

        z_td = z.transpose(1, 2).contiguous().view(B * Dz, T, 1)
        alpha = torch.cholesky_solve(z_td, L)
        quad = (z_td * alpha).sum(dim=(1, 2))
        lp = -0.5 * (quad + logdet + const)
        return lp.view(B, Dz)

    @torch.no_grad()
    def sample(self, B, T, Dz, device):
        K = self._kernel(T, device)
        L = self._chol(K)
        eps = torch.randn(B, Dz, T, device=device)
        z = torch.matmul(eps, L.T)
        return z.permute(0, 2, 1).contiguous()



# VAE model

class VAETextLM(nn.Module):
    def __init__(self, cfg: CFG, prior: PriorBase):
        super().__init__()
        self.cfg = cfg
        self.prior = prior

        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.block_size, cfg.d_model)
        self.drop = nn.Dropout(cfg.dropout)

        self.encoder = TransformerEncoder(cfg.d_model, cfg.n_heads, cfg.n_layers, cfg.dropout)

        self.to_mu = nn.Linear(cfg.d_model, cfg.z_dim)
        self.to_logvar = nn.Linear(cfg.d_model, cfg.z_dim)

        self.z_proj = nn.Linear(cfg.z_dim, cfg.d_model)
        self.decoder = TransformerDecoderLM(cfg.d_model, cfg.n_heads, cfg.n_layers, cfg.dropout)

        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

    def encode(self, x_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, T = x_ids.shape
        tok = self.tok_emb(x_ids)
        pos = self.pos_emb(torch.arange(T, device=x_ids.device))[None, :, :]
        h = self.drop(tok + pos)
        h = self.encoder(h)
        mu = self.to_mu(h)
        logvar = self.to_logvar(h).clamp(-12.0, 6.0)
        return mu, logvar

    def reparam(self, mu, logvar, n_samples: int):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn((n_samples,) + mu.shape, device=mu.device)
        return mu[None] + eps * std[None]

    def log_q_total(self, z, mu, logvar) -> torch.Tensor:
        var = torch.exp(logvar)
        log2pi = math.log(2 * math.pi)
        diff2 = (z - mu[None]) ** 2
        lq = -0.5 * (diff2 / var[None] + logvar[None] + log2pi)
        return lq.sum(dim=(2, 3))

    def log_q_per_dim(self, z_btD: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        var = torch.exp(logvar)
        log2pi = math.log(2 * math.pi)
        diff2 = (z_btD - mu) ** 2
        lq = -0.5 * (diff2 / var + logvar + log2pi)
        return lq.sum(dim=1)

    def decode_logits(self, x_ids: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        B, T = x_ids.shape
        tok = self.tok_emb(x_ids)
        pos = self.pos_emb(torch.arange(T, device=x_ids.device))[None, :, :]
        h = self.drop(tok + pos + self.z_proj(z))
        h = self.decoder(h)
        return self.lm_head(h)

    def forward(self, x_ids: torch.Tensor, y_ids: torch.Tensor, beta: float) -> Dict[str, torch.Tensor]:
        B, T = x_ids.shape
        mu, logvar = self.encode(x_ids)
        zS = self.reparam(mu, logvar, self.cfg.n_z_samples)

        nll_list, logq_list, logp_list = [], [], []
        last_logits = None
        last_z = None

        for s in range(self.cfg.n_z_samples):
            z = zS[s]
            logits = self.decode_logits(x_ids, z)
            last_logits = logits
            last_z = z

            nll = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                y_ids.view(-1),
                reduction="none",
            ).view(B, T).sum(dim=1)
            nll_list.append(nll)

            logq = self.log_q_total(zS[s:s+1], mu, logvar)[0]
            logq_list.append(logq)

            logp = self.prior.log_p(z)
            logp_list.append(logp)

        nll = torch.stack(nll_list, dim=0).mean(dim=0)
        logq = torch.stack(logq_list, dim=0).mean(dim=0)
        logp = torch.stack(logp_list, dim=0).mean(dim=0)

        kl = (logq - logp)
        elbo = -(nll + beta * kl)

        nll_tok = nll.mean() / T
        kl_tok = kl.mean() / T
        elbo_tok = elbo.mean() / T

        ppl = torch.exp(nll_tok.detach())
        bits_per_tok = (nll_tok.detach() / math.log(2.0))

        loss = -(elbo_tok)
        return {
            "loss": loss,
            "elbo_tok": elbo_tok.detach(),
            "nll_tok": nll_tok.detach(),
            "kl_tok": kl_tok.detach(),
            "ppl": ppl.detach(),
            "bits_per_tok": bits_per_tok.detach(),
            "mu": mu,
            "logvar": logvar,
            "z_last": last_z,
            "logits_last": last_logits,
        }

    @torch.no_grad()
    def generate(self, prompt_ids: torch.Tensor, max_new_tokens: int = 80, top_k: int = 0) -> torch.Tensor:
        self.eval()
        device = prompt_ids.device
        B, Tp = prompt_ids.shape
        total_T = min(self.cfg.block_size, Tp + max_new_tokens)

        out = prompt_ids.clone()
        z = self.prior.sample(B, total_T, self.cfg.z_dim, device=device)

        for _ in range(max_new_tokens):
            Tcur = out.size(1)
            if Tcur >= total_T:
                break
            x = out[:, :Tcur]
            zcur = z[:, :Tcur]
            logits = self.decode_logits(x, zcur)
            next_logits = logits[:, -1, :]

            if top_k and top_k > 0:
                vals, idx = torch.topk(next_logits, k=top_k, dim=-1)
                probs = torch.zeros_like(next_logits).scatter_(-1, idx, F.softmax(vals, dim=-1))
            else:
                probs = F.softmax(next_logits, dim=-1)

            next_id = torch.multinomial(probs, num_samples=1)
            out = torch.cat([out, next_id], dim=1)

        return out



# Perf / GPU utilities

def _now():
    return time.perf_counter()

def sync_if_cuda(device: str):
    if isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available():
        torch.cuda.synchronize()

def cuda_mem_reset(device: str):
    if isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

@torch.no_grad()
def cuda_mem_snapshot_mb(device: str) -> Dict[str, float]:
    if not (isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available()):
        return {"peak_alloc_MB": float("nan"), "peak_reserved_MB": float("nan")}
    peak_alloc = torch.cuda.max_memory_allocated() / (1024**2)
    peak_reserved = torch.cuda.max_memory_reserved() / (1024**2)
    return {"peak_alloc_MB": float(peak_alloc), "peak_reserved_MB": float(peak_reserved)}

class RunningMean:
    def __init__(self):
        self.sum = 0.0
        self.n = 0
    def add(self, x: float):
        if math.isfinite(x):
            self.sum += float(x)
            self.n += 1
    def mean(self) -> float:
        return self.sum / max(1, self.n)



# Metric helpers (same as AR)

@torch.no_grad()
def logits_entropy_and_topk_mass(logits: torch.Tensor, top_k: int = 50) -> Dict[str, float]:
    probs = F.softmax(logits, dim=-1)
    logp = torch.log(probs.clamp_min(1e-12))
    ent = -(probs * logp).sum(dim=-1)
    ent_mean = ent.mean().item()
    if top_k and top_k > 0:
        V = probs.size(-1)
        vals, _ = torch.topk(probs, k=min(top_k, V), dim=-1)
        topk_mass = vals.sum(dim=-1).mean().item()
    else:
        topk_mass = float("nan")
    return {"tok_entropy_mean": float(ent_mean), "topk_mass_mean": float(topk_mass)}

@torch.no_grad()
def ece_from_logits(logits: torch.Tensor, targets: torch.Tensor, n_bins: int = 15) -> float:
    probs = F.softmax(logits, dim=-1)
    conf, pred = probs.max(dim=-1)
    acc = (pred == targets).float()
    conf = conf.reshape(-1)
    acc = acc.reshape(-1)
    bins = torch.linspace(0, 1, n_bins + 1, device=logits.device)
    ece = torch.zeros((), device=logits.device)
    for i in range(n_bins):
        lo, hi = bins[i], bins[i + 1]
        mask = (conf > lo) & (conf <= hi) if i > 0 else (conf >= lo) & (conf <= hi)
        if mask.any():
            ece = ece + (mask.float().mean()) * (acc[mask].mean() - conf[mask].mean()).abs()
    return float(ece.item())

@torch.no_grad()
def posterior_collapse_ratio(mu: torch.Tensor, logvar: torch.Tensor) -> Dict[str, float]:
    std = torch.exp(0.5 * logvar)
    mu_abs_mean = mu.abs().mean(dim=(0, 1))
    std_mean = std.mean(dim=(0, 1))
    collapsed = ((mu_abs_mean < 0.02) & ((std_mean - 1.0).abs() < 0.05)).float().mean().item()
    return {
        "post_std_mean": float(std.mean().item()),
        "post_mu_abs_mean": float(mu.abs().mean().item()),
        "collapse_dim_frac": float(collapsed),
    }

@torch.no_grad()
def latent_autocorr(x: torch.Tensor, lag: int = 1) -> float:
    if x.size(1) <= lag:
        return float("nan")
    a = x[:, :-lag, :].reshape(-1, x.size(-1))
    b = x[:, lag:, :].reshape(-1, x.size(-1))
    a = a - a.mean(dim=0, keepdim=True)
    b = b - b.mean(dim=0, keepdim=True)
    cov = (a * b).mean(dim=0)
    va = (a * a).mean(dim=0).clamp_min(1e-8)
    vb = (b * b).mean(dim=0).clamp_min(1e-8)
    return float((cov / torch.sqrt(va * vb)).mean().item())

@torch.no_grad()
def kl_per_dim_stats(model: VAETextLM, mu: torch.Tensor, logvar: torch.Tensor, z: torch.Tensor, eps: float) -> Dict[str, float]:
    B, T, Dz = z.shape
    logq_bd = model.log_q_per_dim(z, mu, logvar)
    logp_bd = model.prior.log_p_per_dim(z)
    kld_bd = (logq_bd - logp_bd) / max(1, T)
    kld_d = kld_bd.mean(dim=0).detach().cpu().numpy()
    return {
        "kldim_mean": float(np.mean(kld_d)),
        "kldim_median": float(np.median(kld_d)),
        "kldim_max": float(np.max(kld_d)),
        "kldim_frac_below_eps": float(np.mean(kld_d < eps)),
    }

@torch.no_grad()
def mi_proxy_batch(mu: torch.Tensor, logvar: torch.Tensor, z: torch.Tensor, max_components: int, max_points: int, seed: int = 0) -> float:
    B, T, Dz = z.shape
    C = B * T
    if C == 0:
        return float("nan")
    rng = np.random.default_rng(seed)

    mu_c = mu.reshape(C, Dz)
    lv_c = logvar.reshape(C, Dz)
    z_c  = z.reshape(C, Dz)

    comp_idx = np.arange(C)
    pt_idx = np.arange(C)
    if C > max_components:
        comp_idx = rng.choice(comp_idx, size=max_components, replace=False)
    if C > max_points:
        pt_idx = rng.choice(pt_idx, size=max_points, replace=False)

    mu_s = mu_c[torch.as_tensor(comp_idx, device=mu.device)]
    lv_s = lv_c[torch.as_tensor(comp_idx, device=mu.device)]
    var_s = torch.exp(lv_s).clamp_min(1e-12)

    z_pts  = z_c [torch.as_tensor(pt_idx, device=mu.device)]
    mu_pts = mu_c[torch.as_tensor(pt_idx, device=mu.device)]
    lv_pts = lv_c[torch.as_tensor(pt_idx, device=mu.device)]
    var_pts = torch.exp(lv_pts).clamp_min(1e-12)

    log2pi = math.log(2 * math.pi)
    lq_cond = -0.5 * (((z_pts - mu_pts) ** 2) / var_pts + lv_pts + log2pi).sum(dim=-1)
    eq_logq_z_given_x = lq_cond.mean()

    z_exp = z_pts[:, None, :]
    mu_exp = mu_s[None, :, :]
    lv_exp = lv_s[None, :, :]
    var_exp = var_s[None, :, :]

    lcomp = -0.5 * (((z_exp - mu_exp) ** 2) / var_exp + lv_exp + log2pi).sum(dim=-1)
    lmix = torch.logsumexp(lcomp, dim=1) - math.log(lcomp.size(1))
    eq_logq_z = lmix.mean()

    return float((eq_logq_z_given_x - eq_logq_z).item())

def rd_points(nll_tok: float, kl_tok: float, betas: Tuple[float, ...]) -> Dict[str, float]:
    out = {}
    for b in betas:
        out[f"rd_elbo_tok_beta_{b:g}"] = -(nll_tok + b * kl_tok)
    out["rd_nll_tok"] = float(nll_tok)
    out["rd_kl_tok"] = float(kl_tok)
    return out


def distinct_ngrams(token_seq: List[int], n: int) -> float:
    if len(token_seq) < n or n <= 0:
        return 0.0
    total = 0
    uniq = set()
    for i in range(len(token_seq) - n + 1):
        uniq.add(tuple(token_seq[i:i+n]))
        total += 1
    return (len(uniq) / total) if total > 0 else 0.0

def repetition_rate(token_seq: List[int], n: int) -> float:
    if len(token_seq) < n or n <= 0:
        return 0.0
    total = 0
    counts = {}
    for i in range(len(token_seq) - n + 1):
        ng = tuple(token_seq[i:i+n])
        counts[ng] = counts.get(ng, 0) + 1
        total += 1
    repeated = sum(c for c in counts.values() if c > 1)
    return repeated / total if total > 0 else 0.0


@torch.no_grad()
def generation_metrics(model: VAETextLM, tokenizer, device: str, n_prompts: int, max_new: int, top_k: int) -> Dict[str, Any]:
    model.eval()
    prompts = [
        "The meaning of life is",
        "In the middle of the night",
        "The government announced that",
        "A new theory suggests",
        "Once upon a time",
        "The experiment shows",
        "In a shocking discovery",
        "The book describes",
        "Scientists found that",
        "The president said",
        "In the future,",
        "The story begins",
    ]

    cuda_mem_reset(device)
    sync_if_cuda(device)
    t0 = _now()
    gen_new_tokens_total = 0

    per_prompt: List[Dict[str, Any]] = []
    decoded_samples: List[str] = []

    for i in range(n_prompts):
        p = prompts[i % len(prompts)]
        p_ids = tokenizer(p, return_tensors="pt").input_ids.to(device)
        if p_ids.size(1) > model.cfg.block_size // 2:
            p_ids = p_ids[:, : model.cfg.block_size // 2]

        out_ids = model.generate(p_ids, max_new_tokens=max_new, top_k=top_k)
        seq = out_ids[0].tolist()

        prompt_len = int(p_ids.size(1))
        new_tokens = max(0, len(seq) - prompt_len)
        gen_new_tokens_total += new_tokens

        st = {
            "prompt": p,
            "prompt_len": prompt_len,
            "total_len": len(seq),
            "new_tokens": new_tokens,
            "rep2": repetition_rate(seq, 2),
            "rep3": repetition_rate(seq, 3),
            "distinct2": distinct_ngrams(seq, 2),
            "distinct3": distinct_ngrams(seq, 3),
        }

        txt = tokenizer.decode(seq[: min(len(seq), 250)], skip_special_tokens=True)
        st["sample"] = txt
        per_prompt.append(st)

        if i < 3:
            decoded_samples.append(txt)

    sync_if_cuda(device)
    dt = max(1e-9, (_now() - t0))
    mem = cuda_mem_snapshot_mb(device)

    rep2 = float(np.mean([p["rep2"] for p in per_prompt])) if per_prompt else 0.0
    rep3 = float(np.mean([p["rep3"] for p in per_prompt])) if per_prompt else 0.0
    distinct2 = float(np.mean([p["distinct2"] for p in per_prompt])) if per_prompt else 0.0
    distinct3 = float(np.mean([p["distinct3"] for p in per_prompt])) if per_prompt else 0.0
    lengths = [p["total_len"] for p in per_prompt]
    newlens = [p["new_tokens"] for p in per_prompt]

    return {
        "gen_distinct2": distinct2,
        "gen_distinct3": distinct3,
        "gen_rep2": rep2,
        "gen_rep3": rep3,
        "gen_len_mean": float(np.mean(lengths)) if lengths else 0.0,
        "gen_len_std": float(np.std(lengths)) if lengths else 0.0,
        "gen_new_tokens_mean": float(np.mean(newlens)) if newlens else 0.0,
        "gen_new_tokens_total": int(gen_new_tokens_total),

        "time_gen_seconds_total": float(dt),
        "time_gen_tokens_per_s": float(gen_new_tokens_total / dt),
        "gpu_gen_peak_alloc_MB": mem["peak_alloc_MB"],
        "gpu_gen_peak_reserved_MB": mem["peak_reserved_MB"],

        "gen_per_prompt": per_prompt,
        "gen_sample_0": decoded_samples[0] if len(decoded_samples) > 0 else "",
        "gen_sample_1": decoded_samples[1] if len(decoded_samples) > 1 else "",
        "gen_sample_2": decoded_samples[2] if len(decoded_samples) > 2 else "",
    }


@torch.no_grad()
def eval_many_metrics(
    model: VAETextLM,
    loader: DataLoader,
    device: str,
    beta: float,
    max_batches: int,
    ece_bins: int,
    gen_do: bool,
    tokenizer=None,
    gen_prompts: int = 12,
    gen_max_new: int = 96,
    gen_top_k: int = 50,
    rd_betas: Tuple[float, ...] = (0.0, 1.0),
    measure_perf: bool = True,
    kldim_eps: float = 0.01,
    mi_max_components: int = 512,
    mi_max_points: int = 256,
    mi_seed: int = 0,
) -> Dict[str, Any]:
    model.eval()

    acc = {
        "nll_tok": 0.0,
        "kl_tok": 0.0,
        "elbo_tok": 0.0,
        "tok_entropy_mean": 0.0,
        "topk_mass_mean": 0.0,
        "ece": 0.0,
        "post_std_mean": 0.0,
        "post_mu_abs_mean": 0.0,
        "collapse_dim_frac": 0.0,
        "mu_ac1": 0.0,
        "mu_ac5": 0.0,
        "z_ac1": 0.0,
        "z_ac5": 0.0,
        "kldim_mean": 0.0,
        "kldim_median": 0.0,
        "kldim_max": 0.0,
        "kldim_frac_below_eps": 0.0,
        "mi_proxy": 0.0,
    }
    n = 0

    step_ms = RunningMean()
    tokens_per_s = RunningMean()
    total_tokens = 0

    if measure_perf:
        cuda_mem_reset(device)

    for i, (x, y) in enumerate(loader):
        if i >= max_batches:
            break
        x = x.to(device)
        y = y.to(device)

        if measure_perf:
            sync_if_cuda(device)
            t0 = _now()

        out = model(x, y, beta=beta)

        if measure_perf:
            sync_if_cuda(device)
            dt = max(1e-9, (_now() - t0))
            step_ms.add(1000.0 * dt)
            B, T = x.shape
            tok = int(B * T)
            total_tokens += tok
            tokens_per_s.add(tok / dt)

        nll_tok = float(out["nll_tok"])
        kl_tok = float(out["kl_tok"])
        elbo_tok = float(out["elbo_tok"])
        if not (math.isfinite(nll_tok) and math.isfinite(kl_tok) and math.isfinite(elbo_tok)):
            continue

        logits = out["logits_last"]
        mu = out["mu"].detach()
        logvar = out["logvar"].detach()
        zlast = out["z_last"].detach()

        ent_topk = logits_entropy_and_topk_mass(logits, top_k=gen_top_k)
        ece = ece_from_logits(logits, y, n_bins=ece_bins)
        post = posterior_collapse_ratio(mu, logvar)

        mu_ac1 = latent_autocorr(mu, lag=1)
        mu_ac5 = latent_autocorr(mu, lag=5)
        z_ac1 = latent_autocorr(zlast, lag=1)
        z_ac5 = latent_autocorr(zlast, lag=5)

        kld = kl_per_dim_stats(model, mu, logvar, zlast, eps=kldim_eps)
        mi = mi_proxy_batch(mu, logvar, zlast, max_components=mi_max_components, max_points=mi_max_points, seed=mi_seed)

        acc["nll_tok"] += nll_tok
        acc["kl_tok"] += kl_tok
        acc["elbo_tok"] += elbo_tok
        acc["tok_entropy_mean"] += ent_topk["tok_entropy_mean"]
        acc["topk_mass_mean"] += ent_topk["topk_mass_mean"]
        acc["ece"] += ece
        acc["post_std_mean"] += post["post_std_mean"]
        acc["post_mu_abs_mean"] += post["post_mu_abs_mean"]
        acc["collapse_dim_frac"] += post["collapse_dim_frac"]
        acc["mu_ac1"] += mu_ac1
        acc["mu_ac5"] += mu_ac5
        acc["z_ac1"] += z_ac1
        acc["z_ac5"] += z_ac5
        acc["kldim_mean"] += kld["kldim_mean"]
        acc["kldim_median"] += kld["kldim_median"]
        acc["kldim_max"] += kld["kldim_max"]
        acc["kldim_frac_below_eps"] += kld["kldim_frac_below_eps"]
        acc["mi_proxy"] += mi
        n += 1

    if n == 0:
        return {k: float("nan") for k in acc.keys()}

    for k in acc:
        acc[k] /= n

    acc["ppl"] = math.exp(acc["nll_tok"])
    acc["bits_per_tok"] = acc["nll_tok"] / math.log(2.0)
    acc.update(rd_points(acc["nll_tok"], acc["kl_tok"], rd_betas))

    if measure_perf:
        mem = cuda_mem_snapshot_mb(device)
        acc["time_eval_step_ms_mean"] = step_ms.mean()
        acc["time_eval_tokens_per_s_mean"] = tokens_per_s.mean()
        acc["gpu_eval_peak_alloc_MB"] = mem["peak_alloc_MB"]
        acc["gpu_eval_peak_reserved_MB"] = mem["peak_reserved_MB"]
        acc["eval_total_tokens_measured"] = int(total_tokens)

    if gen_do and tokenizer is not None:
        acc.update(
            generation_metrics(model, tokenizer, device=device, n_prompts=gen_prompts, max_new=gen_max_new, top_k=gen_top_k)
        )

    return acc



# Schedules

def lr_schedule(step, cfg: CFG):
    if step < cfg.warmup_steps:
        return cfg.lr * (step / max(1, cfg.warmup_steps))
    progress = (step - cfg.warmup_steps) / max(1, cfg.max_steps - cfg.warmup_steps)
    return cfg.lr * (0.1 + 0.9 * 0.5 * (1.0 + math.cos(math.pi * progress)))

def beta_schedule(step, cfg: CFG):
    if step >= cfg.beta_warmup_steps:
        return cfg.beta_end
    a = step / max(1, cfg.beta_warmup_steps)
    return cfg.beta_start + a * (cfg.beta_end - cfg.beta_start)



# Train (GP only)

def train_gp(cfg: CFG, train_loader, val_loader, tokenizer) -> Dict[str, Any]:
    device = cfg.device
    prior = PriorGP(cfg.gp_lengthscale, cfg.gp_sigma, cfg.gp_jitter)
    model = VAETextLM(cfg, prior).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    use_amp = bool(cfg.amp and device.startswith("cuda"))
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

    os.makedirs(cfg.out_dir, exist_ok=True)
    run_stamp = int(time.time())
    log_path = os.path.join(cfg.out_dir, f"{cfg.run_name}_vae_gp_{run_stamp}.jsonl")

    best_val = float("inf")
    best: Dict[str, Any] = {}

    t0 = time.time()
    pbar = tqdm(total=cfg.max_steps, desc="train[vae_gp]")

    it = iter(train_loader)

    train_step_ms = RunningMean()
    train_tokens_per_s = RunningMean()
    train_tokens_total = 0
    cuda_mem_reset(device)

    for step in range(1, cfg.max_steps + 1):
        try:
            x, y = next(it)
        except StopIteration:
            it = iter(train_loader)
            x, y = next(it)

        x, y = x.to(device), y.to(device)

        beta = beta_schedule(step, cfg)
        lr = lr_schedule(step, cfg)
        for pg in opt.param_groups:
            pg["lr"] = lr

        model.train()
        opt.zero_grad(set_to_none=True)

        if use_amp:
            sync_if_cuda(device)
        t_step0 = _now()

        with torch.amp.autocast("cuda", enabled=use_amp):
            out = model(x, y, beta=beta)
            loss = out["loss"]

        scaler.scale(loss).backward()
        if cfg.grad_clip > 0:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        scaler.step(opt)
        scaler.update()

        if use_amp:
            sync_if_cuda(device)
        dt = max(1e-9, (_now() - t_step0))
        train_step_ms.add(1000.0 * dt)
        B, T = x.shape
        tok = int(B * T)
        train_tokens_total += tok
        train_tokens_per_s.add(tok / dt)

        if step % cfg.log_every == 0:
            pbar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "nll_tok": f"{float(out['nll_tok']):.4f}",
                "kl_tok": f"{float(out['kl_tok']):.4f}",
                "ppl": f"{float(out['ppl']):.2f}",
                "beta": f"{beta:.2f}",
                "lr": f"{lr:.1e}",
            })

        if step % cfg.eval_every == 0 or step == 1:
            mem_train = cuda_mem_snapshot_mb(device)
            train_perf = {
                "time_train_step_ms_mean_window": train_step_ms.mean(),
                "time_train_tokens_per_s_mean_window": train_tokens_per_s.mean(),
                "gpu_train_peak_alloc_MB_window": mem_train["peak_alloc_MB"],
                "gpu_train_peak_reserved_MB_window": mem_train["peak_reserved_MB"],
                "train_tokens_measured_window": int(train_tokens_total),
            }

            train_eval = eval_many_metrics(
                model, train_loader, device=device, beta=1.0,
                max_batches=min(cfg.eval_train_batches, cfg.eval_max_batches),
                ece_bins=cfg.ece_bins,
                gen_do=False,
                rd_betas=cfg.rd_betas,
                measure_perf=True,
                kldim_eps=cfg.kldim_eps,
                mi_max_components=cfg.mi_max_components,
                mi_max_points=cfg.mi_max_points,
                mi_seed=cfg.mi_seed,
                gen_top_k=cfg.eval_gen_top_k,
            )
            val_eval = eval_many_metrics(
                model, val_loader, device=device, beta=1.0,
                max_batches=cfg.eval_max_batches,
                ece_bins=cfg.ece_bins,
                gen_do=True, tokenizer=tokenizer,
                gen_prompts=cfg.eval_gen_prompts,
                gen_max_new=cfg.eval_gen_max_new,
                gen_top_k=cfg.eval_gen_top_k,
                rd_betas=cfg.rd_betas,
                measure_perf=True,
                kldim_eps=cfg.kldim_eps,
                mi_max_components=cfg.mi_max_components,
                mi_max_points=cfg.mi_max_points,
                mi_seed=cfg.mi_seed,
            )

            val_nll = val_eval.get("nll_tok", float("inf"))
            if math.isfinite(val_nll) and (val_nll < best_val):
                best_val = float(val_nll)
                best = {
                    "variant": "vae_gp",
                    "step": step,
                    "wall_s": time.time() - t0,
                    "best_val_nll_tok": best_val,
                    "train_perf_window": train_perf,
                    "train": train_eval,
                    "val": val_eval,
                }

            rec = {
                "variant": "vae_gp",
                "step": step,
                "wall_s": time.time() - t0,
                "lr": float(lr),
                "beta": float(beta),
                "train_perf_window": train_perf,
                "train": train_eval,
                "val": val_eval,
                "best_val_nll_tok_so_far": float(best_val),
            }
            with open(log_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(rec, ensure_ascii=False) + "\n")

            print(
                f"[eval] vae_gp step={step:5d} "
                f"train_tok/s={train_eval.get('time_eval_tokens_per_s_mean', float('nan')):.1f} "
                f"val_tok/s={val_eval.get('time_eval_tokens_per_s_mean', float('nan')):.1f} "
                f"gen_tok/s={val_eval.get('time_gen_tokens_per_s', float('nan')):.1f} "
                f"val_ppl={val_eval.get('ppl', float('nan')):.2f} "
                f"val_nll={val_eval.get('nll_tok', float('nan')):.4f} val_kl={val_eval.get('kl_tok', float('nan')):.4f} "
                f"ece={val_eval.get('ece', float('nan')):.3f} ent={val_eval.get('tok_entropy_mean', float('nan')):.3f} "
                f"mi={val_eval.get('mi_proxy', float('nan')):.3f} "
                f"kldim_med={val_eval.get('kldim_median', float('nan')):.3f} "
                f"train_peakMB={train_perf.get('gpu_train_peak_alloc_MB_window', float('nan')):.1f}"
            )

            train_step_ms = RunningMean()
            train_tokens_per_s = RunningMean()
            train_tokens_total = 0
            cuda_mem_reset(device)

        pbar.update(1)

    pbar.close()

    if not best:
        best = {"variant": "vae_gp", "step": cfg.max_steps, "wall_s": time.time() - t0, "best_val_nll_tok": best_val}
    best["log_path"] = log_path
    return best


def main():
    cfg = CFG()
    set_seed(cfg.seed)

    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name)
    tokenizer.model_max_length = int(1e9)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    cfg.vocab_size = len(tokenizer)

    train_loader, val_loader = load_wt2_blocks(cfg, tokenizer)

    best = train_gp(cfg, train_loader, val_loader, tokenizer)

    os.makedirs(cfg.out_dir, exist_ok=True)
    out_json = os.path.join(cfg.out_dir, f"{cfg.run_name}_BEST_{int(time.time())}.json")
    with open(out_json, "w", encoding="utf-8") as f:
        json.dump({"cfg": asdict(cfg), "best": best}, f, indent=2, ensure_ascii=False)

    print(f"\nSaved BEST summary JSON: {out_json}")
    print(f"Best log JSONL: {best.get('log_path','')}")


if __name__ == "__main__":
    main()


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

train[vae_gp]:   0%|          | 1/6000 [00:14<24:12:31, 14.53s/it]

[eval] vae_gp step=    1 train_tok/s=60832.1 val_tok/s=62765.0 gen_tok/s=301.4 val_ppl=112179.78 val_nll=11.6279 val_kl=89006.5137 ece=0.002 ent=10.022 mi=49.252 kldim_med=1353.414 train_peakMB=3519.4


train[vae_gp]:   7%|▋         | 400/6000 [01:58<6:23:37,  4.11s/it, loss=11.8596, nll_tok=6.3095, kl_tok=20.8127, ppl=549.79, beta=0.27, lr=3.0e-04]   

[eval] vae_gp step=  400 train_tok/s=62750.8 val_tok/s=62846.8 gen_tok/s=315.6 val_ppl=704.34 val_nll=6.5573 val_kl=18.4760 ece=0.018 ent=6.313 mi=2.280 kldim_med=0.289 train_peakMB=3981.2


train[vae_gp]:  13%|█▎        | 800/6000 [03:42<5:56:51,  4.12s/it, loss=12.7300, nll_tok=5.9386, kl_tok=12.7339, ppl=379.40, beta=0.53, lr=2.9e-04]

[eval] vae_gp step=  800 train_tok/s=62762.2 val_tok/s=62674.2 gen_tok/s=315.0 val_ppl=488.03 val_nll=6.1904 val_kl=12.7276 ece=0.022 ent=5.817 mi=0.406 kldim_med=0.199 train_peakMB=3981.2


train[vae_gp]:  20%|██        | 1200/6000 [05:26<5:29:33,  4.12s/it, loss=14.0158, nll_tok=5.7446, kl_tok=10.3391, ppl=312.49, beta=0.80, lr=2.8e-04]

[eval] vae_gp step= 1200 train_tok/s=62807.7 val_tok/s=62742.0 gen_tok/s=315.1 val_ppl=408.35 val_nll=6.0121 val_kl=10.1886 ece=0.021 ent=5.658 mi=0.153 kldim_med=0.159 train_peakMB=3981.2


train[vae_gp]:  27%|██▋       | 1600/6000 [07:11<5:03:36,  4.14s/it, loss=14.1751, nll_tok=5.4694, kl_tok=8.7057, ppl=237.31, beta=1.00, lr=2.7e-04] 

[eval] vae_gp step= 1600 train_tok/s=62903.6 val_tok/s=62791.2 gen_tok/s=307.2 val_ppl=364.89 val_nll=5.8996 val_kl=8.2840 ece=0.024 ent=5.516 mi=0.068 kldim_med=0.129 train_peakMB=3981.2


train[vae_gp]:  33%|███▎      | 2000/6000 [08:56<4:36:11,  4.14s/it, loss=12.7825, nll_tok=5.1551, kl_tok=7.6274, ppl=173.32, beta=1.00, lr=2.4e-04]

[eval] vae_gp step= 2000 train_tok/s=62663.1 val_tok/s=62723.5 gen_tok/s=309.3 val_ppl=337.58 val_nll=5.8218 val_kl=7.1116 ece=0.033 ent=5.047 mi=0.040 kldim_med=0.111 train_peakMB=3981.2


train[vae_gp]:  40%|████      | 2400/6000 [10:40<4:08:25,  4.14s/it, loss=11.5538, nll_tok=4.8017, kl_tok=6.7520, ppl=121.72, beta=1.00, lr=2.2e-04]

[eval] vae_gp step= 2400 train_tok/s=62591.2 val_tok/s=62742.2 gen_tok/s=311.8 val_ppl=321.30 val_nll=5.7724 val_kl=6.3504 ece=0.029 ent=5.002 mi=0.029 kldim_med=0.099 train_peakMB=3981.2


train[vae_gp]:  47%|████▋     | 2800/6000 [12:25<3:40:47,  4.14s/it, loss=10.9547, nll_tok=4.8506, kl_tok=6.1041, ppl=127.82, beta=1.00, lr=1.9e-04]

[eval] vae_gp step= 2800 train_tok/s=62692.0 val_tok/s=62761.9 gen_tok/s=310.8 val_ppl=314.71 val_nll=5.7516 val_kl=5.8163 ece=0.038 ent=4.716 mi=0.025 kldim_med=0.091 train_peakMB=3981.2


train[vae_gp]:  53%|█████▎    | 3200/6000 [14:10<3:13:22,  4.14s/it, loss=10.2535, nll_tok=4.7033, kl_tok=5.5503, ppl=110.31, beta=1.00, lr=1.6e-04]

[eval] vae_gp step= 3200 train_tok/s=62595.2 val_tok/s=62512.2 gen_tok/s=312.6 val_ppl=309.94 val_nll=5.7364 val_kl=5.3650 ece=0.036 ent=4.634 mi=0.023 kldim_med=0.084 train_peakMB=3981.2


train[vae_gp]:  60%|██████    | 3600/6000 [15:55<2:45:02,  4.13s/it, loss=9.6672, nll_tok=4.5063, kl_tok=5.1609, ppl=90.59, beta=1.00, lr=1.3e-04]  

[eval] vae_gp step= 3600 train_tok/s=62803.6 val_tok/s=62863.7 gen_tok/s=312.3 val_ppl=305.89 val_nll=5.7232 val_kl=5.0337 ece=0.040 ent=4.543 mi=0.022 kldim_med=0.079 train_peakMB=3981.2


train[vae_gp]:  67%|██████▋   | 4000/6000 [17:39<2:17:04,  4.11s/it, loss=9.0829, nll_tok=4.2055, kl_tok=4.8774, ppl=67.06, beta=1.00, lr=1.0e-04]

[eval] vae_gp step= 4000 train_tok/s=62708.9 val_tok/s=62869.2 gen_tok/s=316.2 val_ppl=310.82 val_nll=5.7392 val_kl=4.8337 ece=0.049 ent=4.378 mi=0.013 kldim_med=0.076 train_peakMB=3981.2


train[vae_gp]:  73%|███████▎  | 4400/6000 [19:24<1:50:01,  4.13s/it, loss=8.8465, nll_tok=4.1410, kl_tok=4.7056, ppl=62.86, beta=1.00, lr=7.9e-05]

[eval] vae_gp step= 4400 train_tok/s=62639.1 val_tok/s=62802.6 gen_tok/s=313.0 val_ppl=313.63 val_nll=5.7482 val_kl=4.6256 ece=0.051 ent=4.284 mi=0.013 kldim_med=0.072 train_peakMB=3981.2


train[vae_gp]:  80%|████████  | 4800/6000 [21:09<1:22:40,  4.13s/it, loss=8.5889, nll_tok=4.0517, kl_tok=4.5372, ppl=57.49, beta=1.00, lr=5.8e-05]

[eval] vae_gp step= 4800 train_tok/s=62712.4 val_tok/s=62664.5 gen_tok/s=314.0 val_ppl=314.73 val_nll=5.7517 val_kl=4.5375 ece=0.051 ent=4.242 mi=0.013 kldim_med=0.071 train_peakMB=3981.2


train[vae_gp]:  87%|████████▋ | 5200/6000 [22:53<54:57,  4.12s/it, loss=8.5168, nll_tok=4.0733, kl_tok=4.4435, ppl=58.75, beta=1.00, lr=4.3e-05]  

[eval] vae_gp step= 5200 train_tok/s=62801.5 val_tok/s=62724.8 gen_tok/s=313.3 val_ppl=318.36 val_nll=5.7632 val_kl=4.4181 ece=0.059 ent=4.143 mi=0.013 kldim_med=0.069 train_peakMB=3981.2


train[vae_gp]:  93%|█████████▎| 5600/6000 [24:38<27:27,  4.12s/it, loss=8.5446, nll_tok=4.1746, kl_tok=4.3700, ppl=65.01, beta=1.00, lr=3.3e-05]

[eval] vae_gp step= 5600 train_tok/s=62720.8 val_tok/s=62764.2 gen_tok/s=313.9 val_ppl=320.58 val_nll=5.7701 val_kl=4.3187 ece=0.059 ent=4.143 mi=0.014 kldim_med=0.068 train_peakMB=3981.2


train[vae_gp]: 100%|██████████| 6000/6000 [26:23<00:00,  3.79it/s, loss=8.2149, nll_tok=3.9117, kl_tok=4.3032, ppl=49.98, beta=1.00, lr=3.0e-05]

[eval] vae_gp step= 6000 train_tok/s=62855.9 val_tok/s=62756.8 gen_tok/s=316.0 val_ppl=322.84 val_nll=5.7772 val_kl=4.2508 ece=0.057 ent=4.135 mi=0.012 kldim_med=0.066 train_peakMB=3981.2

Saved BEST summary JSON: runs/wt2_latentGP_BEST_1765782489.json
Best log JSONL: runs/wt2_latentGP_vae_gp_1765780906.jsonl





In [1]:
# ============================================================
# GP-VAE (iid / GP / AR priors) on WikiText-2 blocks
# "NO-DRAMA" FULL VERSION + MANY METRICS
#
# compute-cost metrics:
# - Training time: step_ms, tokens/s (windowed between evals)
# - Inference time: eval tokens/s (teacher forcing) + generation tokens/s
# - GPU memory: peak allocated + reserved (train/eval/gen)
#
# extra metrics:
# - Bits-back style diagnostics: KL per dim stats (mean/median/max) + frac(KLdim < eps)
# - MI(x,z) proxy: E_q[log q(z|x)] - E_q[log q(z)] approx (batch-based, subsampled)
# - Rate–Distortion points: log (NLL, KL) + ELBO(beta) for multiple fixed betas
# - Generation: per-prompt length + repetition (not just aggregated) + samples
#
# Notes:
# - Tokenization long-seq warning suppressed by setting tokenizer.model_max_length huge.
# - GP prior supports variable T safely; includes cholesky fallback with extra jitter.
# ============================================================

import math
import os
import time
import json
import random
from dataclasses import dataclass, asdict
from typing import Dict, Tuple, List, Any, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from datasets import load_dataset
from transformers import AutoTokenizer


# Environment / warnings

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# More deterministic-ish behavior (still not perfect with transformers + GPU)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True



# Config

@dataclass
class CFG:
    seed: int = 0
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Data
    dataset_name: str = "wikitext"
    dataset_config: str = "wikitext-2-raw-v1"
    tokenizer_name: str = "gpt2"
    block_size: int = 256
    batch_size: int = 16
    num_workers: int = 2
    pin_memory: bool = True

    # Model
    vocab_size: int = 50257  # overwritten after tokenizer load
    d_model: int = 384
    n_layers: int = 6
    n_heads: int = 6
    dropout: float = 0.1

    # Latent
    z_dim: int = 64
    n_z_samples: int = 1  # MC samples for KL estimate

    # Training
    lr: float = 3e-4
    weight_decay: float = 0.01
    max_steps: int = 6000
    warmup_steps: int = 300
    grad_clip: float = 1.0
    eval_every: int = 400
    log_every: int = 50
    amp: bool = True

    # KL anneal
    beta_start: float = 0.0
    beta_end: float = 1.0
    beta_warmup_steps: int = 1500

    # GP prior (RBF kernel)
    gp_lengthscale: float = 25.0
    gp_sigma: float = 1.0
    gp_jitter: float = 1e-3  # safer default than 1e-4

    # AR prior
    ar_init_rho: float = 0.95
    ar_sigma: float = 0.5

    # Eval controls
    eval_max_batches: int = 80
    eval_train_batches: int = 20
    eval_gen_prompts: int = 12
    eval_gen_max_new: int = 96
    eval_gen_top_k: int = 50

    # ECE calibration
    ece_bins: int = 15

    # Bits-back diagnostics
    kldim_eps: float = 0.01  # threshold for "inactive" dims (per-token)

    # MI proxy (batch-based, subsampled for cost)
    mi_max_components: int = 512   # max components in mixture for q(z)
    mi_max_points: int = 256       # max sample points z to evaluate
    mi_seed: int = 0

    # Rate–Distortion points (log derived ELBO for multiple betas)
    rd_betas: Tuple[float, ...] = (0.0, 0.1, 0.5, 1.0, 2.0)

    # Logging / saving
    out_dir: str = "runs"
    run_name: str = "wt2_gpvae_compare"


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)



# Dataset: GPT-2 tokenizer blocks

class LMBlocks(Dataset):
    def __init__(self, token_ids: List[int], block_size: int):
        self.block_size = block_size
        n = (len(token_ids) - 1) // block_size  # keep only full blocks
        self.data = token_ids[: n * block_size + 1]

    def __len__(self):
        return (len(self.data) - 1) // self.block_size

    def __getitem__(self, idx):
        i = idx * self.block_size
        x = torch.tensor(self.data[i : i + self.block_size], dtype=torch.long)
        y = torch.tensor(self.data[i + 1 : i + self.block_size + 1], dtype=torch.long)
        return x, y


def tokenize_split_streaming(tokenizer, texts: List[str], chunk_chars: int = 200_000) -> List[int]:
    """
    Tokenize a large split without building one gigantic string.

    Also avoids HF "sequence length > model max length" warnings by
    using a huge tokenizer.model_max_length (set in main).
    """
    ids: List[int] = []
    buf: List[str] = []
    cur_len = 0
    sep = "\n\n"

    for t in texts:
        if not t:
            continue
        add = t + sep
        buf.append(add)
        cur_len += len(add)
        if cur_len >= chunk_chars:
            chunk = "".join(buf)
            ids.extend(tokenizer.encode(chunk))
            buf = []
            cur_len = 0

    if buf:
        chunk = "".join(buf)
        ids.extend(tokenizer.encode(chunk))

    return ids


def load_wt2_blocks(cfg: CFG, tokenizer) -> Tuple[DataLoader, DataLoader]:
    ds = load_dataset(cfg.dataset_name, cfg.dataset_config)

    train_ids = tokenize_split_streaming(tokenizer, ds["train"]["text"])
    val_ids = tokenize_split_streaming(tokenizer, ds["validation"]["text"])

    train_set = LMBlocks(train_ids, cfg.block_size)
    val_set = LMBlocks(val_ids, cfg.block_size)

    train_loader = DataLoader(
        train_set,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory,
        drop_last=True,
        persistent_workers=(cfg.num_workers > 0),
    )
    val_loader = DataLoader(
        val_set,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory,
        drop_last=False,
        persistent_workers=(cfg.num_workers > 0),
    )
    return train_loader, val_loader



# Transformer blocks

class TransformerEncoder(nn.Module):
    def __init__(self, d_model, n_heads, n_layers, dropout):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.enc = nn.TransformerEncoder(layer, num_layers=n_layers)

    def forward(self, x, src_key_padding_mask=None):
        return self.enc(x, src_key_padding_mask=src_key_padding_mask)


class TransformerDecoderLM(nn.Module):
    """
    Causal LM decoder implemented as TransformerEncoder with a causal mask.
    """
    def __init__(self, d_model, n_heads, n_layers, dropout):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.dec = nn.TransformerEncoder(layer, num_layers=n_layers)

    def forward(self, x):
        _, T, _ = x.shape
        attn_mask = torch.full((T, T), float("-inf"), device=x.device)
        attn_mask = torch.triu(attn_mask, diagonal=1)
        return self.dec(x, mask=attn_mask)



# Priors

class PriorBase(nn.Module):
    def log_p(self, z: torch.Tensor) -> torch.Tensor:
        """ z: (B,T,Dz) -> log p(z): (B,) """
        raise NotImplementedError

    def log_p_per_dim(self, z: torch.Tensor) -> torch.Tensor:
        """
        z: (B,T,Dz) -> returns per-dim log p(z_{:, :, d}) summed over time: (B,Dz)
        Used for KL-per-dim diagnostics. Must respect factorization over dims.
        """
        raise NotImplementedError

    @torch.no_grad()
    def sample(self, B: int, T: int, Dz: int, device: str) -> torch.Tensor:
        raise NotImplementedError


class PriorIID(PriorBase):
    def log_p(self, z):
        return (-0.5 * (z**2 + math.log(2 * math.pi))).sum(dim=(1, 2))

    def log_p_per_dim(self, z):
        # sum over time per dim
        # (B,T,Dz) -> (B,Dz)
        return (-0.5 * (z**2 + math.log(2 * math.pi))).sum(dim=1)

    @torch.no_grad()
    def sample(self, B, T, Dz, device):
        return torch.randn(B, T, Dz, device=device)


class PriorAR(PriorBase):
    """
    p(z1)=N(0,I), p(zt|z_{t-1})=N(rho z_{t-1}, sigma^2 I)
    rho learnable scalar.
    """
    def __init__(self, Dz: int, rho_init: float, sigma: float):
        super().__init__()
        self.logit_rho = nn.Parameter(torch.logit(torch.tensor(float(rho_init))))
        self.sigma = float(sigma)
        self.Dz = Dz

    def rho(self):
        return torch.sigmoid(self.logit_rho).clamp(1e-4, 0.9999)

    def log_p(self, z):
        B, T, Dz = z.shape
        lp1 = (-0.5 * (z[:, 0] ** 2 + math.log(2 * math.pi))).sum(dim=1)
        if T == 1:
            return lp1
        rho = self.rho()
        sigma2 = self.sigma ** 2
        resid = z[:, 1:] - rho * z[:, :-1]
        lp = (-0.5 * ((resid**2) / sigma2 + math.log(2 * math.pi * sigma2))).sum(dim=(1, 2))
        return lp1 + lp

    def log_p_per_dim(self, z):
        # (B,T,Dz) -> (B,Dz)
        B, T, Dz = z.shape
        lp1 = (-0.5 * (z[:, 0] ** 2 + math.log(2 * math.pi)))  # (B,Dz)
        if T == 1:
            return lp1
        rho = self.rho()
        sigma2 = self.sigma ** 2
        resid = z[:, 1:] - rho * z[:, :-1]  # (B,T-1,Dz)
        lp = -0.5 * ((resid**2) / sigma2 + math.log(2 * math.pi * sigma2))  # (B,T-1,Dz)
        return lp1 + lp.sum(dim=1)

    @torch.no_grad()
    def sample(self, B, T, Dz, device):
        rho = float(self.rho().item())
        z = torch.zeros(B, T, Dz, device=device)
        z[:, 0] = torch.randn(B, Dz, device=device)
        for t in range(1, T):
            z[:, t] = rho * z[:, t - 1] + self.sigma * torch.randn(B, Dz, device=device)
        return z


class PriorGP(PriorBase):
    """
    GP prior per latent dim: z_{:, :, d} ~ N(0, K_T) with RBF kernel over time.
    Kernel built dynamically per T (variable-length safe).
    """
    def __init__(self, lengthscale: float, sigma: float, jitter: float):
        super().__init__()
        self.log_lengthscale = nn.Parameter(torch.log(torch.tensor(float(lengthscale))))
        self.log_sigma = nn.Parameter(torch.log(torch.tensor(float(sigma))))
        self.jitter = float(jitter)

    def _kernel(self, T: int, device):
        t = torch.arange(T, device=device).float()
        dt2 = (t[:, None] - t[None, :]) ** 2
        ell = torch.exp(self.log_lengthscale).clamp(1e-3, 1e6)
        sig = torch.exp(self.log_sigma).clamp(1e-6, 1e6)
        K = (sig**2) * torch.exp(-0.5 * dt2 / (ell**2))
        K = K + self.jitter * torch.eye(T, device=device)
        return K

    def _chol(self, K: torch.Tensor) -> torch.Tensor:
        try:
            return torch.linalg.cholesky(K)
        except RuntimeError:
            # defensive: add more jitter and retry
            T = K.size(0)
            K2 = K + (10.0 * self.jitter) * torch.eye(T, device=K.device)
            return torch.linalg.cholesky(K2)

    def log_p(self, z):
        B, T, Dz = z.shape
        K = self._kernel(T, z.device)
        L = self._chol(K)

        z_td = z.transpose(1, 2).contiguous().view(B * Dz, T, 1)
        alpha = torch.cholesky_solve(z_td, L)
        quad = (z_td * alpha).sum(dim=(1, 2))  # (B*Dz,)
        logdet = 2.0 * torch.log(torch.diagonal(L)).sum()
        const = T * math.log(2 * math.pi)
        lp = -0.5 * (quad + logdet + const)    # (B*Dz,)
        return lp.view(B, Dz).sum(dim=1)

    def log_p_per_dim(self, z):
        # (B,T,Dz) -> (B,Dz) each dim independently as MVN over time
        B, T, Dz = z.shape
        K = self._kernel(T, z.device)
        L = self._chol(K)
        logdet = 2.0 * torch.log(torch.diagonal(L)).sum()
        const = T * math.log(2 * math.pi)

        # Solve per (B,Dz) streams: reshape to (B*Dz,T,1)
        z_td = z.transpose(1, 2).contiguous().view(B * Dz, T, 1)
        alpha = torch.cholesky_solve(z_td, L)            # (B*Dz,T,1)
        quad = (z_td * alpha).sum(dim=(1, 2))            # (B*Dz,)
        lp = -0.5 * (quad + logdet + const)              # (B*Dz,)
        return lp.view(B, Dz)

    @torch.no_grad()
    def sample(self, B, T, Dz, device):
        K = self._kernel(T, device)
        L = self._chol(K)
        eps = torch.randn(B, Dz, T, device=device)
        z = torch.matmul(eps, L.T)                       # (B,Dz,T)
        return z.permute(0, 2, 1).contiguous()           # (B,T,Dz)



# VAE for LM

class VAETextLM(nn.Module):
    def __init__(self, cfg: CFG, prior: PriorBase):
        super().__init__()
        self.cfg = cfg
        self.prior = prior

        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.block_size, cfg.d_model)
        self.drop = nn.Dropout(cfg.dropout)

        self.encoder = TransformerEncoder(cfg.d_model, cfg.n_heads, cfg.n_layers, cfg.dropout)

        self.to_mu = nn.Linear(cfg.d_model, cfg.z_dim)
        self.to_logvar = nn.Linear(cfg.d_model, cfg.z_dim)

        self.z_proj = nn.Linear(cfg.z_dim, cfg.d_model)
        self.decoder = TransformerDecoderLM(cfg.d_model, cfg.n_heads, cfg.n_layers, cfg.dropout)

        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

    def encode(self, x_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, T = x_ids.shape
        tok = self.tok_emb(x_ids)
        pos = self.pos_emb(torch.arange(T, device=x_ids.device))[None, :, :]
        h = self.drop(tok + pos)
        h = self.encoder(h)
        mu = self.to_mu(h)
        logvar = self.to_logvar(h).clamp(-12.0, 6.0)
        return mu, logvar

    def reparam(self, mu, logvar, n_samples: int):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn((n_samples,) + mu.shape, device=mu.device)
        return mu[None] + eps * std[None]  # (S,B,T,Dz)

    def log_q_total(self, z, mu, logvar) -> torch.Tensor:
        """
        Total log q(z|x) for sampled z.
        z: (S,B,T,Dz); mu/logvar: (B,T,Dz)
        returns: (S,B)
        """
        var = torch.exp(logvar)
        log2pi = math.log(2 * math.pi)
        diff2 = (z - mu[None]) ** 2
        lq = -0.5 * (diff2 / var[None] + logvar[None] + log2pi)
        return lq.sum(dim=(2, 3))  # (S,B)

    def log_q_per_dim(self, z_btD: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        Per-dim log q(z|x) summed over time.
        z_btD: (B,T,Dz); mu/logvar: (B,T,Dz)
        returns: (B,Dz)
        """
        var = torch.exp(logvar)
        log2pi = math.log(2 * math.pi)
        diff2 = (z_btD - mu) ** 2
        lq = -0.5 * (diff2 / var + logvar + log2pi)  # (B,T,Dz)
        return lq.sum(dim=1)  # (B,Dz)

    def decode_logits(self, x_ids: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        B, T = x_ids.shape
        tok = self.tok_emb(x_ids)
        pos = self.pos_emb(torch.arange(T, device=x_ids.device))[None, :, :]
        h = self.drop(tok + pos + self.z_proj(z))
        h = self.decoder(h)
        logits = self.lm_head(h)
        return logits

    def forward(self, x_ids: torch.Tensor, y_ids: torch.Tensor, beta: float) -> Dict[str, torch.Tensor]:
        """
        Returns averages per-token + extra tensors for metric computation.
        """
        B, T = x_ids.shape
        mu, logvar = self.encode(x_ids)
        zS = self.reparam(mu, logvar, self.cfg.n_z_samples)

        nll_list, logq_list, logp_list = [], [], []
        last_logits = None
        last_z = None

        for s in range(self.cfg.n_z_samples):
            z = zS[s]
            logits = self.decode_logits(x_ids, z)
            last_logits = logits
            last_z = z

            nll = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                y_ids.view(-1),
                reduction="none",
            ).view(B, T).sum(dim=1)  # (B,)
            nll_list.append(nll)

            logq = self.log_q_total(zS[s:s+1], mu, logvar)[0]  # (B,)
            logq_list.append(logq)

            logp = self.prior.log_p(z)  # (B,)
            logp_list.append(logp)

        nll = torch.stack(nll_list, dim=0).mean(dim=0)   # (B,)
        logq = torch.stack(logq_list, dim=0).mean(dim=0)
        logp = torch.stack(logp_list, dim=0).mean(dim=0)

        kl = (logq - logp)                               # (B,)
        elbo = -(nll + beta * kl)                        # (B,)

        n_tokens = T
        nll_tok = nll.mean() / n_tokens
        kl_tok = kl.mean() / n_tokens
        elbo_tok = elbo.mean() / n_tokens

        ppl = torch.exp(nll_tok.detach())
        bits_per_tok = (nll_tok.detach() / math.log(2.0))

        loss = -(elbo_tok)
        return {
            "loss": loss,
            "elbo_tok": elbo_tok.detach(),
            "nll_tok": nll_tok.detach(),
            "kl_tok": kl_tok.detach(),
            "ppl": ppl.detach(),
            "bits_per_tok": bits_per_tok.detach(),
            # extras
            "mu": mu,
            "logvar": logvar,
            "z_last": last_z,
            "logits_last": last_logits,
        }

    @torch.no_grad()
    def generate(self, prompt_ids: torch.Tensor, max_new_tokens: int = 80, top_k: int = 0) -> torch.Tensor:
        self.eval()
        device = prompt_ids.device
        B, Tp = prompt_ids.shape
        total_T = min(self.cfg.block_size, Tp + max_new_tokens)

        out = prompt_ids.clone()
        z = self.prior.sample(B, total_T, self.cfg.z_dim, device=device)

        for _ in range(max_new_tokens):
            Tcur = out.size(1)
            if Tcur >= total_T:
                break

            x = out[:, :Tcur]
            zcur = z[:, :Tcur]
            logits = self.decode_logits(x, zcur)
            next_logits = logits[:, -1, :]

            if top_k and top_k > 0:
                vals, idx = torch.topk(next_logits, k=top_k, dim=-1)
                probs = torch.zeros_like(next_logits).scatter_(-1, idx, F.softmax(vals, dim=-1))
            else:
                probs = F.softmax(next_logits, dim=-1)

            next_id = torch.multinomial(probs, num_samples=1)
            out = torch.cat([out, next_id], dim=1)

        return out



# Perf / GPU utilities

def _now():
    return time.perf_counter()

def sync_if_cuda(device: str):
    if isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available():
        torch.cuda.synchronize()

def cuda_mem_reset(device: str):
    if isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

@torch.no_grad()
def cuda_mem_snapshot_mb(device: str) -> Dict[str, float]:
    if not (isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available()):
        return {"peak_alloc_MB": float("nan"), "peak_reserved_MB": float("nan")}
    peak_alloc = torch.cuda.max_memory_allocated() / (1024**2)
    peak_reserved = torch.cuda.max_memory_reserved() / (1024**2)
    return {"peak_alloc_MB": float(peak_alloc), "peak_reserved_MB": float(peak_reserved)}

class RunningMean:
    def __init__(self):
        self.sum = 0.0
        self.n = 0
    def add(self, x: float):
        if math.isfinite(x):
            self.sum += float(x)
            self.n += 1
    def mean(self) -> float:
        return self.sum / max(1, self.n)



# Metric helpers

def safe_float(x: Any) -> float:
    try:
        return float(x)
    except Exception:
        return float("nan")

@torch.no_grad()
def logits_entropy_and_topk_mass(logits: torch.Tensor, top_k: int = 50) -> Dict[str, float]:
    """
    logits: (B,T,V)
    Returns mean token entropy and mean probability mass of top-k.
    """
    probs = F.softmax(logits, dim=-1)
    logp = torch.log(probs.clamp_min(1e-12))
    ent = -(probs * logp).sum(dim=-1)  # (B,T)
    ent_mean = ent.mean().item()

    if top_k and top_k > 0:
        V = probs.size(-1)
        vals, _ = torch.topk(probs, k=min(top_k, V), dim=-1)
        topk_mass = vals.sum(dim=-1).mean().item()
    else:
        topk_mass = float("nan")

    return {"tok_entropy_mean": float(ent_mean), "topk_mass_mean": float(topk_mass)}

@torch.no_grad()
def ece_from_logits(logits: torch.Tensor, targets: torch.Tensor, n_bins: int = 15) -> float:
    """
    Expected Calibration Error for next-token prediction.
    logits: (B,T,V), targets: (B,T)
    """
    probs = F.softmax(logits, dim=-1)
    conf, pred = probs.max(dim=-1)      # (B,T)
    acc = (pred == targets).float()

    conf = conf.reshape(-1)
    acc = acc.reshape(-1)

    bins = torch.linspace(0, 1, n_bins + 1, device=logits.device)
    ece = torch.zeros((), device=logits.device)
    for i in range(n_bins):
        lo, hi = bins[i], bins[i + 1]
        mask = (conf > lo) & (conf <= hi) if i > 0 else (conf >= lo) & (conf <= hi)
        if mask.any():
            bin_acc = acc[mask].mean()
            bin_conf = conf[mask].mean()
            ece = ece + (mask.float().mean()) * (bin_acc - bin_conf).abs()
    return float(ece.item())

@torch.no_grad()
def posterior_collapse_ratio(mu: torch.Tensor, logvar: torch.Tensor) -> Dict[str, float]:
    """
    Simple collapse heuristics:
    - mean std (posterior)
    - fraction of dims with small mean(|mu|) and std close to 1
    """
    std = torch.exp(0.5 * logvar)
    mu_abs_mean = mu.abs().mean(dim=(0, 1))   # (Dz,)
    std_mean = std.mean(dim=(0, 1))           # (Dz,)

    collapsed = ((mu_abs_mean < 0.02) & ((std_mean - 1.0).abs() < 0.05)).float().mean().item()
    return {
        "post_std_mean": float(std.mean().item()),
        "post_mu_abs_mean": float(mu.abs().mean().item()),
        "collapse_dim_frac": float(collapsed),
    }

@torch.no_grad()
def latent_autocorr(x: torch.Tensor, lag: int = 1) -> float:
    """
    x: (B,T,D) -> mean autocorrelation across dims
    """
    if x.size(1) <= lag:
        return float("nan")
    a = x[:, :-lag, :].reshape(-1, x.size(-1))
    b = x[:, lag:, :].reshape(-1, x.size(-1))
    a = a - a.mean(dim=0, keepdim=True)
    b = b - b.mean(dim=0, keepdim=True)
    cov = (a * b).mean(dim=0)
    va = (a * a).mean(dim=0).clamp_min(1e-8)
    vb = (b * b).mean(dim=0).clamp_min(1e-8)
    corr = (cov / torch.sqrt(va * vb)).mean().item()
    return float(corr)

def distinct_ngrams(token_seq: List[int], n: int) -> float:
    if len(token_seq) < n or n <= 0:
        return 0.0
    total = 0
    uniq = set()
    for i in range(len(token_seq) - n + 1):
        uniq.add(tuple(token_seq[i:i+n]))
        total += 1
    return (len(uniq) / total) if total > 0 else 0.0

def repetition_rate(token_seq: List[int], n: int) -> float:
    if len(token_seq) < n or n <= 0:
        return 0.0
    total = 0
    counts = {}
    for i in range(len(token_seq) - n + 1):
        ng = tuple(token_seq[i:i+n])
        counts[ng] = counts.get(ng, 0) + 1
        total += 1
    if total == 0:
        return 0.0
    repeated = sum(c for c in counts.values() if c > 1)
    return repeated / total

@torch.no_grad()
def kl_per_dim_stats(
    model: VAETextLM,
    mu: torch.Tensor,
    logvar: torch.Tensor,
    z: torch.Tensor,
    eps: float,
) -> Dict[str, float]:
    """
    Bits-back style diagnostics:
    KL_dim[d] ≈ E_batch[ (log q_d(z|x) - log p_d(z)) ] / T
    where log q_d sums over time, and log p_d sums over time (priors factorize over dims).
    """
    B, T, Dz = z.shape
    logq_bd = model.log_q_per_dim(z, mu, logvar)               # (B,Dz)
    logp_bd = model.prior.log_p_per_dim(z)                     # (B,Dz)
    kld_bd = (logq_bd - logp_bd) / max(1, T)                   # per-token
    kld_d = kld_bd.mean(dim=0)                                 # (Dz,)

    kld = kld_d.detach().cpu().numpy()
    kld_mean = float(np.mean(kld))
    kld_med = float(np.median(kld))
    kld_max = float(np.max(kld))
    frac_small = float(np.mean(kld < eps))

    return {
        "kldim_mean": kld_mean,
        "kldim_median": kld_med,
        "kldim_max": kld_max,
        "kldim_frac_below_eps": frac_small,
    }

@torch.no_grad()
def mi_proxy_batch(
    mu: torch.Tensor,
    logvar: torch.Tensor,
    z: torch.Tensor,
    max_components: int,
    max_points: int,
    seed: int = 0,
) -> float:
    """
    MI(x,z) proxy (batch-based):
        I ≈ E_q[log q(z|x)] - E_q[log q(z)]
    where q(z) approximated by mixture of batch posteriors.

    Implementation notes:
    - We use diagonal Gaussians; treat each (b,t) as a mixture component.
    - We subsample components and points for cost control.
    - We use z points coming from the current batch (z sample).
    """
    # Flatten components: C = B*T
    B, T, Dz = z.shape
    C = B * T
    if C == 0:
        return float("nan")

    rng = np.random.default_rng(seed)

    # Components: (C,Dz)
    mu_c = mu.reshape(C, Dz)
    lv_c = logvar.reshape(C, Dz)

    # Points: pick from z at random (b,t)
    z_c = z.reshape(C, Dz)

    # Subsample components and points
    comp_idx = np.arange(C)
    pt_idx = np.arange(C)
    if C > max_components:
        comp_idx = rng.choice(comp_idx, size=max_components, replace=False)
    if C > max_points:
        pt_idx = rng.choice(pt_idx, size=max_points, replace=False)

    mu_s = mu_c[torch.as_tensor(comp_idx, device=mu.device)]
    lv_s = lv_c[torch.as_tensor(comp_idx, device=mu.device)]
    var_s = torch.exp(lv_s).clamp_min(1e-12)

    z_pts = z_c[torch.as_tensor(pt_idx, device=mu.device)]
    mu_pts = mu_c[torch.as_tensor(pt_idx, device=mu.device)]
    lv_pts = lv_c[torch.as_tensor(pt_idx, device=mu.device)]
    var_pts = torch.exp(lv_pts).clamp_min(1e-12)

    log2pi = math.log(2 * math.pi)

    # E_q[log q(z|x)] approximated by average over chosen points using their own component params
    # log N(z | mu, var)
    lq_cond = -0.5 * (((z_pts - mu_pts) ** 2) / var_pts + lv_pts + log2pi).sum(dim=-1)  # (P,)
    eq_logq_z_given_x = lq_cond.mean()

    # E_q[log q(z)] with q(z) as mixture of components
    # log q(z) = log mean_k N(z | mu_k, var_k)
    # For each point, compute logsumexp over components.
    # Compute (P,K,D) costs: P*max_components*Dz; keep bounded via cfg.
    z_exp = z_pts[:, None, :]              # (P,1,D)
    mu_exp = mu_s[None, :, :]              # (1,K,D)
    lv_exp = lv_s[None, :, :]              # (1,K,D)
    var_exp = var_s[None, :, :]            # (1,K,D)

    lcomp = -0.5 * (((z_exp - mu_exp) ** 2) / var_exp + lv_exp + log2pi).sum(dim=-1)  # (P,K)
    lmix = torch.logsumexp(lcomp, dim=1) - math.log(lcomp.size(1))                     # (P,)
    eq_logq_z = lmix.mean()

    mi = (eq_logq_z_given_x - eq_logq_z).item()
    return float(mi)

def rd_points(nll_tok: float, kl_tok: float, betas: Tuple[float, ...]) -> Dict[str, float]:
    """
    Rate–Distortion logging: for each beta, report derived ELBO/token (negative objective).
    """
    out = {}
    for b in betas:
        # elbo_tok(beta) = -(nll_tok + beta*kl_tok)
        out[f"rd_elbo_tok_beta_{b:g}"] = -(nll_tok + b * kl_tok)
    # Also log the base R and D
    out["rd_nll_tok"] = float(nll_tok)
    out["rd_kl_tok"] = float(kl_tok)
    return out



# Generation metrics (aggregated + per prompt)

@torch.no_grad()
def generation_metrics(
    model: VAETextLM,
    tokenizer,
    device: str,
    n_prompts: int,
    max_new: int,
    top_k: int,
) -> Dict[str, Any]:
    model.eval()
    prompts = [
        "The meaning of life is",
        "In the middle of the night",
        "The government announced that",
        "A new theory suggests",
        "Once upon a time",
        "The experiment shows",
        "In a shocking discovery",
        "The book describes",
        "Scientists found that",
        "The president said",
        "In the future,",
        "The story begins",
    ]

    # perf
    cuda_mem_reset(device)
    sync_if_cuda(device)
    t0 = _now()
    gen_new_tokens_total = 0

    per_prompt: List[Dict[str, Any]] = []
    decoded_samples: List[str] = []

    for i in range(n_prompts):
        p = prompts[i % len(prompts)]
        p_ids = tokenizer(p, return_tensors="pt").input_ids.to(device)
        if p_ids.size(1) > model.cfg.block_size // 2:
            p_ids = p_ids[:, : model.cfg.block_size // 2]

        out_ids = model.generate(p_ids, max_new_tokens=max_new, top_k=top_k)
        seq = out_ids[0].tolist()

        prompt_len = int(p_ids.size(1))
        new_tokens = max(0, len(seq) - prompt_len)
        gen_new_tokens_total += new_tokens

        # per prompt stats (on full seq)
        st = {
            "prompt": p,
            "prompt_len": prompt_len,
            "total_len": len(seq),
            "new_tokens": new_tokens,
            "rep2": repetition_rate(seq, 2),
            "rep3": repetition_rate(seq, 3),
            "distinct2": distinct_ngrams(seq, 2),
            "distinct3": distinct_ngrams(seq, 3),
        }

        # short sample text (truncate to be JSON-friendly)
        txt = tokenizer.decode(seq[: min(len(seq), 250)], skip_special_tokens=True)
        st["sample"] = txt
        per_prompt.append(st)

        if i < 3:
            decoded_samples.append(txt)

    sync_if_cuda(device)
    dt = max(1e-9, (_now() - t0))
    mem = cuda_mem_snapshot_mb(device)

    # Aggregate across prompts
    d1 = float(np.mean([distinct_ngrams([*map(int, s["sample"].encode("utf-8")[:0])], 1) for s in []]) if False else 0.0)  # unused placeholder
    rep2 = float(np.mean([p["rep2"] for p in per_prompt])) if per_prompt else 0.0
    rep3 = float(np.mean([p["rep3"] for p in per_prompt])) if per_prompt else 0.0
    distinct2 = float(np.mean([p["distinct2"] for p in per_prompt])) if per_prompt else 0.0
    distinct3 = float(np.mean([p["distinct3"] for p in per_prompt])) if per_prompt else 0.0
    lengths = [p["total_len"] for p in per_prompt]
    newlens = [p["new_tokens"] for p in per_prompt]

    return {
        # aggregate
        "gen_distinct2": distinct2,
        "gen_distinct3": distinct3,
        "gen_rep2": rep2,
        "gen_rep3": rep3,
        "gen_len_mean": float(np.mean(lengths)) if lengths else 0.0,
        "gen_len_std": float(np.std(lengths)) if lengths else 0.0,
        "gen_new_tokens_mean": float(np.mean(newlens)) if newlens else 0.0,
        "gen_new_tokens_total": int(gen_new_tokens_total),

        # perf
        "time_gen_seconds_total": float(dt),
        "time_gen_tokens_per_s": float(gen_new_tokens_total / dt),
        "gpu_gen_peak_alloc_MB": mem["peak_alloc_MB"],
        "gpu_gen_peak_reserved_MB": mem["peak_reserved_MB"],

        # per prompt breakdown (what you asked: not only aggregated)
        "gen_per_prompt": per_prompt,

        # a few samples
        "gen_sample_0": decoded_samples[0] if len(decoded_samples) > 0 else "",
        "gen_sample_1": decoded_samples[1] if len(decoded_samples) > 1 else "",
        "gen_sample_2": decoded_samples[2] if len(decoded_samples) > 2 else "",
    }



# Eval (many metrics + perf + bits-back + MI + RD)

@torch.no_grad()
def eval_many_metrics(
    model: VAETextLM,
    loader: DataLoader,
    device: str,
    beta: float,
    max_batches: int,
    ece_bins: int,
    gen_do: bool,
    tokenizer=None,
    gen_prompts: int = 12,
    gen_max_new: int = 96,
    gen_top_k: int = 50,
    rd_betas: Tuple[float, ...] = (0.0, 1.0),
    measure_perf: bool = True,
    kldim_eps: float = 0.01,
    mi_max_components: int = 512,
    mi_max_points: int = 256,
    mi_seed: int = 0,
) -> Dict[str, Any]:
    model.eval()

    # basic averages
    acc = {
        "nll_tok": 0.0,
        "kl_tok": 0.0,
        "elbo_tok": 0.0,
        "tok_entropy_mean": 0.0,
        "topk_mass_mean": 0.0,
        "ece": 0.0,
        "post_std_mean": 0.0,
        "post_mu_abs_mean": 0.0,
        "collapse_dim_frac": 0.0,
        "mu_ac1": 0.0,
        "mu_ac5": 0.0,
        "z_ac1": 0.0,
        "z_ac5": 0.0,
        # bits-back
        "kldim_mean": 0.0,
        "kldim_median": 0.0,
        "kldim_max": 0.0,
        "kldim_frac_below_eps": 0.0,
        # MI
        "mi_proxy": 0.0,
    }
    n = 0

    # perf accumulators
    step_ms = RunningMean()
    tokens_per_s = RunningMean()
    total_tokens = 0

    if measure_perf:
        cuda_mem_reset(device)

    for i, (x, y) in enumerate(loader):
        if i >= max_batches:
            break
        x = x.to(device)
        y = y.to(device)

        if measure_perf:
            sync_if_cuda(device)
            t0 = _now()

        out = model(x, y, beta=beta)

        if measure_perf:
            sync_if_cuda(device)
            dt = max(1e-9, (_now() - t0))
            step_ms.add(1000.0 * dt)
            B, T = x.shape
            tok = int(B * T)
            total_tokens += tok
            tokens_per_s.add(tok / dt)

        nll_tok = safe_float(out["nll_tok"])
        kl_tok = safe_float(out["kl_tok"])
        elbo_tok = safe_float(out["elbo_tok"])
        if not (math.isfinite(nll_tok) and math.isfinite(kl_tok) and math.isfinite(elbo_tok)):
            continue

        logits = out["logits_last"]          # (B,T,V)
        mu = out["mu"]                       # (B,T,Dz)
        logvar = out["logvar"]               # (B,T,Dz)
        zlast = out["z_last"]                # (B,T,Dz)

        ent_topk = logits_entropy_and_topk_mass(logits, top_k=gen_top_k)
        ece = ece_from_logits(logits, y, n_bins=ece_bins)
        post = posterior_collapse_ratio(mu, logvar)

        mu_ac1 = latent_autocorr(mu.detach(), lag=1)
        mu_ac5 = latent_autocorr(mu.detach(), lag=5)
        z_ac1 = latent_autocorr(zlast.detach(), lag=1)
        z_ac5 = latent_autocorr(zlast.detach(), lag=5)

        # bits-back KL-per-dim
        kld = kl_per_dim_stats(model, mu.detach(), logvar.detach(), zlast.detach(), eps=kldim_eps)

        # MI proxy (batch-based)
        mi = mi_proxy_batch(
            mu.detach(), logvar.detach(), zlast.detach(),
            max_components=mi_max_components,
            max_points=mi_max_points,
            seed=mi_seed,
        )

        acc["nll_tok"] += nll_tok
        acc["kl_tok"] += kl_tok
        acc["elbo_tok"] += elbo_tok
        acc["tok_entropy_mean"] += ent_topk["tok_entropy_mean"]
        acc["topk_mass_mean"] += ent_topk["topk_mass_mean"]
        acc["ece"] += ece
        acc["post_std_mean"] += post["post_std_mean"]
        acc["post_mu_abs_mean"] += post["post_mu_abs_mean"]
        acc["collapse_dim_frac"] += post["collapse_dim_frac"]
        acc["mu_ac1"] += mu_ac1
        acc["mu_ac5"] += mu_ac5
        acc["z_ac1"] += z_ac1
        acc["z_ac5"] += z_ac5

        acc["kldim_mean"] += kld["kldim_mean"]
        acc["kldim_median"] += kld["kldim_median"]
        acc["kldim_max"] += kld["kldim_max"]
        acc["kldim_frac_below_eps"] += kld["kldim_frac_below_eps"]

        acc["mi_proxy"] += mi
        n += 1

    if n == 0:
        return {k: float("nan") for k in acc.keys()}

    for k in acc:
        acc[k] /= n

    acc["ppl"] = math.exp(acc["nll_tok"])
    acc["bits_per_tok"] = acc["nll_tok"] / math.log(2.0)

    # Rate–Distortion derived points
    acc.update(rd_points(acc["nll_tok"], acc["kl_tok"], rd_betas))

    # Perf snapshots
    if measure_perf:
        mem = cuda_mem_snapshot_mb(device)
        acc["time_eval_step_ms_mean"] = step_ms.mean()
        acc["time_eval_tokens_per_s_mean"] = tokens_per_s.mean()
        acc["gpu_eval_peak_alloc_MB"] = mem["peak_alloc_MB"]
        acc["gpu_eval_peak_reserved_MB"] = mem["peak_reserved_MB"]
        acc["eval_total_tokens_measured"] = int(total_tokens)

    # Generation metrics (includes its own perf + GPU peaks)
    if gen_do and tokenizer is not None:
        gen = generation_metrics(
            model, tokenizer, device=device,
            n_prompts=gen_prompts, max_new=gen_max_new, top_k=gen_top_k
        )
        acc.update(gen)

    return acc



# Schedules

def lr_schedule(step, cfg: CFG):
    if step < cfg.warmup_steps:
        return cfg.lr * (step / max(1, cfg.warmup_steps))
    progress = (step - cfg.warmup_steps) / max(1, cfg.max_steps - cfg.warmup_steps)
    return cfg.lr * (0.1 + 0.9 * 0.5 * (1.0 + math.cos(math.pi * progress)))

def beta_schedule(step, cfg: CFG):
    if step >= cfg.beta_warmup_steps:
        return cfg.beta_end
    a = step / max(1, cfg.beta_warmup_steps)
    return cfg.beta_start + a * (cfg.beta_end - cfg.beta_start)



# Model builder

def make_model(cfg: CFG, variant: str) -> VAETextLM:
    if variant == "vae":
        prior = PriorIID()
    elif variant == "vae_ar":
        prior = PriorAR(cfg.z_dim, rho_init=cfg.ar_init_rho, sigma=cfg.ar_sigma)
    elif variant == "vae_gp":
        prior = PriorGP(cfg.gp_lengthscale, cfg.gp_sigma, cfg.gp_jitter)
    else:
        raise ValueError("variant must be one of: vae, vae_gp, vae_ar")
    return VAETextLM(cfg, prior)



# Train one variant (robust + perf/gpu)

def train_one_variant(cfg: CFG, variant: str, train_loader, val_loader, tokenizer) -> Dict[str, Any]:
    device = cfg.device
    model = make_model(cfg, variant).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    use_amp = bool(cfg.amp and device.startswith("cuda"))
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

    os.makedirs(cfg.out_dir, exist_ok=True)
    run_stamp = int(time.time())
    log_path = os.path.join(cfg.out_dir, f"{cfg.run_name}_{variant}_{run_stamp}.jsonl")

    best_val = float("inf")
    best: Dict[str, Any] = {}

    t0 = time.time()
    pbar = tqdm(total=cfg.max_steps, desc=f"train[{variant}]")

    it = iter(train_loader)

    # windowed training perf
    train_step_ms = RunningMean()
    train_tokens_per_s = RunningMean()
    train_tokens_total = 0
    cuda_mem_reset(device)

    for step in range(1, cfg.max_steps + 1):
        try:
            x, y = next(it)
        except StopIteration:
            it = iter(train_loader)
            x, y = next(it)

        x, y = x.to(device), y.to(device)

        beta = beta_schedule(step, cfg)
        lr = lr_schedule(step, cfg)
        for pg in opt.param_groups:
            pg["lr"] = lr

        model.train()
        opt.zero_grad(set_to_none=True)

        # timing per step
        if use_amp:
            sync_if_cuda(device)
        t_step0 = _now()

        with torch.amp.autocast("cuda", enabled=use_amp):
            out = model(x, y, beta=beta)
            loss = out["loss"]

        scaler.scale(loss).backward()
        if cfg.grad_clip > 0:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        scaler.step(opt)
        scaler.update()

        if use_amp:
            sync_if_cuda(device)
        dt = max(1e-9, (_now() - t_step0))
        train_step_ms.add(1000.0 * dt)
        B, T = x.shape
        tok = int(B * T)
        train_tokens_total += tok
        train_tokens_per_s.add(tok / dt)

        if step % cfg.log_every == 0:
            pbar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "nll_tok": f"{float(out['nll_tok']):.4f}",
                "kl_tok": f"{float(out['kl_tok']):.4f}",
                "ppl": f"{float(out['ppl']):.2f}",
                "beta": f"{beta:.2f}",
                "lr": f"{lr:.1e}",
            })

        if step % cfg.eval_every == 0 or step == 1:
            # snapshot memory usage for train window (since last reset)
            mem_train = cuda_mem_snapshot_mb(device)
            train_perf = {
                "time_train_step_ms_mean_window": train_step_ms.mean(),
                "time_train_tokens_per_s_mean_window": train_tokens_per_s.mean(),
                "gpu_train_peak_alloc_MB_window": mem_train["peak_alloc_MB"],
                "gpu_train_peak_reserved_MB_window": mem_train["peak_reserved_MB"],
                "train_tokens_measured_window": int(train_tokens_total),
            }

            # eval train (few batches) + val (more batches)
            train_eval = eval_many_metrics(
                model, train_loader, device=device, beta=1.0,
                max_batches=min(cfg.eval_train_batches, cfg.eval_max_batches),
                ece_bins=cfg.ece_bins,
                gen_do=False,
                tokenizer=None,
                gen_prompts=0,
                gen_max_new=0,
                gen_top_k=cfg.eval_gen_top_k,
                rd_betas=cfg.rd_betas,
                measure_perf=True,
                kldim_eps=cfg.kldim_eps,
                mi_max_components=cfg.mi_max_components,
                mi_max_points=cfg.mi_max_points,
                mi_seed=cfg.mi_seed,
            )
            val_eval = eval_many_metrics(
                model, val_loader, device=device, beta=1.0,
                max_batches=cfg.eval_max_batches,
                ece_bins=cfg.ece_bins,
                gen_do=True,
                tokenizer=tokenizer,
                gen_prompts=cfg.eval_gen_prompts,
                gen_max_new=cfg.eval_gen_max_new,
                gen_top_k=cfg.eval_gen_top_k,
                rd_betas=cfg.rd_betas,
                measure_perf=True,
                kldim_eps=cfg.kldim_eps,
                mi_max_components=cfg.mi_max_components,
                mi_max_points=cfg.mi_max_points,
                mi_seed=cfg.mi_seed,
            )

            # best logic on val nll
            val_nll = val_eval.get("nll_tok", float("inf"))
            if math.isfinite(val_nll) and (val_nll < best_val):
                best_val = float(val_nll)
                best = {
                    "variant": variant,
                    "step": step,
                    "wall_s": time.time() - t0,
                    "best_val_nll_tok": best_val,
                    "train_perf_window": train_perf,
                    "train": train_eval,
                    "val": val_eval,
                }

            rec = {
                "variant": variant,
                "step": step,
                "wall_s": time.time() - t0,
                "lr": float(lr),
                "beta": float(beta),
                "train_perf_window": train_perf,
                "train": train_eval,
                "val": val_eval,
                "best_val_nll_tok_so_far": float(best_val),
            }
            with open(log_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(rec, ensure_ascii=False) + "\n")

            print(
                f"[eval] {variant} step={step:5d} "
                f"train_tok/s={train_eval.get('time_eval_tokens_per_s_mean', float('nan')):.1f} "
                f"val_tok/s={val_eval.get('time_eval_tokens_per_s_mean', float('nan')):.1f} "
                f"gen_tok/s={val_eval.get('time_gen_tokens_per_s', float('nan')):.1f} "
                f"val_ppl={val_eval.get('ppl', float('nan')):.2f} "
                f"val_nll={val_eval.get('nll_tok', float('nan')):.4f} val_kl={val_eval.get('kl_tok', float('nan')):.4f} "
                f"ece={val_eval.get('ece', float('nan')):.3f} ent={val_eval.get('tok_entropy_mean', float('nan')):.3f} "
                f"mi={val_eval.get('mi_proxy', float('nan')):.3f} "
                f"kldim_med={val_eval.get('kldim_median', float('nan')):.3f} "
                f"train_peakMB={train_perf.get('gpu_train_peak_alloc_MB_window', float('nan')):.1f}"
            )

            # reset window perf + GPU peaks
            train_step_ms = RunningMean()
            train_tokens_per_s = RunningMean()
            train_tokens_total = 0
            cuda_mem_reset(device)

        pbar.update(1)

    pbar.close()

    if not best:
        best = {"variant": variant, "step": cfg.max_steps, "wall_s": time.time() - t0, "best_val_nll_tok": best_val}
    best["log_path"] = log_path
    return best



# Main

def main():
    cfg = CFG()
    set_seed(cfg.seed)

    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name)
    # Avoid HF warning about "sequence length > model max length" during tokenization chunks:
    tokenizer.model_max_length = int(1e9)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    cfg.vocab_size = len(tokenizer)

    train_loader, val_loader = load_wt2_blocks(cfg, tokenizer)

    results = []
    for variant in ["vae", "vae_gp", "vae_ar"]:
        best = train_one_variant(cfg, variant, train_loader, val_loader, tokenizer)
        results.append(best)

    print("\n=== SUMMARY (best checkpoints) ===")
    results = sorted(results, key=lambda r: r.get("best_val_nll_tok", float("inf")))

    for r in results:
        variant = r.get("variant", "?")
        step = r.get("step", -1)
        best_val_nll = r.get("best_val_nll_tok", float("nan"))
        wall = r.get("wall_s", float("nan"))
        log_path = r.get("log_path", "")

        val = r.get("val", {})
        trp = r.get("train_perf_window", {})

        print(
            f"{variant:7s} | step={step:5d} | wall_s={wall:8.1f} | best_val_nll/tok={best_val_nll:.4f} | "
            f"val_ppl={val.get('ppl', float('nan')):.2f} | val_kl/tok={val.get('kl_tok', float('nan')):.4f} | "
            f"eval_tok/s={val.get('time_eval_tokens_per_s_mean', float('nan')):.1f} | "
            f"gen_tok/s={val.get('time_gen_tokens_per_s', float('nan')):.1f} | "
            f"train_step_ms={trp.get('time_train_step_ms_mean_window', float('nan')):.2f} | "
            f"train_peakMB={trp.get('gpu_train_peak_alloc_MB_window', float('nan')):.1f} | "
            f"eval_peakMB={val.get('gpu_eval_peak_alloc_MB', float('nan')):.1f} | "
            f"gen_peakMB={val.get('gpu_gen_peak_alloc_MB', float('nan')):.1f} | "
            f"mi={val.get('mi_proxy', float('nan')):.3f} | "
            f"kldim_med={val.get('kldim_median', float('nan')):.3f} frac<eps={val.get('kldim_frac_below_eps', float('nan')):.2f} | "
            f"d2={val.get('gen_distinct2', float('nan')):.3f} rep2={val.get('gen_rep2', float('nan')):.3f} | "
            f"log={log_path}"
        )

    os.makedirs(cfg.out_dir, exist_ok=True)
    out_json = os.path.join(cfg.out_dir, f"{cfg.run_name}_BEST_{int(time.time())}.json")
    with open(out_json, "w", encoding="utf-8") as f:
        json.dump(
            {"cfg": asdict(cfg), "results": results},
            f,
            indent=2,
            ensure_ascii=False,
        )
    print(f"\nSaved best summary JSON: {out_json}")


if __name__ == "__main__":
    main()


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (43766 > 1024). Running this sequence through the model will result in indexing errors
train[vae]:   0%|          | 1/6000 [00:13<22:43:02, 13.63s/it]

[eval] vae step=    1 train_ppl=112344.95 val_ppl=111763.03 val_nll=11.6241 val_kl=66.2300 ece=0.002 ent=10.022 collapse=0.00 mu_ac1=0.005


train[vae]:   7%|▋         | 400/6000 [01:49<6:01:39,  3.87s/it, loss=6.5331, nll_tok=6.4655, kl_tok=0.2535, ppl=642.60, beta=0.27, lr=3.0e-04] 

[eval] vae step=  400 train_ppl=678.52 val_ppl=795.32 val_nll=6.6787 val_kl=0.1531 ece=0.018 ent=6.375 collapse=0.00 mu_ac1=0.014


train[vae]:  13%|█▎        | 800/6000 [03:24<5:33:43,  3.85s/it, loss=6.0963, nll_tok=6.0539, kl_tok=0.0796, ppl=425.75, beta=0.53, lr=2.9e-04]

[eval] vae step=  800 train_ppl=352.26 val_ppl=521.88 val_nll=6.2574 val_kl=0.0551 ece=0.022 ent=5.985 collapse=0.04 mu_ac1=0.015


train[vae]:  20%|██        | 1200/6000 [05:00<5:08:29,  3.86s/it, loss=5.8712, nll_tok=5.8347, kl_tok=0.0455, ppl=341.97, beta=0.80, lr=2.8e-04]

[eval] vae step= 1200 train_ppl=217.39 val_ppl=418.75 val_nll=6.0373 val_kl=0.0385 ece=0.021 ent=5.754 collapse=0.62 mu_ac1=0.017


train[vae]:  27%|██▋       | 1600/6000 [06:36<4:41:56,  3.84s/it, loss=5.6021, nll_tok=5.5583, kl_tok=0.0438, ppl=259.38, beta=1.00, lr=2.7e-04]

[eval] vae step= 1600 train_ppl=160.63 val_ppl=369.98 val_nll=5.9135 val_kl=0.0282 ece=0.024 ent=5.606 collapse=0.74 mu_ac1=0.005


train[vae]:  33%|███▎      | 2000/6000 [08:11<4:16:55,  3.85s/it, loss=5.2466, nll_tok=5.2144, kl_tok=0.0322, ppl=183.90, beta=1.00, lr=2.4e-04]

[eval] vae step= 2000 train_ppl=128.15 val_ppl=339.01 val_nll=5.8260 val_kl=0.0189 ece=0.030 ent=5.108 collapse=0.88 mu_ac1=0.010


train[vae]:  40%|████      | 2400/6000 [09:47<3:51:31,  3.86s/it, loss=4.8903, nll_tok=4.8700, kl_tok=0.0204, ppl=130.32, beta=1.00, lr=2.2e-04]

[eval] vae step= 2400 train_ppl=105.06 val_ppl=318.71 val_nll=5.7643 val_kl=0.0165 ece=0.025 ent=5.084 collapse=0.88 mu_ac1=0.008


train[vae]:  47%|████▋     | 2800/6000 [11:23<3:26:36,  3.87s/it, loss=4.9312, nll_tok=4.9159, kl_tok=0.0154, ppl=136.44, beta=1.00, lr=1.9e-04]

[eval] vae step= 2800 train_ppl=87.31 val_ppl=309.29 val_nll=5.7343 val_kl=0.0116 ece=0.032 ent=4.821 collapse=0.95 mu_ac1=0.010


train[vae]:  53%|█████▎    | 3200/6000 [12:59<3:00:08,  3.86s/it, loss=4.8145, nll_tok=4.8016, kl_tok=0.0128, ppl=121.71, beta=1.00, lr=1.6e-04]

[eval] vae step= 3200 train_ppl=80.68 val_ppl=304.46 val_nll=5.7186 val_kl=0.0088 ece=0.032 ent=4.727 collapse=0.99 mu_ac1=0.012


train[vae]:  60%|██████    | 3600/6000 [14:35<2:33:53,  3.85s/it, loss=4.6418, nll_tok=4.6297, kl_tok=0.0121, ppl=102.48, beta=1.00, lr=1.3e-04]

[eval] vae step= 3600 train_ppl=64.77 val_ppl=296.85 val_nll=5.6932 val_kl=0.0074 ece=0.034 ent=4.642 collapse=0.99 mu_ac1=0.017


train[vae]:  67%|██████▋   | 4000/6000 [16:11<2:08:19,  3.85s/it, loss=4.3487, nll_tok=4.3384, kl_tok=0.0104, ppl=76.58, beta=1.00, lr=1.0e-04] 

[eval] vae step= 4000 train_ppl=60.26 val_ppl=300.13 val_nll=5.7042 val_kl=0.0059 ece=0.041 ent=4.503 collapse=1.00 mu_ac1=0.011


train[vae]:  73%|███████▎  | 4400/6000 [17:47<1:43:12,  3.87s/it, loss=4.3338, nll_tok=4.3252, kl_tok=0.0086, ppl=75.58, beta=1.00, lr=7.9e-05]

[eval] vae step= 4400 train_ppl=55.25 val_ppl=300.82 val_nll=5.7065 val_kl=0.0045 ece=0.042 ent=4.423 collapse=1.00 mu_ac1=0.016


train[vae]:  80%|████████  | 4800/6000 [19:23<1:17:13,  3.86s/it, loss=4.2107, nll_tok=4.2030, kl_tok=0.0078, ppl=66.88, beta=1.00, lr=5.8e-05]

[eval] vae step= 4800 train_ppl=50.27 val_ppl=299.69 val_nll=5.7028 val_kl=0.0031 ece=0.042 ent=4.393 collapse=1.00 mu_ac1=0.013


train[vae]:  87%|████████▋ | 5200/6000 [20:59<51:33,  3.87s/it, loss=4.2219, nll_tok=4.2185, kl_tok=0.0034, ppl=67.93, beta=1.00, lr=4.3e-05]  

[eval] vae step= 5200 train_ppl=48.05 val_ppl=302.03 val_nll=5.7105 val_kl=0.0028 ece=0.049 ent=4.281 collapse=1.00 mu_ac1=0.016


train[vae]:  93%|█████████▎| 5600/6000 [22:35<25:45,  3.86s/it, loss=4.3254, nll_tok=4.3224, kl_tok=0.0030, ppl=75.37, beta=1.00, lr=3.3e-05]

[eval] vae step= 5600 train_ppl=47.49 val_ppl=302.83 val_nll=5.7132 val_kl=0.0026 ece=0.048 ent=4.302 collapse=1.00 mu_ac1=0.019


train[vae]: 100%|██████████| 6000/6000 [24:11<00:00,  4.13it/s, loss=4.1004, nll_tok=4.0973, kl_tok=0.0031, ppl=60.18, beta=1.00, lr=3.0e-05]

[eval] vae step= 6000 train_ppl=46.26 val_ppl=304.70 val_nll=5.7193 val_kl=0.0018 ece=0.047 ent=4.292 collapse=1.00 mu_ac1=0.018



train[vae_gp]:   0%|          | 1/6000 [00:12<20:54:01, 12.54s/it]

[eval] vae_gp step=    1 train_ppl=109728.96 val_ppl=110570.83 val_nll=11.6134 val_kl=83677.1656 ece=0.002 ent=10.018 collapse=0.00 mu_ac1=0.001


train[vae_gp]:   7%|▋         | 400/6000 [01:55<6:08:57,  3.95s/it, loss=12.4869, nll_tok=6.7556, kl_tok=21.4926, ppl=858.83, beta=0.27, lr=3.0e-04]   

[eval] vae_gp step=  400 train_ppl=536.56 val_ppl=702.82 val_nll=6.5551 val_kl=19.3924 ece=0.019 ent=6.429 collapse=0.00 mu_ac1=0.180


train[vae_gp]:  13%|█▎        | 800/6000 [03:38<5:44:28,  3.97s/it, loss=12.7934, nll_tok=5.9730, kl_tok=12.7881, ppl=392.70, beta=0.53, lr=2.9e-04]

[eval] vae_gp step=  800 train_ppl=295.82 val_ppl=490.15 val_nll=6.1947 val_kl=12.4319 ece=0.023 ent=5.970 collapse=0.00 mu_ac1=0.067


train[vae_gp]:  20%|██        | 1200/6000 [05:21<5:15:40,  3.95s/it, loss=13.9344, nll_tok=5.7330, kl_tok=10.2518, ppl=308.88, beta=0.80, lr=2.8e-04]

[eval] vae_gp step= 1200 train_ppl=209.67 val_ppl=408.05 val_nll=6.0114 val_kl=9.9512 ece=0.022 ent=5.703 collapse=0.00 mu_ac1=0.039


train[vae_gp]:  27%|██▋       | 1600/6000 [07:04<4:50:01,  3.95s/it, loss=14.0886, nll_tok=5.4634, kl_tok=8.6252, ppl=235.91, beta=1.00, lr=2.7e-04] 

[eval] vae_gp step= 1600 train_ppl=154.29 val_ppl=363.03 val_nll=5.8945 val_kl=8.1948 ece=0.024 ent=5.379 collapse=0.00 mu_ac1=0.041


train[vae_gp]:  33%|███▎      | 2000/6000 [08:48<4:25:35,  3.98s/it, loss=12.9881, nll_tok=5.3696, kl_tok=7.6186, ppl=214.77, beta=1.00, lr=2.4e-04]

[eval] vae_gp step= 2000 train_ppl=125.62 val_ppl=342.02 val_nll=5.8349 val_kl=7.0848 ece=0.026 ent=5.115 collapse=0.00 mu_ac1=0.058


train[vae_gp]:  40%|████      | 2400/6000 [10:31<3:57:19,  3.96s/it, loss=11.7545, nll_tok=5.0593, kl_tok=6.6952, ppl=157.48, beta=1.00, lr=2.2e-04]

[eval] vae_gp step= 2400 train_ppl=99.62 val_ppl=322.43 val_nll=5.7759 val_kl=6.9172 ece=0.034 ent=4.911 collapse=0.00 mu_ac1=0.069


train[vae_gp]:  47%|████▋     | 2800/6000 [12:14<3:32:25,  3.98s/it, loss=11.0530, nll_tok=5.0094, kl_tok=6.0436, ppl=149.81, beta=1.00, lr=1.9e-04]

[eval] vae_gp step= 2800 train_ppl=72.43 val_ppl=312.41 val_nll=5.7443 val_kl=6.3137 ece=0.034 ent=4.804 collapse=0.00 mu_ac1=0.087


train[vae_gp]:  53%|█████▎    | 3200/6000 [13:58<3:06:11,  3.99s/it, loss=10.2458, nll_tok=4.6757, kl_tok=5.5702, ppl=107.30, beta=1.00, lr=1.6e-04]

[eval] vae_gp step= 3200 train_ppl=64.02 val_ppl=311.46 val_nll=5.7413 val_kl=5.3648 ece=0.044 ent=4.579 collapse=0.00 mu_ac1=0.089


train[vae_gp]:  60%|██████    | 3600/6000 [15:42<2:39:49,  4.00s/it, loss=9.5473, nll_tok=4.3592, kl_tok=5.1881, ppl=78.20, beta=1.00, lr=1.3e-04]  

[eval] vae_gp step= 3600 train_ppl=54.71 val_ppl=309.23 val_nll=5.7341 val_kl=5.0114 ece=0.047 ent=4.439 collapse=0.00 mu_ac1=0.086


train[vae_gp]:  67%|██████▋   | 4000/6000 [17:26<2:12:56,  3.99s/it, loss=9.2051, nll_tok=4.3007, kl_tok=4.9044, ppl=73.75, beta=1.00, lr=1.0e-04]

[eval] vae_gp step= 4000 train_ppl=51.64 val_ppl=309.32 val_nll=5.7344 val_kl=4.8131 ece=0.046 ent=4.382 collapse=0.00 mu_ac1=0.101


train[vae_gp]:  73%|███████▎  | 4400/6000 [19:09<1:46:22,  3.99s/it, loss=8.8193, nll_tok=4.0793, kl_tok=4.7400, ppl=59.11, beta=1.00, lr=7.9e-05]

[eval] vae_gp step= 4400 train_ppl=43.38 val_ppl=310.20 val_nll=5.7372 val_kl=4.8346 ece=0.051 ent=4.293 collapse=0.00 mu_ac1=0.099


train[vae_gp]:  80%|████████  | 4800/6000 [20:53<1:19:42,  3.99s/it, loss=8.5039, nll_tok=3.9182, kl_tok=4.5856, ppl=50.31, beta=1.00, lr=5.8e-05]

[eval] vae_gp step= 4800 train_ppl=43.38 val_ppl=314.46 val_nll=5.7508 val_kl=4.6392 ece=0.052 ent=4.244 collapse=0.00 mu_ac1=0.114


train[vae_gp]:  87%|████████▋ | 5200/6000 [22:37<53:06,  3.98s/it, loss=8.5154, nll_tok=4.0564, kl_tok=4.4590, ppl=57.76, beta=1.00, lr=4.3e-05]  

[eval] vae_gp step= 5200 train_ppl=38.31 val_ppl=317.71 val_nll=5.7611 val_kl=4.3875 ece=0.057 ent=4.176 collapse=0.00 mu_ac1=0.121


train[vae_gp]:  93%|█████████▎| 5600/6000 [24:21<26:35,  3.99s/it, loss=8.4395, nll_tok=4.0517, kl_tok=4.3878, ppl=57.50, beta=1.00, lr=3.3e-05]

[eval] vae_gp step= 5600 train_ppl=38.47 val_ppl=319.32 val_nll=5.7662 val_kl=4.2987 ece=0.056 ent=4.139 collapse=0.00 mu_ac1=0.115


train[vae_gp]: 100%|██████████| 6000/6000 [26:05<00:00,  3.83it/s, loss=8.3038, nll_tok=3.9715, kl_tok=4.3323, ppl=53.07, beta=1.00, lr=3.0e-05]

[eval] vae_gp step= 6000 train_ppl=38.35 val_ppl=321.93 val_nll=5.7743 val_kl=4.2492 ece=0.056 ent=4.146 collapse=0.00 mu_ac1=0.109



train[vae_ar]:   0%|          | 1/6000 [00:12<20:38:21, 12.39s/it]

[eval] vae_ar step=    1 train_ppl=116526.99 val_ppl=116056.98 val_nll=11.6618 val_kl=609.7574 ece=0.002 ent=10.034 collapse=0.00 mu_ac1=-0.008


train[vae_ar]:   7%|▋         | 400/6000 [01:48<6:04:30,  3.91s/it, loss=12.3916, nll_tok=6.6738, kl_tok=21.4420, ppl=791.36, beta=0.27, lr=3.0e-04]

[eval] vae_ar step=  400 train_ppl=691.37 val_ppl=806.56 val_nll=6.6928 val_kl=20.8259 ece=0.022 ent=6.187 collapse=0.00 mu_ac1=0.119


train[vae_ar]:  13%|█▎        | 800/6000 [03:25<5:39:33,  3.92s/it, loss=17.1375, nll_tok=6.2190, kl_tok=20.4722, ppl=502.18, beta=0.53, lr=2.9e-04]

[eval] vae_ar step=  800 train_ppl=381.44 val_ppl=565.15 val_nll=6.3371 val_kl=20.2787 ece=0.022 ent=5.853 collapse=0.00 mu_ac1=0.102


train[vae_ar]:  20%|██        | 1200/6000 [05:02<5:13:53,  3.92s/it, loss=21.7502, nll_tok=5.9460, kl_tok=19.7552, ppl=382.23, beta=0.80, lr=2.8e-04]

[eval] vae_ar step= 1200 train_ppl=268.60 val_ppl=450.04 val_nll=6.1093 val_kl=19.8438 ece=0.019 ent=5.804 collapse=0.00 mu_ac1=0.084


train[vae_ar]:  27%|██▋       | 1600/6000 [06:38<4:46:26,  3.91s/it, loss=24.9564, nll_tok=5.4433, kl_tok=19.5130, ppl=231.21, beta=1.00, lr=2.7e-04]

[eval] vae_ar step= 1600 train_ppl=194.08 val_ppl=390.58 val_nll=5.9676 val_kl=19.3929 ece=0.021 ent=5.486 collapse=0.00 mu_ac1=0.078


train[vae_ar]:  33%|███▎      | 2000/6000 [08:15<4:19:14,  3.89s/it, loss=24.3278, nll_tok=5.2522, kl_tok=19.0755, ppl=190.99, beta=1.00, lr=2.4e-04]

[eval] vae_ar step= 2000 train_ppl=155.48 val_ppl=361.36 val_nll=5.8899 val_kl=18.9786 ece=0.025 ent=5.359 collapse=0.00 mu_ac1=0.057


train[vae_ar]:  40%|████      | 2400/6000 [09:51<3:55:18,  3.92s/it, loss=23.7695, nll_tok=5.2087, kl_tok=18.5608, ppl=182.85, beta=1.00, lr=2.2e-04]

[eval] vae_ar step= 2400 train_ppl=121.36 val_ppl=337.88 val_nll=5.8227 val_kl=18.6014 ece=0.035 ent=5.054 collapse=0.00 mu_ac1=0.064


train[vae_ar]:  47%|████▋     | 2800/6000 [11:28<3:28:05,  3.90s/it, loss=23.2906, nll_tok=5.0194, kl_tok=18.2712, ppl=151.32, beta=1.00, lr=1.9e-04]

[eval] vae_ar step= 2800 train_ppl=101.41 val_ppl=316.83 val_nll=5.7584 val_kl=18.2395 ece=0.033 ent=4.908 collapse=0.00 mu_ac1=0.068


train[vae_ar]:  53%|█████▎    | 3200/6000 [13:05<3:02:17,  3.91s/it, loss=22.6476, nll_tok=4.7702, kl_tok=17.8775, ppl=117.94, beta=1.00, lr=1.6e-04]

[eval] vae_ar step= 3200 train_ppl=89.43 val_ppl=312.33 val_nll=5.7441 val_kl=17.9213 ece=0.035 ent=4.784 collapse=0.00 mu_ac1=0.068


train[vae_ar]:  60%|██████    | 3600/6000 [14:41<2:36:51,  3.92s/it, loss=22.3376, nll_tok=4.5387, kl_tok=17.7988, ppl=93.57, beta=1.00, lr=1.3e-04] 

[eval] vae_ar step= 3600 train_ppl=79.16 val_ppl=302.14 val_nll=5.7109 val_kl=17.6603 ece=0.033 ent=4.726 collapse=0.00 mu_ac1=0.071


train[vae_ar]:  67%|██████▋   | 4000/6000 [16:18<2:10:09,  3.90s/it, loss=22.0842, nll_tok=4.6136, kl_tok=17.4706, ppl=100.85, beta=1.00, lr=1.0e-04]

[eval] vae_ar step= 4000 train_ppl=68.63 val_ppl=301.11 val_nll=5.7075 val_kl=17.4411 ece=0.041 ent=4.569 collapse=0.00 mu_ac1=0.083


train[vae_ar]:  73%|███████▎  | 4400/6000 [17:54<1:43:57,  3.90s/it, loss=21.8404, nll_tok=4.6090, kl_tok=17.2314, ppl=100.39, beta=1.00, lr=7.9e-05]

[eval] vae_ar step= 4400 train_ppl=61.35 val_ppl=301.58 val_nll=5.7090 val_kl=17.2873 ece=0.040 ent=4.566 collapse=0.00 mu_ac1=0.078


train[vae_ar]:  80%|████████  | 4800/6000 [19:31<1:17:48,  3.89s/it, loss=21.3793, nll_tok=4.2449, kl_tok=17.1343, ppl=69.75, beta=1.00, lr=5.8e-05] 

[eval] vae_ar step= 4800 train_ppl=58.81 val_ppl=301.09 val_nll=5.7074 val_kl=17.1505 ece=0.040 ent=4.508 collapse=0.00 mu_ac1=0.081


train[vae_ar]:  87%|████████▋ | 5200/6000 [21:08<52:03,  3.90s/it, loss=21.2509, nll_tok=4.1653, kl_tok=17.0856, ppl=64.41, beta=1.00, lr=4.3e-05]  

[eval] vae_ar step= 5200 train_ppl=57.03 val_ppl=301.99 val_nll=5.7104 val_kl=17.0499 ece=0.045 ent=4.412 collapse=0.00 mu_ac1=0.082


train[vae_ar]:  93%|█████████▎| 5600/6000 [22:44<26:08,  3.92s/it, loss=21.0963, nll_tok=4.1598, kl_tok=16.9365, ppl=64.06, beta=1.00, lr=3.3e-05]

[eval] vae_ar step= 5600 train_ppl=53.44 val_ppl=302.30 val_nll=5.7114 val_kl=16.9846 ece=0.044 ent=4.384 collapse=0.00 mu_ac1=0.082


train[vae_ar]: 100%|██████████| 6000/6000 [24:21<00:00,  4.11it/s, loss=21.0421, nll_tok=4.2128, kl_tok=16.8293, ppl=67.54, beta=1.00, lr=3.0e-05]

[eval] vae_ar step= 6000 train_ppl=51.36 val_ppl=303.85 val_nll=5.7165 val_kl=16.9258 ece=0.045 ent=4.370 collapse=0.00 mu_ac1=0.075

=== SUMMARY (best checkpoints) ===
vae     | step= 3600 | best_val_nll/tok=5.6932 | val_ppl=296.85 | val_kl/tok=0.0074 | ECE=0.034 | ent=4.642 | collapse=0.99 | mu_ac1=0.017 | d2=0.796 rep2=0.293 | log=runs/wt2_gpvae_compare_vae_1765723152.jsonl
vae_ar  | step= 4800 | best_val_nll/tok=5.7074 | val_ppl=301.09 | val_kl/tok=17.1505 | ECE=0.040 | ent=4.508 | collapse=0.00 | mu_ac1=0.081 | d2=0.803 rep2=0.289 | log=runs/wt2_gpvae_compare_vae_ar_1765726171.jsonl
vae_gp  | step= 3600 | best_val_nll/tok=5.7341 | val_ppl=309.23 | val_kl/tok=5.0114 | ECE=0.047 | ent=4.439 | collapse=0.00 | mu_ac1=0.086 | d2=0.763 rep2=0.331 | log=runs/wt2_gpvae_compare_vae_gp_1765724605.jsonl

Saved best summary JSON: runs/wt2_gpvae_compare_BEST_1765727632.json



