# Latent-Autoregressive GP-VAE Language Model with a Dilated TCN Encoder

This encoder is a **dilated causal TCN**, not a “pyramidal” network.

- It stacks **causal 1D convolutions with increasing dilations** (`base_dilations = 1,2,4,8,16`) and **residual TCN blocks**.
- The **time resolution stays constant** (no stride, pooling, or downsampling/upsampling).

In many papers, “pyramidal” usually implies **progressive temporal downsampling** (e.g., T → T/2 → T/4). Since that is not present here, the accurate name is:
**Hierarchical / Dilated Causal TCN encoder** (or **TCN+** for the enhanced block design).


## Purpose of the code

This codebase was created as a **controlled proof-of-concept** to investigate whether
**sequential structure in a language model can be shifted from token-level autoregression
to a continuous latent space**.

It implements a **GP-VAE language model** in which:
- temporal dependencies are enforced by a **causal Gaussian Process prior** over latent variables,
- the encoder maps token sequences to latent distributions using a **dilated causal TCN**,
- the decoder generates tokens **in parallel**, without token-level autoregression.

---

## What is being tested

The code is designed to perform a **systematic ablation of latent autoregression**:

- **Latent AR mode**: latent variables are sampled sequentially using GP conditionals,
  enforcing temporal coherence.
- **Latent non-AR mode**: temporal correlations are removed by sampling latent variables
  independently from diagonal marginals.

This setup isolates the effect of **autoregressive structure in latent space**, independently
of token-level autoregression.

---

## Why this is useful

This framework allows us to:
- verify that the GP-VAE model can be **trained stably** on real language data (WikiText-103),
- analyze how latent autoregression affects **long-horizon generation stability**,
- compare against a **standard autoregressive Transformer baseline**,
- and highlight the limitations of **classical AR metrics** when applied to
  parallel latent generative models.

Overall, the code serves as an **experimental testbed** for assessing whether part of a
language model’s temporal structure can be supported by the **probabilistic geometry of
latent space**, rather than by explicit token-level autoregressive neural operations.


In [1]:
# Latent-Autoregressive GP-VAE Language Model with a Dilated TCN Encoder

import math, os, time, json, shutil, random, pathlib
from dataclasses import dataclass, asdict
from typing import Tuple, Optional, Dict, Any, List

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

from torch.nn.utils.parametrizations import weight_norm as pn_weight_norm

try:
    import optuna
    from optuna.pruners import MedianPruner
    _HAVE_OPTUNA = True
except Exception:
    _HAVE_OPTUNA = False

try:
    from datasets import load_dataset
    from transformers import AutoTokenizer, AutoModelForCausalLM
except Exception:
    load_dataset = None
    AutoTokenizer = None
    AutoModelForCausalLM = None
    print("[Warning] Please install datasets and transformers: pip install datasets transformers")

# Dataset utilities

class LMBlocks(Dataset):
    def __init__(self, texts, tokenizer, block_size: int = 256):
        ids = []
        for t in texts:
            ids.extend(tokenizer.encode(t))
        n = max(1, len(ids) // block_size)
        ids = ids[: n * block_size]
        self.block_size = block_size
        self.data = torch.tensor(ids, dtype=torch.long).view(n, block_size)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]

def load_wikitext103(block_size=256, split='train'):
    """
      - train  -> /kaggle/input/wikitext/wikitext-103/wiki.train.tokens
      - valid  -> /kaggle/input/wikitext/wikitext-103/wiki.valid.tokens
      - test   -> /kaggle/input/wikitext/wikitext-103/wiki.test.tokens
    """
    path_map = {
        'train': "/kaggle/input/wikitext/wikitext-103/wiki.train.tokens",
        'validation': "/kaggle/input/wikitext/wikitext-103/wiki.valid.tokens",
        'valid': "/kaggle/input/wikitext/wikitext-103/wiki.valid.tokens",
        'test': "/kaggle/input/wikitext/wikitext-103/wiki.test.tokens",
    }
    if split not in path_map:
        raise ValueError(f"split inconnu pour WikiText-103 : {split}")

    path = path_map[split]
    if not os.path.exists(path):
        raise FileNotFoundError(f"Fichier WikiText-103 introuvable : {path}")

    texts = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.rstrip("\n")
            if line.strip() == "":
                continue
            texts.append(line)

    # Hack pour éviter les warnings HF
    try:
        from transformers.utils import hub as _hf_hub_utils
        _hf_hub_utils.list_repo_templates = lambda *a, **k: []
        import transformers.tokenization_utils_base as _tub
        _tub.list_repo_templates = lambda *a, **k: []
    except Exception:
        pass

    assert AutoTokenizer is not None, "Please install transformers"
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    # Évite les warnings "sequence length > 1024" pendant l'encodage du corpus
    tokenizer.model_max_length = int(1e9)
    if hasattr(tokenizer, "init_kwargs"):
        tokenizer.init_kwargs["model_max_length"] = int(1e9)

    dataset = LMBlocks(texts, tokenizer, block_size)
    return dataset, tokenizer


# Plotting helpers

def plot_curve(values, title, ylabel, out_png):
    plt.figure()
    plt.plot(values)
    plt.xlabel('Step')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()


def reliability_diagram_data(logits: torch.Tensor, x: torch.Tensor, n_bins: int = 15):
    with torch.no_grad():
        probs = torch.softmax(logits, dim=-1)
        conf, pred = probs.max(dim=-1)
        corr = (pred == x).float()
        conf = conf.detach().cpu().view(-1).numpy()
        corr = corr.detach().cpu().view(-1).numpy()
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(conf, bins) - 1
    bin_conf, bin_acc, bin_count = [], [], []
    ece = 0.0
    N = len(conf)
    for b in range(n_bins):
        m = (bin_ids == b)
        cnt = m.sum()
        if cnt == 0:
            bin_conf.append(0.0)
            bin_acc.append(0.0)
            bin_count.append(0)
            continue
        c = conf[m].mean()
        a = corr[m].mean()
        w = cnt / N
        ece += w * abs(a - c)
        bin_conf.append(c)
        bin_acc.append(a)
        bin_count.append(cnt)
    return np.array(bin_conf), np.array(bin_acc), float(ece)


def plot_reliability(confidences, accuracy, ece, out_png):
    plt.figure()
    plt.plot(confidences, accuracy, 'o-')
    plt.plot([0, 1], [0, 1], '--')
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.title(f'Reliability diagram — ECE={ece:.4f}')
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()


def nll_per_token(logits: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    B, T, V = logits.shape
    ce = F.cross_entropy(logits.reshape(B * T, V), x.reshape(B * T), reduction='none').view(B, T)
    return ce.mean(dim=0).detach().cpu()


def plot_nll_per_position(nll_per_pos, out_png):
    plt.figure()
    plt.plot(nll_per_pos)
    plt.xlabel('Position t')
    plt.ylabel('Token NLL')
    plt.title('NLL per position')
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()


# GP helpers

def rbf_kernel(t: torch.Tensor, lengthscale: torch.Tensor, variance: torch.Tensor) -> torch.Tensor:
    t = t.view(-1, 1)
    d2 = (t - t.T) ** 2
    return variance * torch.exp(-0.5 * d2 / (lengthscale ** 2))


def mvn_logpdf_zero_mean(z: torch.Tensor, K: torch.Tensor, jitter: float = 1e-5) -> torch.Tensor:
    B, T, D = z.shape
    I = torch.eye(T, device=K.device, dtype=K.dtype)
    L = torch.linalg.cholesky(K + jitter * I)
    log_det = 2.0 * torch.sum(torch.log(torch.diag(L)))
    z_bd = z.permute(0, 2, 1).reshape(B * D, T)
    y_bd = torch.cholesky_solve(z_bd.unsqueeze(-1), L).squeeze(-1)
    quad = (z_bd * y_bd).sum(dim=-1).view(B, D).sum(dim=-1)
    const = T * math.log(2 * math.pi)
    logp = -0.5 * (D * (log_det + const) + quad)
    return logp


def diag_mvn_logpdf(z: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    var = torch.exp(logvar)
    const = math.log(2 * math.pi)
    logp = -0.5 * (const + logvar + (z - mu) ** 2 / var)
    return logp.sum(dim=(1, 2))


# Building blocks (positional enc, TCN, etc.)

class SinusoidalPos(nn.Module):
    def __init__(self, d: int):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(1))
        self.d = d

    def forward(self, T: int, device):
        D = self.d
        pe = torch.zeros(T, D, device=device)
        position = torch.arange(T, device=device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, D, 2, device=device) * (-math.log(10000.0) / D))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return (self.scale * pe).unsqueeze(0)


class CausalDepthwiseSeparableConv(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=1, out_channels=None):
        super().__init__()
        pad = (kernel_size - 1) * dilation
        self.left_pad = pad
        out_channels = out_channels if out_channels is not None else channels
        dw = nn.Conv1d(channels, channels, kernel_size, groups=channels,
                       dilation=dilation, padding=pad, bias=True)
        pw = nn.Conv1d(channels, out_channels, kernel_size=1, bias=True)
        self.dw = pn_weight_norm(dw)
        self.pw = pn_weight_norm(pw)

        self._pad = pad

    def forward(self, x):
        h = self.dw(x)
        if self.left_pad > 0:
            h = h[..., :-self.left_pad]
        h = self.pw(h)
        return h


class SqueezeExcite(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        hidden = max(1, channels // reduction)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Conv1d(channels, hidden, kernel_size=1)
        self.fc2 = nn.Conv1d(hidden, channels, kernel_size=1)

    def forward(self, x):
        s = self.pool(x)
        s = F.gelu(self.fc1(s))
        s = torch.sigmoid(self.fc2(s))
        return x * s


class DropPath1D(nn.Module):
    def __init__(self, drop_prob=0.0):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if not self.training or self.drop_prob <= 0.0:
            return x
        keep = 1.0 - self.drop_prob
        shape = (x.size(0),) + (1,) * (x.ndim - 1)
        mask = torch.empty(shape, dtype=x.dtype, device=x.device).bernoulli_(keep) / keep
        return x * mask


class TCNBlock(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=1, dropout=0.05, droppath=0.05):
        super().__init__()
        self.gn1 = nn.GroupNorm(1, channels)
        self.conv1 = CausalDepthwiseSeparableConv(channels, kernel_size, dilation, out_channels=2 * channels)
        self.gn2 = nn.GroupNorm(1, channels)
        self.conv2 = CausalDepthwiseSeparableConv(channels, kernel_size, dilation, out_channels=2 * channels)
        self.se = SqueezeExcite(channels, reduction=8)
        self.drop = nn.Dropout(dropout)
        self.droppath = DropPath1D(drop_prob=droppath)

    def forward(self, x):
        h = self.gn1(x)
        h = self.conv1(h)
        h = F.glu(h, dim=1)
        h = self.drop(h)
        h = self.gn2(h)
        h = self.conv2(h)
        h = F.glu(h, dim=1)
        h = self.se(h)
        h = self.drop(h)
        return x + self.droppath(h)


class HierarchicalTCN(nn.Module):
    def __init__(self, in_ch, hidden_ch, n_stacks=3, blocks_per_stack=2,
                 kernel_size=3, base_dilations=(1, 2, 4, 8, 16), dropout=0.05, droppath=0.05):
        super().__init__()
        self.proj = nn.Conv1d(in_ch, hidden_ch, 1)
        self.blocks = nn.ModuleList()
        dil = list(base_dilations)
        idx = 0
        for _ in range(n_stacks * blocks_per_stack):
            d = dil[idx % len(dil)]
            self.blocks.append(TCNBlock(hidden_ch, kernel_size, d, dropout, droppath))
            idx += 1

    def forward(self, x):
        h = self.proj(x)
        for blk in self.blocks:
            h = blk(h)
        return h


# GP-VAE-TCN+ model

@dataclass
class Config:
    vocab_size: int
    d_model: int = 256
    d_latent: int = 64
    block_size: int = 256
    emb_dim: int = 256
    gp_lengthscale_init: float = 8.0
    gp_variance_init: float = 1.0
    lr: float = 2e-4
    label_smoothing: float = 0.03
    weight_decay: float = 0.01
    grad_clip: float = 1.0
    free_bits_nats: float = 0.5
    kl_cap_nats: float = 12.0
    kl_target_nats: float = 8.0
    beta_init: float = 1e-3
    beta_max: float = 0.5
    beta_adapt_rate: float = 0.05
    use_adaptive_beta: bool = True
    K_multi: int = 3
    multi_lambda_scheme: str = "harmonic"
    gamma_multi: float = 0.5
    logvar_min: float = -6.0
    logvar_max: float = 2.0
    logvar_init: float = -4.0
    tcn_stacks: int = 3
    tcn_blocks_per_stack: int = 3
    tcn_kernel: int = 3
    tcn_base_dilations: tuple = (1, 2, 4, 8, 16)
    tcn_dropout: float = 0.05
    tcn_droppath: float = 0.05
    embed_reg_weight: float = 0.2
    embed_reg_mode: str = "cos"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


class TCNEncoder(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.emb_dim)
        self.tcn = HierarchicalTCN(
            in_ch=cfg.emb_dim, hidden_ch=cfg.d_model,
            n_stacks=cfg.tcn_stacks, blocks_per_stack=cfg.tcn_blocks_per_stack,
            kernel_size=cfg.tcn_kernel, base_dilations=cfg.tcn_base_dilations,
            dropout=cfg.tcn_dropout, droppath=cfg.tcn_droppath)
        self.to_mu = nn.Conv1d(cfg.d_model, cfg.d_latent, kernel_size=1)
        self.to_logvar = nn.Conv1d(cfg.d_model, cfg.d_latent, kernel_size=1)
        nn.init.zeros_(self.to_mu.weight)
        nn.init.zeros_(self.to_mu.bias)
        nn.init.zeros_(self.to_logvar.weight)
        nn.init.constant_(self.to_logvar.bias, cfg.logvar_init)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.tok_emb(x)
        h = h.transpose(1, 2)
        h = self.tcn(h)
        mu = self.to_mu(h).transpose(1, 2)
        logvar = self.to_logvar(h).transpose(1, 2)
        logvar = logvar.clamp(self.cfg.logvar_min, self.cfg.logvar_max)
        return mu, logvar


class TokenDecoder(nn.Module):
    def __init__(self, cfg: Config, tied_weight: torch.Tensor):
        super().__init__()
        self.cfg = cfg
        self.pe = SinusoidalPos(cfg.d_latent)
        self.mlp = nn.Sequential(
            nn.Linear(cfg.d_latent, cfg.d_model), nn.GELU(),
            nn.Linear(cfg.d_model, cfg.d_model), nn.GELU())
        self.ln = nn.LayerNorm(cfg.d_model)
        self.post = nn.Conv1d(cfg.d_model, cfg.d_model, kernel_size=3, padding=1)
        self.to_emb = nn.Linear(cfg.d_model, cfg.emb_dim, bias=False)
        self.tied_weight = tied_weight
        self.bias = nn.Parameter(torch.zeros(cfg.vocab_size))
        self.sem_head = nn.Linear(cfg.d_model, cfg.emb_dim, bias=False)

    def forward(self, z: torch.Tensor):
        z = z + self.pe(T=z.size(1), device=z.device)
        h = self.mlp(z)
        h = self.ln(h)
        h2 = self.post(h.transpose(1, 2)).transpose(1, 2)
        h = h + h2
        e_proj = self.to_emb(h)
        tw = F.normalize(self.tied_weight, dim=-1)
        logits = torch.matmul(e_proj, tw.t()) + self.bias
        e_hat = self.sem_head(h)
        return logits, e_hat


class GPVAE_TCN(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.encoder = TCNEncoder(cfg)
        self.decoder = TokenDecoder(cfg, tied_weight=self.encoder.tok_emb.weight)
        self.register_buffer('t_train', torch.arange(cfg.block_size, dtype=torch.float32))
        self._ls_unc = nn.Parameter(torch.tensor(cfg.gp_lengthscale_init).log())
        self._var_unc = nn.Parameter(torch.tensor(cfg.gp_variance_init).log())
        self.softplus = nn.Softplus(beta=1.0)

    def gp_hypers(self):
        ls = self.softplus(self._ls_unc) + 1e-6
        var = self.softplus(self._var_unc) + 1e-6
        return ls, var

    def K_tt(self, t: torch.Tensor) -> torch.Tensor:
        ls, var = self.gp_hypers()
        return rbf_kernel(t, ls, var)

    def _label_smoothing_ce(self, logits, targets, eps):
        log_probs = F.log_softmax(logits, dim=-1)
        nll = F.nll_loss(log_probs.transpose(1, 2), targets, reduction='mean')
        smooth = -log_probs.mean(dim=-1).mean()
        return (1 - eps) * nll + eps * smooth

    def _multi_horizon_ll(self, logits, x) -> torch.Tensor:
        K = self.cfg.K_multi
        if K <= 0:
            return torch.zeros((), device=logits.device)
        B, T, V = logits.shape
        total = 0.0
        for k in range(1, K + 1):
            if T - k <= 0:
                break
            lam = (1.0 / k) if self.cfg.multi_lambda_scheme == "harmonic" else 1.0
            ll_k = -self._label_smoothing_ce(logits[:, :T - k, :], x[:, k:], self.cfg.label_smoothing)
            total = total + self.cfg.gamma_multi * lam * ll_k
        return total

    def _embed_reg_losses(self, e_hat, x, token_emb_table: torch.Tensor):
        with torch.no_grad():
            e_tgt = token_emb_table[x]
        loss = 0.0
        if self.cfg.embed_reg_mode in ("mse", "cos+mse"):
            loss = loss + F.mse_loss(e_hat, e_tgt, reduction="mean")
        if self.cfg.embed_reg_mode in ("cos", "cos+mse"):
            loss = loss + (1.0 - F.cosine_similarity(e_hat, e_tgt, dim=-1).mean())
        return torch.as_tensor(loss, device=e_hat.device)

    def elbo(self, x: torch.Tensor, beta_override: Optional[float] = None):
        cfg = self.cfg
        B, T = x.shape
        mu, logvar = self.encoder(x)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        logits, e_hat = self.decoder(z)
        ll_0 = -self._label_smoothing_ce(logits, x, cfg.label_smoothing)
        ll_multi = self._multi_horizon_ll(logits, x)
        emb_reg = self._embed_reg_losses(e_hat, x, self.encoder.tok_emb.weight.detach())
        tvec = torch.arange(T, dtype=torch.float32, device=x.device)
        K_full = self.K_tt(tvec)
        log_pz_b = mvn_logpdf_zero_mean(z, K_full)
        log_qz_b = diag_mvn_logpdf(z, mu, logvar)
        kl_tok_raw_t = (log_qz_b - log_pz_b).mean() / T
        kl_tok_capped_t = torch.clamp(torch.clamp(kl_tok_raw_t, min=cfg.free_bits_nats), max=cfg.kl_cap_nats)
        beta = beta_override if beta_override is not None else getattr(self, "_beta_state", cfg.beta_init)
        elbo_tok_t = (ll_0 + ll_multi) - beta * kl_tok_capped_t - cfg.embed_reg_weight * emb_reg
        stats = {
            'll0_tok': float(ll_0.detach().item()),
            'll_multi_tok': float(ll_multi.detach().item()),
            'emb_reg': float(emb_reg.detach().item()) if torch.is_tensor(emb_reg) else float(emb_reg),
            'kl_tok_raw': float(kl_tok_raw_t.detach().item()),
            'kl_tok_cap': float(kl_tok_capped_t.detach().item()),
            'beta': float(beta),
            'elbo_tok': float(elbo_tok_t.detach().item()),
        }
        comps = {
            'll0_tok_t': ll_0,
            'll_multi_tok_t': ll_multi,
            'emb_reg_t': emb_reg,
            'kl_tok_raw_t': kl_tok_raw_t,
            'kl_tok_cap_t': kl_tok_capped_t,
            'elbo_tok_t': elbo_tok_t,
            'beta_t': torch.tensor(beta, device=x.device),
        }
        return elbo_tok_t, stats, comps, logits

    @torch.no_grad()
    def _gp_conditional_step(self, K: torch.Tensor, z_prefix: torch.Tensor, t_next: int):
        B, tp, Dz = z_prefix.shape
        if tp == 0:
            var0 = K[0, 0].clamp_min(1e-9)
            return torch.randn(B, Dz, device=z_prefix.device) * var0.sqrt()
        K_11 = K[:tp, :tp]
        k_12 = K[:tp, t_next]
        K_22 = K[t_next, t_next]
        jitter = 1e-6
        for _ in range(5):
            try:
                L_11 = torch.linalg.cholesky(K_11 + jitter * torch.eye(tp, device=K.device, dtype=K.dtype))
                break
            except RuntimeError:
                jitter *= 10.0
        alpha = torch.cholesky_solve(k_12.view(tp, 1), L_11).view(tp)
        mean_t = torch.einsum('t, btd -> bd', alpha, z_prefix)
        var_t = (K_22 - (k_12 * alpha).sum()).clamp_min(1e-9)
        return mean_t + torch.randn(B, Dz, device=z_prefix.device) * var_t.sqrt()

    @torch.no_grad()
    def _sample_gp_prior_block(self, T: int, batch_size: int) -> torch.Tensor:
        device = self.t_train.device
        t = torch.arange(T, dtype=torch.float32, device=device)
        K = self.K_tt(t)
        B = batch_size
        Dz = self.cfg.d_latent
        jitter = 1e-6
        I = torch.eye(T, device=device, dtype=K.dtype)
        L = torch.linalg.cholesky(K + jitter * I)

        eps = torch.randn(B * Dz, T, device=device)
        z_bd = (L @ eps.T).T
        z = z_bd.view(B, Dz, T).permute(0, 2, 1).contiguous()
        return z

    @torch.no_grad()
    def _gp_conditional_block(self, K: torch.Tensor, z_prefix: torch.Tensor, T0: int) -> torch.Tensor:
        B, T0_check, Dz = z_prefix.shape
        assert T0_check == T0
        T = K.size(0)
        Tf = T - T0
        device = z_prefix.device
        dtype = K.dtype

        if Tf <= 0:
            return z_prefix.new_zeros(B, 0, Dz)

        K_pp = K[:T0, :T0]
        K_pf = K[:T0, T0:]
        K_fp = K_pf.transpose(0, 1)
        K_ff = K[T0:, T0:]

        jitter = 1e-6
        I_pp = torch.eye(T0, device=device, dtype=dtype)
        L_pp = torch.linalg.cholesky(K_pp + jitter * I_pp)

        z_p_bd = z_prefix.permute(0, 2, 1).reshape(B * Dz, T0)
        alpha_bd = torch.cholesky_solve(z_p_bd.unsqueeze(-1), L_pp).squeeze(-1)
        alpha = alpha_bd.view(B, Dz, T0).permute(0, 2, 1)

        mean_fut = torch.einsum('ft, btd -> bfd', K_fp, alpha)

        C = torch.cholesky_solve(K_pf, L_pp)
        Sigma_f = K_ff - K_fp @ C
        Sigma_f = 0.5 * (Sigma_f + Sigma_f.T)

        I_f = torch.eye(Tf, device=device, dtype=dtype)
        L_f = torch.linalg.cholesky(Sigma_f + jitter * I_f)

        eps = torch.randn(B * Dz, Tf, device=device)
        z_noise_bd = (L_f @ eps.T).T
        z_noise = z_noise_bd.view(B, Dz, Tf).permute(0, 2, 1)

        z_fut = mean_fut + z_noise
        return z_fut

    @torch.no_grad()
    def generate(self, T: int, batch_size: int = 1, top_k: int = 50,
                 top_p: float = 0.9, temperature: float = 0.9):
        device = self.t_train.device
        self.eval()
        z = self._sample_gp_prior_block(T, batch_size)
        logits, _ = self.decoder(z)
        return sample_logits_from_timewise_logits(
            logits, top_k=top_k, top_p=top_p, temperature=temperature
        )

    @torch.no_grad()
    def generate_with_prompt(self, prompt_ids: torch.Tensor, total_len: int, eos_id: int,
                             top_k: int = 50, top_p: float = 0.9, temperature: float = 0.9):
        device = self.t_train.device
        self.eval()
        B, T0 = prompt_ids.shape
        T = total_len

        x_in = torch.full((B, T), fill_value=eos_id, dtype=torch.long, device=device)
        x_in[:, :T0] = prompt_ids.to(device)

        if T0 > 0:
            mu, logvar = self.encoder(x_in[:, :T0])
            std = torch.exp(0.5 * logvar)
            z_prompt = mu + std * torch.randn_like(std)
        else:
            z_prompt = None

        t = torch.arange(T, dtype=torch.float32, device=device)
        K = self.K_tt(t)
        Dz = self.cfg.d_latent

        z = torch.zeros(B, T, Dz, device=device)
        if T0 > 0:
            z[:, :T0, :] = z_prompt
            z_fut = self._gp_conditional_block(K, z_prompt, T0)
            z[:, T0:, :] = z_fut
        else:
            z = self._sample_gp_prior_block(T, B)

        logits, _ = self.decoder(z)
        new_ids = sample_logits_from_timewise_logits(
            logits[:, T0:, :], top_k=top_k, top_p=top_p, temperature=temperature
        )
        x_out = x_in.clone()
        if T > T0:
            x_out[:, T0:] = new_ids
        return x_out, logits

    @torch.no_grad()
    def generate_no_ar(self, T: int, batch_size: int = 1,
                       top_k: int = 50, top_p: float = 0.9,
                       temperature: float = 0.9):
        device = self.t_train.device
        self.eval()
        t = torch.arange(T, dtype=torch.float32, device=device)
        K = self.K_tt(t)
        Dz = self.cfg.d_latent
        B = batch_size

        var_diag = torch.diag(K).clamp_min(1e-9)
        std_diag = var_diag.sqrt().view(1, T, 1)

        z = torch.randn(B, T, Dz, device=device) * std_diag
        z = z * float(temperature)
        logits, _ = self.decoder(z)
        return sample_logits_from_timewise_logits(
            logits, top_k=top_k, top_p=top_p, temperature=temperature
        )

    @torch.no_grad()
    def generate_with_prompt_no_ar(self, prompt_ids: torch.Tensor, total_len: int, eos_id: int,
                                   top_k: int = 50, top_p: float = 0.9,
                                   temperature: float = 0.9):
        device = self.t_train.device
        self.eval()
        B, T0 = prompt_ids.shape
        T = total_len

        x_in = torch.full((B, T), fill_value=eos_id,
                          dtype=torch.long, device=device)
        x_in[:, :T0] = prompt_ids.to(device)

        Dz = self.cfg.d_latent
        z = torch.zeros(B, T, Dz, device=device)

        if T0 > 0:
            mu, logvar = self.encoder(x_in[:, :T0])
            std = torch.exp(0.5 * logvar)
            z_prompt = mu + std * torch.randn_like(std)
            z[:, :T0, :] = z_prompt

        t = torch.arange(T, dtype=torch.float32, device=device)
        K = self.K_tt(t)
        var_diag = torch.diag(K).clamp_min(1e-9)
        if T > T0:
            std_future = var_diag[T0:].sqrt().view(1, T - T0, 1)
            noise_future = torch.randn(B, T - T0, Dz, device=device)
            z[:, T0:, :] = noise_future * std_future

        z = z * float(temperature)
        logits, _ = self.decoder(z)
        new_ids = sample_logits_from_timewise_logits(
            logits[:, T0:, :], top_k=top_k, top_p=top_p, temperature=temperature
        )
        x_out = x_in.clone()
        if T > T0:
            x_out[:, T0:] = new_ids
        return x_out, logits


# Sampling from logits

@torch.no_grad()
def sample_logits_from_timewise_logits(logits: torch.Tensor, top_k: int = 50,
                                       top_p: float = 0.9, temperature: float = 1.0):
    B, T, V = logits.shape
    x = logits / max(1e-6, float(temperature))
    if top_k is not None and 0 < top_k < V:
        kth = torch.topk(x, top_k, dim=-1).values[..., -1:].expand(B, T, V)
        x = torch.where(x < kth, torch.full_like(x, -1e10), x)
    if top_p is not None and 0.0 < top_p < 1.0:
        probs = torch.softmax(x, dim=-1)
        sorted_probs, sorted_idx = torch.sort(probs, dim=-1, descending=True)
        cum = torch.cumsum(sorted_probs, dim=-1)
        cutoff = cum > top_p
        cutoff[..., 1:] = cutoff[..., :-1].clone()
        cutoff[..., 0] = False
        sorted_logits = torch.log(sorted_probs.clamp_min(1e-20))
        sorted_logits[cutoff] = -1e10
        x = torch.full_like(x, -1e10).scatter(-1, sorted_idx, sorted_logits)
    ids = torch.distributions.Categorical(logits=x).sample()
    return ids


@torch.no_grad()
def score_continuation_gpvae_from_logits(model, x_full, T0):
    B, T = x_full.shape
    mu, logvar = model.encoder(x_full)
    z = mu
    logits, _ = model.decoder(z)
    targets = x_full[:, T0:]
    logits_c = logits[:, T0:, :]
    V = logits_c.size(-1)
    nll = F.cross_entropy(logits_c.reshape(-1, V), targets.reshape(-1), reduction='mean').item()
    ppl = math.exp(nll) if nll < 50 else float("inf")
    return {"nll": nll, "ppl": ppl}


# Transformer LM baseline (autoregressive, GPT-style)

@dataclass
class TFConfig:
    vocab_size: int
    d_model: int = 640
    n_layer: int = 10
    n_head: int = 10
    d_ff: int = 2560
    block_size: int = 256
    dropout: float = 0.1
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


class TransformerBlock(nn.Module):
    def __init__(self, cfg: TFConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(cfg.d_model)
        self.attn = nn.MultiheadAttention(
            embed_dim=cfg.d_model,
            num_heads=cfg.n_head,
            dropout=cfg.dropout,
            batch_first=True
        )
        self.ln2 = nn.LayerNorm(cfg.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(cfg.d_model, cfg.d_ff),
            nn.GELU(),
            nn.Linear(cfg.d_ff, cfg.d_model),
            nn.Dropout(cfg.dropout),
        )

    def forward(self, x, attn_mask=None):
        h = self.ln1(x)
        h, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False)
        x = x + h
        h2 = self.mlp(self.ln2(x))
        x = x + h2
        return x


class TransformerLM(nn.Module):
    def __init__(self, cfg: TFConfig):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.block_size, cfg.d_model)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layer)])
        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

    def forward(self, x):
        B, T = x.shape
        assert T <= self.cfg.block_size, "Sequence length exceeds block_size"
        pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
        h = self.tok_emb(x) + self.pos_emb(pos)

        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()

        for blk in self.blocks:
            h = blk(h, attn_mask=mask)

        h = self.ln_f(h)
        logits = self.head(h)
        return logits

    @torch.no_grad()
    def generate(self, prompt_ids: torch.Tensor, total_len: int,
                 top_k: int = 50, top_p: float = 0.9, temperature: float = 1.0):
        self.eval()
        x = prompt_ids.clone()
        B, T0 = x.shape
        for t in range(T0, total_len):
            if x.size(1) > self.cfg.block_size:
                x_cond = x[:, -self.cfg.block_size:]
            else:
                x_cond = x
            logits = self.forward(x_cond)
            last_logits = logits[:, -1:, :]
            ids = sample_logits_from_timewise_logits(
                last_logits, top_k=top_k, top_p=top_p, temperature=temperature
            )
            x = torch.cat([x, ids], dim=1)
        return x


def train_transformer_baseline(
    train_ds, val_ds, tok,
    block_size=256,
    batch_size=16,
    steps=2000,
    lr=3e-4,
    d_model=640,
    n_layer=10,
    n_head=10,
    d_ff=2560,
    dropout=0.1,
    log_every=50,
    warmup_frac=0.1,
):

    
    cfg = TFConfig(
        vocab_size=tok.vocab_size,
        d_model=d_model,
        n_layer=n_layer,
        n_head=n_head,
        d_ff=d_ff,
        block_size=block_size,
        dropout=dropout,
    )
    device = cfg.device
    model = TransformerLM(cfg).to(device)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=False)

    optim = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.01)
    scaler = torch.amp.GradScaler('cuda', enabled=device.startswith("cuda"))

    base_lr = lr
    warmup_steps = max(10, int(warmup_frac * steps))

    def lr_at(step_idx: int) -> float:
        if step_idx <= 0:
            return 0.0
        if step_idx < warmup_steps:
            return base_lr * (step_idx / float(warmup_steps))
        t = (step_idx - warmup_steps) / max(1, steps - warmup_steps)
        return base_lr * 0.5 * (1.0 + math.cos(math.pi * t))

    hist_loss = []
    t0 = time.perf_counter()

    model.train()
    for step, x in enumerate(train_loader, 1):
        x = x.to(device)

        cur_lr = lr_at(step)
        for g in optim.param_groups:
            g["lr"] = cur_lr

        optim.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=device.startswith("cuda")):
            logits = model(x)
            B, T, V = logits.shape
            loss = F.cross_entropy(
                logits[:, :-1, :].contiguous().view(-1, V),
                x[:, 1:].contiguous().view(-1),
                reduction='mean'
            )
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optim)
        scaler.update()

        hist_loss.append(float(loss.item()))
        if step % log_every == 0:
            dt = time.perf_counter() - t0
            tokps = (batch_size * block_size * log_every) / max(1e-9, dt)
            print(f"[TF] step {step} | loss {loss.item():.3f} | lr {cur_lr:.2e} | tok/s {tokps:,.0f}")
            t0 = time.perf_counter()

        if step >= steps:
            break

    model.eval()
    total_nll = 0.0
    total_tokens = 0
    with torch.no_grad():
        for x_val in val_loader:
            x_val = x_val.to(device)
            logits = model(x_val)
            B, T, V = logits.shape
            nll_batch = F.cross_entropy(
                logits[:, :-1, :].contiguous().view(-1, V),
                x_val[:, 1:].contiguous().view(-1),
                reduction='sum'
            ).item()
            total_nll += nll_batch
            total_tokens += (B * (T - 1))
    if total_tokens == 0:
        nll = float("inf")
    else:
        nll = total_nll / total_tokens
    ppl = math.exp(nll) if nll < 50 else float("inf")
    print(f"[TF VAL] NLL={nll:.4f} | PPL={ppl:.2f}")
    return model, cfg, {"val_nll": nll, "val_ppl": ppl}


# GPT-2 judge helpers & repetition metrics

def _repetition_metrics_for_ids(ids: List[int]):
    L = len(ids)
    if L <= 1:
        return 0.0, 0.0, 0.0

    seen2 = set()
    rep2_count = 0
    for i in range(1, L):
        bg = (ids[i - 1], ids[i])
        if bg in seen2:
            rep2_count += 1
        seen2.add(bg)
    rep2 = rep2_count / max(1, L - 1)

    if L > 2:
        seen3 = set()
        rep3_count = 0
        for i in range(2, L):
            tg = (ids[i - 2], ids[i - 1], ids[i])
            if tg in seen3:
                rep3_count += 1
            seen3.add(tg)
        rep3 = rep3_count / max(1, L - 2)
    else:
        rep3 = 0.0

    consec_count = sum(1 for i in range(1, L) if ids[i] == ids[i - 1])
    consec = consec_count / max(1, L - 1)

    return rep2, rep3, consec


def _non_ascii_fraction(text: str) -> float:
    if len(text) == 0:
        return 0.0
    count = sum(1 for ch in text if ord(ch) < 32 or ord(ch) > 126)
    return count / len(text)


@torch.no_grad()
def _gpt2_continuation_metrics(
    gpt2_model,
    gpt2_tok,
    prompt_ids: List[int],
    cont_ids: List[int],
    rare_prob_thresh: float = 1e-4,
):
    device = next(gpt2_model.parameters()).device
    prompt_ids = list(prompt_ids)
    cont_ids = list(cont_ids)
    if len(cont_ids) == 0:
        return 0.0, 1.0, 0.0

    full = prompt_ids + cont_ids
    max_ctx = getattr(gpt2_model.config, "n_positions", None)
    if max_ctx is None:
        max_ctx = getattr(gpt2_model.config, "max_position_embeddings", 1024)

    if len(full) > max_ctx:
        full = full[-max_ctx:]
        if len(cont_ids) >= max_ctx:
            prompt_len = 0
        else:
            prompt_len = max_ctx - len(cont_ids)
    else:
        prompt_len = len(prompt_ids)

    input_ids = torch.tensor([full], dtype=torch.long, device=device)
    outputs = gpt2_model(input_ids)
    logits = outputs.logits[0]
    T = logits.size(0)

    shift_logits = logits[:-1, :]
    shift_labels = input_ids[0, 1:]
    start = max(0, prompt_len - 1)
    shift_logits_c = shift_logits[start:, :]
    shift_labels_c = shift_labels[start:]

    if shift_logits_c.numel() == 0:
        return 0.0, 1.0, 0.0

    log_probs = F.log_softmax(shift_logits_c, dim=-1)
    nll_tokens = F.nll_loss(log_probs, shift_labels_c, reduction="none")
    nll = float(nll_tokens.mean().item())
    ppl = math.exp(nll) if nll < 50 else float("inf")

    token_logp = log_probs[torch.arange(shift_labels_c.size(0)), shift_labels_c]
    token_prob = torch.exp(token_logp)
    rare_mask = token_prob < rare_prob_thresh
    rare_frac = float(rare_mask.float().mean().item())
    return nll, ppl, rare_frac


def score_continuation_transformer(model: TransformerLM, x_full: torch.Tensor, T0: int):
    device = next(model.parameters()).device
    x_full = x_full.to(device)
    B, T = x_full.shape
    max_len = model.cfg.block_size

    if T > max_len:
        x_full = x_full[:, -max_len:]
        T = max_len
        T0 = min(T0, T - 1)

    logits = model(x_full)
    B, Tm, V = logits.shape
    assert Tm == T

    start_idx = max(0, T0 - 1)
    logits_c = logits[:, start_idx:-1, :]
    targets = x_full[:, start_idx + 1:]

    V = logits_c.size(-1)
    nll = F.cross_entropy(
        logits_c.contiguous().view(-1, V),
        targets.contiguous().view(-1),
        reduction='mean'
    ).item()
    ppl = math.exp(nll) if nll < 50 else float("inf")
    return {"nll": nll, "ppl": ppl}


def evaluate_all_models(
    gp_model: GPVAE_TCN,
    tf_model: TransformerLM,
    tok,
    gpt2_judge=None,
    gpt2_tok=None,
    block_size: int = 64,
):
    device = gp_model.cfg.device
    gp_model.eval()
    tf_model.eval()

    prompt_text = "The meaning of life"
    ids = tok.encode(prompt_text)[:block_size]
    eos_id = getattr(tok, "eos_token_id", 0)
    if len(ids) == 0:
        ids = [eos_id]
    prompt_ids = torch.tensor([ids], dtype=torch.long, device=device)
    T0 = prompt_ids.size(1)
    total_len = min(block_size * 2, T0 + 64)

    
    # 1) Générations
    
    # GP-VAE-TCN+ AR
    x_gp_ar, logits_gp_ar = gp_model.generate_with_prompt(
        prompt_ids, total_len, eos_id,
        top_k=50, top_p=0.9, temperature=0.9
    )
    cont_gp_ar = x_gp_ar[0, T0:].tolist()
    sc_gp_ar = score_continuation_gpvae_from_logits(gp_model, x_gp_ar, T0)

    # GP-VAE-TCN+ no-AR
    x_gp_na, logits_gp_na = gp_model.generate_with_prompt_no_ar(
        prompt_ids, total_len, eos_id,
        top_k=50, top_p=0.9, temperature=0.9
    )
    cont_gp_na = x_gp_na[0, T0:].tolist()
    sc_gp_na = score_continuation_gpvae_from_logits(gp_model, x_gp_na, T0)

    # Transformer baseline
    x_tf = tf_model.generate(prompt_ids, total_len, top_k=50, top_p=0.9, temperature=0.9)
    cont_tf = x_tf[0, T0:].tolist()
    sc_tf = score_continuation_transformer(tf_model, x_tf, T0)

    
    # 2) Répétition / non-ASCII
    
    rep2_ar, rep3_ar, consec_ar = _repetition_metrics_for_ids(cont_gp_ar)
    rep2_na, rep3_na, consec_na = _repetition_metrics_for_ids(cont_gp_na)
    rep2_tf, rep3_tf, consec_tf = _repetition_metrics_for_ids(cont_tf)

    text_ar = tok.decode(cont_gp_ar)
    text_na = tok.decode(cont_gp_na)
    text_tf = tok.decode(cont_tf)
    na_ar = _non_ascii_fraction(text_ar)
    na_na = _non_ascii_fraction(text_na)
    na_tf = _non_ascii_fraction(text_tf)

    print("\n---== Comparative evaluation (GP-VAE-TCN+ AR vs no-AR vs Transformer) ---==")
    print("Own-model continuation metrics:")
    print(f"[GP-AR ] NLL={sc_gp_ar['nll']:.4f} | PPL={sc_gp_ar['ppl']:.2f}")
    print(f"[GP-noA] NLL={sc_gp_na['nll']:.4f} | PPL={sc_gp_na['ppl']:.2f}")
    print(f"[TF    ] NLL={sc_tf['nll']:.4f} | PPL={sc_tf['ppl']:.2f}")

    
    # 3) GPT-2 judge (optionnel)
    
    if (gpt2_judge is not None) and (gpt2_tok is not None):
        try:
            prompt_list = prompt_ids[0].tolist()
            nll_j_ar, ppl_j_ar, rare_ar = _gpt2_continuation_metrics(gpt2_judge, gpt2_tok, prompt_list, cont_gp_ar)
            nll_j_na, ppl_j_na, rare_na = _gpt2_continuation_metrics(gpt2_judge, gpt2_tok, prompt_list, cont_gp_na)
            nll_j_tf, ppl_j_tf, rare_tf = _gpt2_continuation_metrics(gpt2_judge, gpt2_tok, prompt_list, cont_tf)

            print("\n(Attention: GPT-2 est un modèle AR token-level ; ses PPL sont biaisées "
                  "contre les modèles non-AR. On les rapporte ici à titre informatif, "
                  "la comparaison principale se fait sur les perplexités intrinsèques.)")

            print("\nGPT-2 judge metrics:")
            print(f"[GP-AR ] GPT2-PPL={ppl_j_ar:.2f} | GPT2-NLL={nll_j_ar:.4f} | rare_frac={rare_ar:.3f}")
            print(f"[GP-noA] GPT2-PPL={ppl_j_na:.2f} | GPT2-NLL={nll_j_na:.4f} | rare_frac={rare_na:.3f}")
            print(f"[TF    ] GPT2-PPL={ppl_j_tf:.2f} | GPT2-NLL={nll_j_tf:.4f} | rare_frac={rare_tf:.3f}")
        except Exception as e:
            print(f"\n[WARN] GPT-2 judge failed during evaluation: {e}")
    else:
        print("\n[WARN] GPT-2 judge not available, skipping GPT-2-based metrics.")

    print("\nSurface repetition / non-ASCII:")
    print(f"[GP-AR ] rep2={rep2_ar:.3f} | rep3={rep3_ar:.3f} | consec={consec_ar:.3f} | non_ascii={na_ar:.3f}")
    print(f"[GP-noA] rep2={rep2_na:.3f} | rep3={rep3_na:.3f} | consec={consec_na:.3f} | non_ascii={na_na:.3f}")
    print(f"[TF    ] rep2={rep2_tf:.3f} | rep3={rep3_tf:.3f} | consec={consec_tf:.3f} | non_ascii={na_tf:.3f}")
    print("---------------------------------------------------------------------------==\n")


# Training GP-VAE-TCN+ and hyperparameter search


def apply_overrides(cfg: Config, overrides: Optional[Dict[str, Any]] = None) -> Config:
    if not overrides:
        return cfg
    for k, v in overrides.items():
        if hasattr(cfg, k):
            setattr(cfg, k, v)
    return cfg


def train_gpvae(block_size=256, batch_size=16, steps=1000, lr=2e-4,
                log_every=50, cfg_overrides: Optional[Dict[str, Any]] = None):
    train_ds, tok = load_wikitext103(block_size=block_size, split='train')
    val_ds, _ = load_wikitext103(block_size=block_size, split='validation')

    cfg = Config(vocab_size=tok.vocab_size, block_size=block_size, lr=lr)
    cfg = apply_overrides(cfg, cfg_overrides)
    device = cfg.device

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=True)

    model = GPVAE_TCN(cfg).to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=cfg.lr, betas=(0.9, 0.95), weight_decay=cfg.weight_decay)
    scaler = torch.amp.GradScaler('cuda', enabled=device.startswith("cuda"))
    warmup_steps = int(os.getenv('KL_WARMUP_STEPS', '4000'))

    hist_elbo, hist_ll0, hist_llm, hist_kl = [], [], [], []
    hist_kl_cap, hist_beta, hist_embreg, hist_tokps = [], [], [], []

    model.train()
    t0 = time.perf_counter()
    beta = cfg.beta_init
    model._beta_state = beta

    for step, x in enumerate(train_loader, 1):
        x = x.to(device)
        optim.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=device.startswith("cuda")):
            elbo_tok_t, stats, comps, _ = model.elbo(x, beta_override=beta)
            if step < warmup_steps:
                beta = min(cfg.beta_max, beta + (cfg.beta_max - cfg.beta_init) / max(1, warmup_steps))
            elif cfg.use_adaptive_beta:
                diff = float(comps['kl_tok_cap_t'].detach().item()) - cfg.kl_target_nats
                beta = float(beta) * (1.0 + cfg.beta_adapt_rate * (1.0 if diff > 0 else -1.0))
                beta = float(np.clip(beta, 1e-4, cfg.beta_max))
            model._beta_state = beta
            loss = -elbo_tok_t

        scaler.scale(loss).backward()

        gc = getattr(cfg, "grad_clip", 0.0)
        if gc and gc > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), gc)

        scaler.step(optim)
        scaler.update()

        hist_elbo.append(float(stats['elbo_tok']))
        hist_ll0.append(float(stats['ll0_tok']))
        hist_llm.append(float(stats['ll_multi_tok']))
        hist_kl.append(float(stats['kl_tok_raw']))
        hist_kl_cap.append(float(stats['kl_tok_cap']))
        hist_beta.append(float(stats['beta']))
        hist_embreg.append(float(stats['emb_reg']))

        if step % log_every == 0:
            dt = time.perf_counter() - t0
            tokps = (batch_size * block_size * log_every) / max(1e-9, dt)
            hist_tokps.append(float(tokps))
            print(
                f"step {step} | elbo/tok {stats['elbo_tok']:.3f} | ll0 {stats['ll0_tok']:.3f} | "
                f"ll_multi {stats['ll_multi_tok']:.3f} | kl_raw {stats['kl_tok_raw']:.3f} | "
                f"kl_cap {stats['kl_tok_cap']:.3f} | beta {stats['beta']:.3f} | tok/s {tokps:,.0f}"
            )
            t0 = time.perf_counter()
        if step >= steps:
            break

    plot_curve(hist_elbo, "ELBO/token evolution", "ELBO/token", "elbo_token.png")
    plot_curve(hist_ll0, "LL0/token evolution (k=0)", "LL0/token", "ll0_token.png")
    plot_curve(hist_llm, "LL_multi/token evolution (k=1..K)", "LL_multi/token", "llmulti_token.png")
    plot_curve(hist_kl, "KL/token (raw) evolution", "KL/token", "kl_token_raw.png")
    plot_curve(hist_kl_cap, "KL/token (capped) evolution", "KL_cap/token", "kl_token_cap.png")
    plot_curve(hist_beta, "β evolution", "β", "beta.png")
    plot_curve(hist_embreg, "Semantic regularization evolution", "emb_reg", "emb_reg.png")
    if len(hist_tokps) > 0:
        plot_curve(hist_tokps, "Token throughput evolution", "tok/s", "tok_per_s.png")

    model.eval()
    total_nll = 0.0
    total_tokens = 0
    agg_stats = {
        "elbo_tok": [],
        "ll0_tok": [],
        "ll_multi_tok": [],
        "kl_tok_cap": [],
        "kl_tok_raw": [],
        "emb_reg": [],
        "beta": [],
    }
    logits_for_diag = None
    x_for_diag = None

    with torch.no_grad():
        for i, x_val in enumerate(val_loader):
            x_val = x_val.to(device)
            elbo_tok_v, st_v, comps_v, logits_v = model.elbo(x_val, beta_override=model._beta_state)
            V = logits_v.size(-1)
            nll_batch = F.cross_entropy(
                logits_v.reshape(-1, V),
                x_val.reshape(-1),
                reduction='sum'
            ).item()
            total_nll += nll_batch
            total_tokens += x_val.numel()

            for k in agg_stats.keys():
                agg_stats[k].append(float(st_v[k]))

            if i == 0:
                logits_for_diag = logits_v.detach().clone()
                x_for_diag = x_val.detach().clone()

    if total_tokens == 0:
        nll = float("inf")
    else:
        nll = total_nll / total_tokens
    ppl = math.exp(nll) if nll < 50 else float("inf")

    st_v = {k: float(np.mean(v)) for k, v in agg_stats.items()}

    if logits_for_diag is not None and x_for_diag is not None:
        bin_conf, bin_acc, ece = reliability_diagram_data(logits_for_diag, x_for_diag, n_bins=15)
        nll_pos = nll_per_token(logits_for_diag, x_for_diag)
        plot_reliability(bin_conf, bin_acc, ece, "reliability.png")
        plot_nll_per_position(nll_pos, "nll_per_position.png")
    else:
        print("[WARN] No batch for reliability diagram.")
        bin_conf = bin_acc = nll_pos = None
        ece = 0.0

    print(
        f"[VAL] elbo/tok {st_v['elbo_tok']:.3f} | ll0 {st_v['ll0_tok']:.3f} | "
        f"ll_multi {st_v['ll_multi_tok']:.3f} | kl_cap {st_v['kl_tok_cap']:.3f} | ppl {ppl:.2f}"
    )

    T = x_val.shape[1]
    T0 = int(float(os.getenv('T0_FRAC', '0.5')) * T)
    sc = score_continuation_gpvae_from_logits(model, x_val, T0)

    ckpt_path = os.path.abspath('gpvae_best.pt')
    torch.save({'model_state_dict': model.state_dict(), 'config': asdict(cfg)}, ckpt_path)

    metrics = {
        "val_elbo_tok": float(st_v['elbo_tok']),
        "val_ppl": float(ppl),
        "cont_nll": float(sc['nll']),
        "cont_ppl": float(sc['ppl']),
        "kl_eff_nats": float(st_v['kl_tok_cap']),
        "beta_final": float(model._beta_state),
        "tok_s": float(hist_tokps[-1] if len(hist_tokps) > 0 else 0.0),
        "checkpoint_path": ckpt_path,
    }

    figs = {
        "elbo_token_png": "elbo_token.png",
        "ll0_token_png": "ll0_token.png",
        "llmulti_token_png": "llmulti_token.png",
        "kl_token_raw_png": "kl_token_raw.png",
        "kl_token_cap_png": "kl_token_cap.png",
        "beta_png": "beta.png",
        "emb_reg_png": "emb_reg.png",
        "tok_per_s_png": "tok_per_s.png",
        "reliability_png": "reliability.png",
        "nll_per_position_png": "nll_per_position.png",
    }
    return model, tok, metrics, figs


@dataclass
class SearchSpace:
    metric: str = "cont_nll"
    n_trials: int = 30
    algo: str = "optuna"
    seed: int = 42
    outdir: str = "hpsearch_runs/gpvae_tcn_plus"
    block_size: int = 256
    batch_size: int = 16
    max_steps: int = 1000
    val_every: int = 200
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    lr_log10_min: float = -5.0
    lr_log10_max: float = -3.0
    beta_max_min: float = 0.05
    beta_max_max: float = 0.8
    kl_cap_choices: List[float] = None
    free_bits_choices: List[float] = None
    kl_target_min: float = 4.0
    kl_target_max: float = 14.0
    tcn_stacks_min: int = 2
    tcn_stacks_max: int = 4
    tcn_blocks_min: int = 2
    tcn_blocks_max: int = 4
    d_model_min: int = 192
    d_model_max: int = 512
    kernel_choices: List[int] = None
    dropout_min: float = 0.0
    dropout_max: float = 0.25
    droppath_min: float = 0.0
    droppath_max: float = 0.2
    gp_lengthscale_min: float = 2.0
    gp_lengthscale_max: float = 24.0
    gp_var_min: float = 0.2
    gp_var_max: float = 3.0
    gp_noise_min: float = 1e-5
    gp_noise_max: float = 1e-2
    t0_frac_min: float = 0.3
    t0_frac_max: float = 0.8
    weight_decay_min: float = 0.0
    weight_decay_max: float = 0.1
    grad_clip_choices: List[float] = None

    def __post_init__(self):
        if self.kl_cap_choices is None:
            self.kl_cap_choices = [8.0, 10.0, 12.0, 14.0, 16.0]
        if self.free_bits_choices is None:
            self.free_bits_choices = [0.0, 0.1, 0.2, 0.3]
        if self.kernel_choices is None:
            self.kernel_choices = [3, 5, 7]
        if self.grad_clip_choices is None:
            self.grad_clip_choices = [0.0, 0.5, 1.0, 2.0]


def ensure_outdir(p: str) -> str:
    pth = pathlib.Path(p)
    pth.mkdir(parents=True, exist_ok=True)
    return str(pth)


def sample_random(space: SearchSpace, rng: random.Random) -> Dict[str, Any]:
    def rfloat(a, b):
        return a + (b - a) * rng.random()

    cfg_over = {
        "lr": 10 ** rfloat(space.lr_log10_min, space.lr_log10_max),
        "weight_decay": rfloat(space.weight_decay_min, space.weight_decay_max),
        "beta_init": 5e-4,
        "beta_max": rfloat(space.beta_max_min, space.beta_max_max),
        "kl_cap_nats": rng.choice(space.kl_cap_choices),
        "free_bits_nats": rng.choice(space.free_bits_choices),
        "kl_target_nats": rfloat(space.kl_target_min, space.kl_target_max),
        "tcn_stacks": rng.randint(space.tcn_stacks_min, space.tcn_stacks_max),
        "tcn_blocks_per_stack": rng.randint(space.tcn_blocks_min, space.tcn_blocks_max),
        "d_model": int(rfloat(space.d_model_min, space.d_model_max) // 32 * 32),
        "tcn_kernel": rng.choice(space.kernel_choices),
        "tcn_dropout": rfloat(space.dropout_min, space.dropout_max),
        "tcn_droppath": rfloat(space.droppath_min, space.droppath_max),
        "gp_lengthscale_init": rfloat(space.gp_lengthscale_min, space.gp_lengthscale_max),
        "gp_variance_init": rfloat(space.gp_var_min, space.gp_var_max),
        "_t0_frac": rfloat(space.t0_frac_min, space.t0_frac_max),
        "grad_clip": rng.choice(space.grad_clip_choices),
        "_block_size": space.block_size,
        "_batch_size": space.batch_size,
        "_max_steps": space.max_steps,
        "_device": space.device,
    }
    cfg_over["_gp_noise"] = 10 ** rfloat(math.log10(space.gp_noise_min), math.log10(space.gp_noise_max))
    return cfg_over


def run_one_trial(overrides: Dict[str, Any]) -> Dict[str, Any]:
    block_size = overrides.pop("_block_size", 256)
    batch_size = overrides.pop("_batch_size", 16)
    max_steps = overrides.pop("_max_steps", 1000)
    device = overrides.pop("_device", "cuda" if torch.cuda.is_available() else "cpu")
    t0_frac = float(overrides.pop("_t0_frac", 0.5))
    gp_noise = float(overrides.pop("_gp_noise", 1e-5))

    os.environ['T0_FRAC'] = str(t0_frac)

    model, tok, metrics, figs = train_gpvae(
        block_size=block_size, batch_size=batch_size, steps=max_steps,
        lr=overrides.get('lr', 2e-4), cfg_overrides=overrides)

    metrics["t0_frac"] = t0_frac
    metrics["device"] = device
    metrics["notes"] = ""
    return {"metrics": metrics, "cfg_overrides": overrides}


CSV_NAME = "trials.csv"
BEST_JSON = "best_config.json"
BEST_METRICS_JSON = "best_metrics.json"
BEST_CKPT = "best.ckpt"


def write_rows_csv(path: str, rows: List[Dict[str, Any]]):
    import csv
    if not rows:
        return
    keys = sorted({k for r in rows for k in r.keys()})
    with open(path, 'w', newline='', encoding='utf-8') as f:
        w = csv.DictWriter(f, fieldnames=keys)
        w.writeheader()
        for r in rows:
            w.writerow(r)


def append_row_csv(path: str, row: Dict[str, Any]):
    if not os.path.exists(path):
        write_rows_csv(path, [row])
    else:
        import csv
        with open(path, 'r', encoding='utf-8') as f:
            header = f.readline().strip().split(',')
        for k in row.keys():
            if k not in header:
                try:
                    import pandas as pd
                    df = pd.read_csv(path)
                    df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
                    df.to_csv(path, index=False)
                    return
                except Exception:
                    pass
        with open(path, 'a', newline='', encoding='utf-8') as f:
            w = csv.DictWriter(f, fieldnames=header)
            w.writerow(row)


def save_best(outdir: str, best: Dict[str, Any]):
    os.makedirs(outdir, exist_ok=True)
    with open(os.path.join(outdir, BEST_JSON), 'w', encoding='utf-8') as f:
        json.dump(best["cfg"], f, indent=2)
    with open(os.path.join(outdir, BEST_METRICS_JSON), 'w', encoding='utf-8') as f:
        json.dump(best["metrics"], f, indent=2)
    ckpt = best["metrics"].get("checkpoint_path")
    if ckpt and os.path.exists(ckpt):
        dst = os.path.join(outdir, BEST_CKPT)
        try:
            if os.path.exists(dst):
                os.remove(dst)
            os.symlink(os.path.abspath(ckpt), dst)
        except Exception:
            shutil.copy2(ckpt, dst)


def run_random_search(space: SearchSpace) -> Dict[str, Any]:
    rng = random.Random(space.seed)
    outdir = ensure_outdir(space.outdir)
    rows: List[Dict[str, Any]] = []
    best: Optional[Dict[str, Any]] = None
    for t in range(space.n_trials):
        overrides = sample_random(space, rng)
        print(f"[RandomSearch] Trial {t + 1}/{space.n_trials}: {overrides}")
        try:
            ret = run_one_trial(overrides)
        except Exception as e:
            print(f"Trial {t + 1} failed: {e}")
            continue
        metrics = ret["metrics"]
        cfg_over = ret["cfg_overrides"]
        score = metrics.get(space.metric, float('inf'))
        row = {**cfg_over, **metrics, "trial": t, space.metric: score}
        append_row_csv(os.path.join(outdir, CSV_NAME), row)
        rows.append(row)
        if (best is None) or (score < best["score"]):
            best = {"score": score, "cfg": cfg_over, "metrics": metrics}
            save_best(outdir, best)
    return best or {}


def run_optuna_search(space: SearchSpace) -> Dict[str, Any]:
    if not _HAVE_OPTUNA:
        print("[INFO] Optuna not available → falling back to random search.")
        return run_random_search(space)
    outdir = ensure_outdir(space.outdir)

    def objective(trial: 'optuna.Trial'):
        cfg_over = {
            "lr": 10 ** trial.suggest_float("lr_log10", space.lr_log10_min, space.lr_log10_max),
            "weight_decay": trial.suggest_float("weight_decay", space.weight_decay_min, space.weight_decay_max),
            "beta_init": 5e-4,
            "beta_max": trial.suggest_float("beta_max", space.beta_max_min, space.beta_max_max),
            "kl_cap_nats": trial.suggest_categorical("kl_cap_nats", space.kl_cap_choices),
            "free_bits_nats": trial.suggest_categorical("free_bits_nats", space.free_bits_choices),
            "kl_target_nats": trial.suggest_float("kl_target_nats", space.kl_target_min, space.kl_target_max),
            "tcn_stacks": trial.suggest_int("tcn_stacks", space.tcn_stacks_min, space.tcn_stacks_max),
            "tcn_blocks_per_stack": trial.suggest_int("tcn_blocks_per_stack", space.tcn_blocks_min, space.tcn_blocks_max),
            "d_model": trial.suggest_int("d_model", space.d_model_min, space.d_model_max, step=32),
            "tcn_kernel": trial.suggest_categorical("tcn_kernel", space.kernel_choices),
            "tcn_dropout": trial.suggest_float("tcn_dropout", space.dropout_min, space.dropout_max),
            "tcn_droppath": trial.suggest_float("tcn_droppath", space.droppath_min, space.droppath_max),
            "gp_lengthscale_init": trial.suggest_float("gp_lengthscale_init", space.gp_lengthscale_min, space.gp_lengthscale_max),
            "gp_variance_init": trial.suggest_float("gp_variance_init", space.gp_var_min, space.gp_var_max),
            "_t0_frac": trial.suggest_float("t0_frac", space.t0_frac_min, space.t0_frac_max),
            "grad_clip": trial.suggest_categorical("grad_clip", space.grad_clip_choices),
            "_block_size": space.block_size,
            "_batch_size": space.batch_size,
            "_max_steps": space.max_steps,
            "_device": space.device,
            "_gp_noise": 10 ** trial.suggest_float("gp_noise_log10",
                                                   math.log10(space.gp_noise_min),
                                                   math.log10(space.gp_noise_max)),
        }
        ret = run_one_trial(cfg_over)
        metrics = ret["metrics"]
        score = metrics.get(space.metric)
        if score is None:
            raise RuntimeError(f"Metric {space.metric} missing")
        row = {**cfg_over, **metrics, "trial": trial.number, space.metric: score}
        append_row_csv(os.path.join(outdir, CSV_NAME), row)
        trial.report(score, step=1)
        return score

    study = optuna.create_study(direction="minimize", pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=1))
    study.optimize(objective, n_trials=space.n_trials, show_progress_bar=True)
    bt = study.best_trial
    best_cfg = {
        "lr": 10 ** bt.params.get("lr_log10", -4.0),
        "weight_decay": bt.params.get("weight_decay"),
        "beta_init": 5e-4,
        "beta_max": bt.params.get("beta_max"),
        "kl_cap_nats": bt.params.get("kl_cap_nats"),
        "free_bits_nats": bt.params.get("free_bits_nats"),
        "kl_target_nats": bt.params.get("kl_target_nats"),
        "tcn_stacks": bt.params.get("tcn_stacks"),
        "tcn_blocks_per_stack": bt.params.get("tcn_blocks_per_stack"),
        "d_model": bt.params.get("d_model"),
        "tcn_kernel": bt.params.get("tcn_kernel"),
        "tcn_dropout": bt.params.get("tcn_dropout"),
        "tcn_droppath": bt.params.get("tcn_droppath"),
        "gp_lengthscale_init": bt.params.get("gp_lengthscale_init"),
        "gp_variance_init": bt.params.get("gp_variance_init"),
        "grad_clip": bt.params.get("grad_clip"),
    }
    ret = run_one_trial({**best_cfg,
                         "_t0_frac": 0.6,
                         "_block_size": space.block_size,
                         "_batch_size": space.batch_size,
                         "_max_steps": space.max_steps,
                         "_device": space.device,
                         "_gp_noise": 1e-5})
    best = {"score": ret["metrics"][space.metric], "cfg": best_cfg, "metrics": ret["metrics"]}
    save_best(space.outdir, best)
    return best


# CLI

import argparse, sys


def parse_args(argv=None):
    p = argparse.ArgumentParser(description="GP-VAE-TCN+ with integrated hyperparameter search")
    p.add_argument('--search', type=str, default='none', choices=['none', 'random', 'optuna'])
    p.add_argument('--n-trials', type=int, default=20)
    p.add_argument('--metric', type=str, default='cont_nll')
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--outdir', type=str, default='hpsearch_runs/gpvae_tcn_plus')
    # Par défaut, contexte plus long pour WikiText-103
    p.add_argument('--block-size', type=int, default=3072)
    p.add_argument('--batch-size', type=int, default=1)
    p.add_argument('--steps', type=int, default=1000)
    p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    p.add_argument('--grad-clip', type=float, default=None, help='Override for Config.grad_clip')
    args, _ = p.parse_known_args(argv)
    return args


def main():
    args = parse_args()
    if args.search == 'none':
        overrides = {}
        if args.grad_clip is not None:
            overrides['grad_clip'] = float(args.grad_clip)

        # 1) Entraînement GP-VAE-TCN+ sur WikiText-103
        model, tok, metrics, figs = train_gpvae(
            block_size=args.block_size,
            batch_size=args.batch_size,
            steps=args.steps,
            cfg_overrides=overrides if overrides else None
        )
        print("\n[VAL GP-VAE-TCN+]", metrics)
        model.eval()
        device = next(model.parameters()).device

        # 2) Entraînement Transformer baseline sur le même dataset
        train_ds, _ = load_wikitext103(block_size=args.block_size, split='train')
        val_ds, _ = load_wikitext103(block_size=args.block_size, split='validation')

        tf_steps = int(os.getenv("TF_STEPS", args.steps * 2))

        tf_model, tf_cfg, tf_metrics = train_transformer_baseline(
            train_ds, val_ds, tok,
            block_size=args.block_size,
            batch_size=args.batch_size,
            steps=tf_steps,
            lr=3e-4,
            d_model=640,
            n_layer=10,
            n_head=10,
            d_ff=2560,
            dropout=0.1,
            log_every=50,
        )
        print("\n[VAL TF]", tf_metrics)

        # 3) Chargement du juge GPT-2
        judge_name = os.getenv("JUDGE_MODEL_NAME", "gpt2")
        try:
            gpt2_tok = AutoTokenizer.from_pretrained(judge_name)
            gpt2_tok.pad_token = gpt2_tok.eos_token
            gpt2_tok.model_max_length = int(1e9)
            if hasattr(gpt2_tok, "init_kwargs"):
                gpt2_tok.init_kwargs["model_max_length"] = int(1e9)
            gpt2_model = AutoModelForCausalLM.from_pretrained(judge_name).to(device)
            gpt2_model.eval()
            print(f"[INFO] GPT-2 judge loaded: {judge_name}")
        except Exception as e:
            print(f"[Warning] Could not load GPT-2 judge ({judge_name}): {e}")
            gpt2_model = None
            gpt2_tok = None

        # 4) Démo de génération simple
        prompt_text = "The meaning of life"
        ids = tok.encode(prompt_text)[:args.block_size]
        eos_id = getattr(tok, 'eos_token_id', 0)
        ids = ids if len(ids) > 0 else [eos_id]
        prompt_ids = torch.tensor([ids], dtype=torch.long, device=device)
        total_len = min(args.block_size * 2, len(ids) + 64)

        x_out = model.generate(T=64, batch_size=1, top_k=50, top_p=0.9, temperature=0.9)
        print("\n[Sample GP-VAE-TCN+ Unconditional]\n", tok.decode(x_out[0].tolist()))
        x_out2, _ = model.generate_with_prompt(prompt_ids, total_len, eos_id)
        print("\n[Sample GP-VAE-TCN+ Prompt-conditioned (AR)]\n", tok.decode(x_out2[0].tolist()))
        x_out2_na, _ = model.generate_with_prompt_no_ar(prompt_ids, total_len, eos_id)
        print("\n[Sample GP-VAE-TCN+ Prompt-conditioned (no-AR)]\n", tok.decode(x_out2_na[0].tolist()))
        x_tf = tf_model.generate(prompt_ids, total_len, top_k=50, top_p=0.9, temperature=0.9)
        print("\n[Sample Transformer baseline]\n", tok.decode(x_tf[0].tolist()))

        evaluate_all_models(
            model,
            tf_model,
            tok,
            gpt2_model,   # peut être None, la fonction gère ce cas
            gpt2_tok,     # idem
            block_size=args.block_size,
        )
        return


    # Recherche d'hyperparamètres
    
    space = SearchSpace(metric=args.metric, n_trials=args.n_trials, algo=args.search,
                        seed=args.seed, outdir=args.outdir, block_size=args.block_size,
                        batch_size=args.batch_size, max_steps=args.steps, device=args.device)
    os.makedirs(space.outdir, exist_ok=True)
    with open(pathlib.Path(space.outdir) / 'search_space.json', 'w', encoding='utf-8') as f:
        json.dump(asdict(space), f, indent=2)
    if args.search == 'random':
        best = run_random_search(space)
    else:
        best = run_optuna_search(space)
    if best:
        print("\n--- BEST ---")
        print(json.dumps(best["cfg"], indent=2))
        print("Metric", space.metric, "=", best["score"])
        print("Checkpoint:", best["metrics"].get("checkpoint_path"))
    else:
        print("No successful trial.")


if __name__ == "__main__":
    main()


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

step 50 | elbo/tok -18.871 | ll0 -9.399 | ll_multi -9.191 | kl_raw 38419.527 | kl_cap 12.000 | beta 0.007 | tok/s 3,310
step 100 | elbo/tok -17.581 | ll0 -7.912 | ll_multi -9.318 | kl_raw 153320.234 | kl_cap 12.000 | beta 0.013 | tok/s 3,424
step 150 | elbo/tok -15.630 | ll0 -6.397 | ll_multi -8.817 | kl_raw 348843.750 | kl_cap 12.000 | beta 0.020 | tok/s 3,422
step 200 | elbo/tok -14.654 | ll0 -5.498 | ll_multi -8.669 | kl_raw 440211.594 | kl_cap 12.000 | beta 0.026 | tok/s 3,425
step 250 | elbo/tok -14.455 | ll0 -5.262 | ll_multi -8.636 | kl_raw 506316.719 | kl_cap 12.000 | beta 0.032 | tok/s 3,425
step 300 | elbo/tok -13.222 | ll0 -4.350 | ll_multi -8.247 | kl_raw 634882.562 | kl_cap 12.000 | beta 0.038 | tok/s 3,425
step 350 | elbo/tok -12.982 | ll0 -3.891 | ll_multi -8.393 | kl_raw 743491.188 | kl_cap 12.000 | beta 0.045 | tok/s 3,423
step 400 | elbo/tok -12.656 | ll0 -3.639 | ll_multi -8.246 | kl_raw 810965.938 | kl_cap 12.000 | beta 0.051 | tok/s 3,422
step 450 | elbo/tok -12.08

2025-12-11 22:30:59.172171: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765492259.369567      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765492259.424187      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

[INFO] GPT-2 judge loaded: gpt2

[Sample GP-VAE-TCN+ Unconditional]
 ublebreak 389 GA via via viainninginningINClicts tribal Someone528528528 Feel metab become 205 205 unsett BM Excellenceaocity Sales prohibitionmusic302 Chev6000529 + + Rd to toタ resize needs he areasoriginal single single mayhem from from Luckormonsinning Animal transmission semifinals Kenny . the the the theywareSaudi GG

[Sample GP-VAE-TCN+ Prompt-conditioned (AR)]
 The meaning of life during regulates En Calls during during C wasCBSEmilyrary dissect wal asses assesblockingche tid Clubs Independent improving submarinesaidsaid recession but but Someone stagnant Hud smoked smoked smoked smoked Dana Top distinguishedointmentvoice天 partsJCforcementforcementIALshould boycott strategist Value punished punishedimil ItshouldGaming " shortcuts Yale toice reacts zoning to Foundation

[Sample GP-VAE-TCN+ Prompt-conditioned (no-AR)]
 The meaning of liferary FTC grain laugh testing from can hats Moorilities antagonists in bund r