In [1]:
import math, os, time, json, shutil, random, pathlib
from dataclasses import dataclass, asdict
from typing import Tuple, Optional, Dict, Any, List

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

from torch.nn.utils.parametrizations import weight_norm as pn_weight_norm

try:
    import optuna
    from optuna.pruners import MedianPruner
    _HAVE_OPTUNA = True
except Exception:
    _HAVE_OPTUNA = False

try:
    from datasets import load_dataset
    from transformers import AutoTokenizer
except Exception:
    load_dataset = None
    AutoTokenizer = None
    print("[Avertissement] Installez datasets et transformers : pip install datasets transformers")


class LMBlocks(Dataset):
    def __init__(self, texts, tokenizer, block_size: int = 64):
        ids = []
        for t in texts:
            ids.extend(tokenizer.encode(t))
        n = max(1, len(ids) // block_size)
        ids = ids[: n * block_size]
        self.block_size = block_size
        self.data = torch.tensor(ids, dtype=torch.long).view(n, block_size)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]


def load_wikitext2(block_size=64, split='train'):
    assert load_dataset is not None and AutoTokenizer is not None, "Installez datasets/transformers"
    ds = load_dataset('wikitext', 'wikitext-2-raw-v1')[split]
    try:
        from transformers.utils import hub as _hf_hub_utils
        _hf_hub_utils.list_repo_templates = lambda *a, **k: []
        import transformers.tokenization_utils_base as _tub
        _tub.list_repo_templates = lambda *a, **k: []
    except Exception:
        pass
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token
    texts = ds['text']
    dataset = LMBlocks(texts, tokenizer, block_size)
    return dataset, tokenizer


def plot_curve(values, title, ylabel, out_png):
    plt.figure()
    plt.plot(values)
    plt.xlabel('Step')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()


def reliability_diagram_data(logits: torch.Tensor, x: torch.Tensor, n_bins: int = 15):
    with torch.no_grad():
        probs = torch.softmax(logits, dim=-1)
        conf, pred = probs.max(dim=-1)
        corr = (pred == x).float()
        conf = conf.detach().cpu().view(-1).numpy()
        corr = corr.detach().cpu().view(-1).numpy()
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(conf, bins) - 1
    bin_conf, bin_acc, bin_count = [], [], []
    ece = 0.0
    N = len(conf)
    for b in range(n_bins):
        m = (bin_ids == b)
        cnt = m.sum()
        if cnt == 0:
            bin_conf.append(0.0)
            bin_acc.append(0.0)
            bin_count.append(0)
            continue
        c = conf[m].mean()
        a = corr[m].mean()
        w = cnt / N
        ece += w * abs(a - c)
        bin_conf.append(c)
        bin_acc.append(a)
        bin_count.append(cnt)
    return np.array(bin_conf), np.array(bin_acc), float(ece)


def plot_reliability(confidences, accuracy, ece, out_png):
    plt.figure()
    plt.plot(confidences, accuracy, 'o-')
    plt.plot([0, 1], [0, 1], '--')
    plt.xlabel('Confiance')
    plt.ylabel('Précision')
    plt.title(f'Diagramme de fiabilité — ECE={ece:.4f}')
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()


def nll_per_token(logits: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    B, T, V = logits.shape
    ce = F.cross_entropy(logits.reshape(B * T, V), x.reshape(B * T), reduction='none').view(B, T)
    return ce.mean(dim=0).detach().cpu()


def plot_nll_per_position(nll_per_pos, out_png):
    plt.figure()
    plt.plot(nll_per_pos)
    plt.xlabel('Position t')
    plt.ylabel('Token NLL')
    plt.title('NLL par position')
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()


def rbf_kernel(t: torch.Tensor, lengthscale: torch.Tensor, variance: torch.Tensor) -> torch.Tensor:
    t = t.view(-1, 1)
    d2 = (t - t.T) ** 2
    return variance * torch.exp(-0.5 * d2 / (lengthscale ** 2))


def mvn_logpdf_zero_mean(z: torch.Tensor, K: torch.Tensor, jitter: float = 1e-5) -> torch.Tensor:
    B, T, D = z.shape
    I = torch.eye(T, device=K.device, dtype=K.dtype)
    L = torch.linalg.cholesky(K + jitter * I)
    log_det = 2.0 * torch.sum(torch.log(torch.diag(L)))
    z_bd = z.permute(0, 2, 1).reshape(B * D, T)
    y_bd = torch.cholesky_solve(z_bd.unsqueeze(-1), L).squeeze(-1)
    quad = (z_bd * y_bd).sum(dim=-1).view(B, D).sum(dim=-1)
    const = T * math.log(2 * math.pi)
    logp = -0.5 * (D * (log_det + const) + quad)
    return logp


def diag_mvn_logpdf(z: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    var = torch.exp(logvar)
    const = math.log(2 * math.pi)
    logp = -0.5 * (const + logvar + (z - mu) ** 2 / var)
    return logp.sum(dim=(1, 2))


class SinusoidalPos(nn.Module):
    def __init__(self, d: int):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(1))
        self.d = d

    def forward(self, T: int, device):
        D = self.d
        pe = torch.zeros(T, D, device=device)
        position = torch.arange(T, device=device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, D, 2, device=device) * (-math.log(10000.0) / D))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return (self.scale * pe).unsqueeze(0)


class CausalDepthwiseSeparableConv(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=1, out_channels=None):
        super().__init__()
        pad = (kernel_size - 1) * dilation
        self.left_pad = pad
        out_channels = out_channels if out_channels is not None else channels
        dw = nn.Conv1d(
            channels,
            channels,
            kernel_size,
            groups=channels,
            dilation=dilation,
            padding=pad,
            bias=True,
        )
        pw = nn.Conv1d(channels, out_channels, kernel_size=1, bias=True)
        self.dw = pn_weight_norm(dw)
        self.pw = pn_weight_norm(pw)

    def forward(self, x):
        h = self.dw(x)
        if self.left_pad > 0:
            h = h[..., :-self.left_pad]
        h = self.pw(h)
        return h


class SqueezeExcite(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        hidden = max(1, channels // reduction)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Conv1d(channels, hidden, kernel_size=1)
        self.fc2 = nn.Conv1d(hidden, channels, kernel_size=1)

    def forward(self, x):
        s = self.pool(x)
        s = F.gelu(self.fc1(s))
        s = torch.sigmoid(self.fc2(s))
        return x * s


class DropPath1D(nn.Module):
    def __init__(self, drop_prob=0.0):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if not self.training or self.drop_prob <= 0.0:
            return x
        keep = 1.0 - self.drop_prob
        shape = (x.size(0),) + (1,) * (x.ndim - 1)
        mask = (
            torch.empty(shape, dtype=x.dtype, device=x.device)
            .bernoulli_(keep)
            / keep
        )
        return x * mask


class TCNBlock(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=1, dropout=0.05, droppath=0.05):
        super().__init__()
        self.gn1 = nn.GroupNorm(1, channels)
        self.conv1 = CausalDepthwiseSeparableConv(
            channels, kernel_size, dilation, out_channels=2 * channels
        )
        self.gn2 = nn.GroupNorm(1, channels)
        self.conv2 = CausalDepthwiseSeparableConv(
            channels, kernel_size, dilation, out_channels=2 * channels
        )
        self.se = SqueezeExcite(channels, reduction=8)
        self.drop = nn.Dropout(dropout)
        self.droppath = DropPath1D(drop_prob=droppath)

    def forward(self, x):
        h = self.gn1(x)
        h = self.conv1(h)
        h = F.glu(h, dim=1)
        h = self.drop(h)
        h = self.gn2(h)
        h = self.conv2(h)
        h = F.glu(h, dim=1)
        h = self.se(h)
        h = self.drop(h)
        return x + self.droppath(h)


class HierarchicalTCN(nn.Module):
    def __init__(
        self,
        in_ch,
        hidden_ch,
        n_stacks=3,
        blocks_per_stack=2,
        kernel_size=3,
        base_dilations=(1, 2, 4, 8, 16),
        dropout=0.05,
        droppath=0.05,
    ):
        super().__init__()
        self.proj = nn.Conv1d(in_ch, hidden_ch, 1)
        self.blocks = nn.ModuleList()
        dil = list(base_dilations)
        idx = 0
        for _ in range(n_stacks * blocks_per_stack):
            d = dil[idx % len(dil)]
            self.blocks.append(
                TCNBlock(hidden_ch, kernel_size, d, dropout, droppath)
            )
            idx += 1

    def forward(self, x):
        h = self.proj(x)
        for blk in self.blocks:
            h = blk(h)
        return h


@dataclass
class Config:
    vocab_size: int
    d_model: int = 256
    d_latent: int = 64
    block_size: int = 64
    emb_dim: int = 256
    gp_lengthscale_init: float = 8.0
    gp_variance_init: float = 1.0
    lr: float = 2e-4
    label_smoothing: float = 0.03
    weight_decay: float = 0.01
    grad_clip: float = 1.0
    free_bits_nats: float = 0.5
    kl_cap_nats: float = 12.0
    kl_target_nats: float = 8.0
    beta_init: float = 1e-3
    beta_max: float = 0.5
    beta_adapt_rate: float = 0.05
    use_adaptive_beta: bool = True
    K_multi: int = 3
    multi_lambda_scheme: str = "harmonic"
    gamma_multi: float = 0.5
    logvar_min: float = -6.0
    logvar_max: float = 2.0
    logvar_init: float = -4.0
    tcn_stacks: int = 3
    tcn_blocks_per_stack: int = 3
    tcn_kernel: int = 3
    tcn_base_dilations: tuple = (1, 2, 4, 8, 16)
    tcn_dropout: float = 0.05
    tcn_droppath: float = 0.05
    embed_reg_weight: float = 0.2
    embed_reg_mode: str = "cos"
    dec_use_attn: bool = True
    dec_n_attn_layers: int = 1
    dec_n_heads: int = 4
    dec_attn_dropout: float = 0.1
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


class TCNEncoder(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.emb_dim)
        self.tcn = HierarchicalTCN(
            in_ch=cfg.emb_dim,
            hidden_ch=cfg.d_model,
            n_stacks=cfg.tcn_stacks,
            blocks_per_stack=cfg.tcn_blocks_per_stack,
            kernel_size=cfg.tcn_kernel,
            base_dilations=cfg.tcn_base_dilations,
            dropout=cfg.tcn_dropout,
            droppath=cfg.tcn_droppath,
        )
        self.to_mu = nn.Conv1d(cfg.d_model, cfg.d_latent, kernel_size=1)
        self.to_logvar = nn.Conv1d(cfg.d_model, cfg.d_latent, kernel_size=1)
        nn.init.zeros_(self.to_mu.weight)
        nn.init.zeros_(self.to_mu.bias)
        nn.init.zeros_(self.to_logvar.weight)
        nn.init.constant_(self.to_logvar.bias, cfg.logvar_init)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.tok_emb(x)
        h = h.transpose(1, 2)
        h = self.tcn(h)
        mu = self.to_mu(h).transpose(1, 2)
        logvar = self.to_logvar(h).transpose(1, 2)
        logvar = logvar.clamp(self.cfg.logvar_min, self.cfg.logvar_max)
        return mu, logvar


class CausalSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, D = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        mask = torch.triu(
            torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
        )
        att = att.masked_fill(mask, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, D)
        y = self.proj(y)
        return y


class TokenDecoder(nn.Module):
    def __init__(self, cfg: Config, tied_weight: torch.Tensor):
        super().__init__()
        self.cfg = cfg
        self.pe = SinusoidalPos(cfg.d_latent)
        self.mlp = nn.Sequential(
            nn.Linear(cfg.d_latent, cfg.d_model),
            nn.GELU(),
            nn.Linear(cfg.d_model, cfg.d_model),
            nn.GELU(),
        )
        self.ln = nn.LayerNorm(cfg.d_model)
        self.post = nn.Conv1d(cfg.d_model, cfg.d_model, kernel_size=3, padding=1)

        self.attn_blocks = None
        if cfg.dec_use_attn and cfg.dec_n_attn_layers > 0:
            self.attn_blocks = nn.ModuleList(
                [
                    nn.ModuleDict(
                        {
                            "ln": nn.LayerNorm(cfg.d_model),
                            "attn": CausalSelfAttention(
                                cfg.d_model, cfg.dec_n_heads, cfg.dec_attn_dropout
                            ),
                            "drop": nn.Dropout(cfg.dec_attn_dropout),
                        }
                    )
                    for _ in range(cfg.dec_n_attn_layers)
                ]
            )

        self.to_emb = nn.Linear(cfg.d_model, cfg.emb_dim, bias=False)
        self.tied_weight = tied_weight
        self.bias = nn.Parameter(torch.zeros(cfg.vocab_size))
        self.sem_head = nn.Linear(cfg.d_model, cfg.emb_dim, bias=False)

    def forward(self, z: torch.Tensor):
        z = z + self.pe(T=z.size(1), device=z.device)
        h = self.mlp(z)
        h = self.ln(h)
        h2 = self.post(h.transpose(1, 2)).transpose(1, 2)
        h = h + h2

        if self.attn_blocks is not None:
            for blk in self.attn_blocks:
                h_norm = blk["ln"](h)
                h_att = blk["attn"](h_norm)
                h = h + blk["drop"](h_att)

        e_proj = self.to_emb(h)
        tw = F.normalize(self.tied_weight, dim=-1)
        logits = torch.matmul(e_proj, tw.t()) + self.bias
        e_hat = self.sem_head(h)
        return logits, e_hat


class GPVAE_TCN(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.encoder = TCNEncoder(cfg)
        self.decoder = TokenDecoder(cfg, tied_weight=self.encoder.tok_emb.weight)
        self.register_buffer(
            "t_train", torch.arange(cfg.block_size, dtype=torch.float32)
        )
        self._ls_unc = nn.Parameter(torch.tensor(cfg.gp_lengthscale_init).log())
        self._var_unc = nn.Parameter(torch.tensor(cfg.gp_variance_init).log())
        self.softplus = nn.Softplus(beta=1.0)

    def gp_hypers(self):
        ls = self.softplus(self._ls_unc) + 1e-6
        var = self.softplus(self._var_unc) + 1e-6
        return ls, var

    def K_tt(self, t: torch.Tensor) -> torch.Tensor:
        ls, var = self.gp_hypers()
        return rbf_kernel(t, ls, var)

    def _label_smoothing_ce(self, logits, targets, eps):
        log_probs = F.log_softmax(logits, dim=-1)
        nll = F.nll_loss(log_probs.transpose(1, 2), targets, reduction="mean")
        smooth = -log_probs.mean(dim=-1).mean()
        return (1 - eps) * nll + eps * smooth

    def _multi_horizon_ll(self, logits, x) -> torch.Tensor:
        K = self.cfg.K_multi
        if K <= 0:
            return torch.zeros((), device=logits.device)
        B, T, V = logits.shape
        total = 0.0
        for k in range(1, K + 1):
            if T - k <= 0:
                break
            lam = 1.0 / k if self.cfg.multi_lambda_scheme == "harmonic" else 1.0
            ll_k = -self._label_smoothing_ce(
                logits[:, : T - k, :], x[:, k:], self.cfg.label_smoothing
            )
            total = total + self.cfg.gamma_multi * lam * ll_k
        return total

    def _embed_reg_losses(self, e_hat, x, token_emb_table: torch.Tensor):
        with torch.no_grad():
            e_tgt = token_emb_table[x]
        loss = 0.0
        if self.cfg.embed_reg_mode in ("mse", "cos+mse"):
            loss = loss + F.mse_loss(e_hat, e_tgt, reduction="mean")
        if self.cfg.embed_reg_mode in ("cos", "cos+mse"):
            loss = loss + (1.0 - F.cosine_similarity(e_hat, e_tgt, dim=-1).mean())
        return torch.as_tensor(loss, device=e_hat.device)

    def elbo(self, x: torch.Tensor, beta_override: Optional[float] = None):
        cfg = self.cfg
        B, T = x.shape
        mu, logvar = self.encoder(x)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        logits, e_hat = self.decoder(z)
        ll_0 = -self._label_smoothing_ce(logits, x, cfg.label_smoothing)
        ll_multi = self._multi_horizon_ll(logits, x)
        emb_reg = self._embed_reg_losses(
            e_hat, x, self.encoder.tok_emb.weight.detach()
        )
        tvec = torch.arange(T, dtype=torch.float32, device=x.device)
        K_full = self.K_tt(tvec)
        log_pz_b = mvn_logpdf_zero_mean(z, K_full)
        log_qz_b = diag_mvn_logpdf(z, mu, logvar)
        kl_tok_raw_t = (log_qz_b - log_pz_b).mean() / T
        kl_tok_capped_t = torch.clamp(
            torch.clamp(kl_tok_raw_t, min=cfg.free_bits_nats), max=cfg.kl_cap_nats
        )
        beta = (
            beta_override
            if beta_override is not None
            else getattr(self, "_beta_state", cfg.beta_init)
        )
        elbo_tok_t = (
            ll_0 + ll_multi
            - beta * kl_tok_capped_t
            - cfg.embed_reg_weight * emb_reg
        )
        stats = {
            "ll0_tok": float(ll_0.detach().item()),
            "ll_multi_tok": float(ll_multi.detach().item()),
            "emb_reg": float(emb_reg.detach().item())
            if torch.is_tensor(emb_reg)
            else float(emb_reg),
            "kl_tok_raw": float(kl_tok_raw_t.detach().item()),
            "kl_tok_cap": float(kl_tok_capped_t.detach().item()),
            "beta": float(beta),
            "elbo_tok": float(elbo_tok_t.detach().item()),
        }
        comps = {
            "ll0_tok_t": ll_0,
            "ll_multi_tok_t": ll_multi,
            "emb_reg_t": emb_reg,
            "kl_tok_raw_t": kl_tok_raw_t,
            "kl_tok_cap_t": kl_tok_capped_t,
            "elbo_tok_t": elbo_tok_t,
            "beta_t": torch.tensor(beta, device=x.device),
        }
        return elbo_tok_t, stats, comps, logits

    @torch.no_grad()
    def _gp_conditional_step(self, K: torch.Tensor, z_prefix: torch.Tensor, t_next: int):
        B, tp, Dz = z_prefix.shape
        if tp == 0:
            var0 = K[0, 0].clamp_min(1e-9)
            return torch.randn(B, Dz, device=z_prefix.device) * var0.sqrt()
        K_11 = K[:tp, :tp]
        k_12 = K[:tp, t_next]
        K_22 = K[t_next, t_next]
        jitter = 1e-6
        for _ in range(5):
            try:
                L_11 = torch.linalg.cholesky(
                    K_11 + jitter * torch.eye(tp, device=K.device, dtype=K.dtype)
                )
                break
            except RuntimeError:
                jitter *= 10.0
        alpha = torch.cholesky_solve(k_12.view(tp, 1), L_11).view(tp)
        mean_t = torch.einsum("t, btd -> bd", alpha, z_prefix)
        var_t = (K_22 - (k_12 * alpha).sum()).clamp_min(1e-9)
        return mean_t + torch.randn(B, Dz, device=z_prefix.device) * var_t.sqrt()

    @torch.no_grad()
    def _sample_gp_prior_block(self, T: int, batch_size: int) -> torch.Tensor:
        device = self.t_train.device
        t = torch.arange(T, dtype=torch.float32, device=device)
        K = self.K_tt(t)
        B = batch_size
        Dz = self.cfg.d_latent
        jitter = 1e-6
        I = torch.eye(T, device=device, dtype=K.dtype)
        L = torch.linalg.cholesky(K + jitter * I)

        eps = torch.randn(B * Dz, T, device=device)
        z_bd = (L @ eps.T).T
        z = z_bd.view(B, Dz, T).permute(0, 2, 1).contiguous()
        return z

    @torch.no_grad()
    def _gp_conditional_block(
        self, K: torch.Tensor, z_prefix: torch.Tensor, T0: int
    ) -> torch.Tensor:
        B, T0_check, Dz = z_prefix.shape
        assert T0_check == T0
        T = K.size(0)
        Tf = T - T0
        device = z_prefix.device
        dtype = K.dtype

        if Tf <= 0:
            return z_prefix.new_zeros(B, 0, Dz)

        K_pp = K[:T0, :T0]
        K_pf = K[:T0, T0:]
        K_fp = K_pf.transpose(0, 1)
        K_ff = K[T0:, T0:]

        jitter = 1e-6
        I_pp = torch.eye(T0, device=device, dtype=dtype)
        L_pp = torch.linalg.cholesky(K_pp + jitter * I_pp)

        z_p_bd = z_prefix.permute(0, 2, 1).reshape(B * Dz, T0)
        alpha_bd = torch.cholesky_solve(z_p_bd.unsqueeze(-1), L_pp).squeeze(-1)
        alpha = alpha_bd.view(B, Dz, T0).permute(0, 2, 1)

        mean_fut = torch.einsum("ft, btd -> bfd", K_fp, alpha)

        C = torch.cholesky_solve(K_pf, L_pp)
        Sigma_f = K_ff - K_fp @ C
        Sigma_f = 0.5 * (Sigma_f + Sigma_f.T)

        I_f = torch.eye(Tf, device=device, dtype=dtype)
        L_f = torch.linalg.cholesky(Sigma_f + jitter * I_f)

        eps = torch.randn(B * Dz, Tf, device=device)
        z_noise_bd = (L_f @ eps.T).T
        z_noise = z_noise_bd.view(B, Dz, Tf).permute(0, 2, 1)

        z_fut = mean_fut + z_noise
        return z_fut

    @torch.no_grad()
    def generate(
        self,
        T: int,
        batch_size: int = 1,
        top_k: int = 50,
        top_p: float = 0.9,
        temperature: float = 0.9,
    ):
        device = self.t_train.device
        self.eval()
        z = self._sample_gp_prior_block(T, batch_size)
        logits, _ = self.decoder(z)
        return sample_logits_from_timewise_logits(
            logits, top_k=top_k, top_p=top_p, temperature=temperature
        )

    @torch.no_grad()
    def generate_with_prompt(
        self,
        prompt_ids: torch.Tensor,
        total_len: int,
        eos_id: int,
        top_k: int = 50,
        top_p: float = 0.9,
        temperature: float = 0.9,
    ):
        device = self.t_train.device
        self.eval()
        B, T0 = prompt_ids.shape
        T = total_len

        x_in = torch.full(
            (B, T), fill_value=eos_id, dtype=torch.long, device=device
        )
        x_in[:, :T0] = prompt_ids.to(device)

        if T0 > 0:
            mu, logvar = self.encoder(x_in[:, :T0])
            std = torch.exp(0.5 * logvar)
            z_prompt = mu + std * torch.randn_like(std)
        else:
            z_prompt = None

        t = torch.arange(T, dtype=torch.float32, device=device)
        K = self.K_tt(t)
        Dz = self.cfg.d_latent

        z = torch.zeros(B, T, Dz, device=device)
        if T0 > 0:
            z[:, :T0, :] = z_prompt
            z_fut = self._gp_conditional_block(K, z_prompt, T0)
            z[:, T0:, :] = z_fut
        else:
            z = self._sample_gp_prior_block(T, B)

        logits, _ = self.decoder(z)
        new_ids = sample_logits_from_timewise_logits(
            logits[:, T0:, :],
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
        )
        x_out = x_in.clone()
        if T > T0:
            x_out[:, T0:] = new_ids
        return x_out, logits


@torch.no_grad()
def sample_logits_from_timewise_logits(
    logits: torch.Tensor,
    top_k: int = 50,
    top_p: float = 0.9,
    temperature: float = 1.0,
):
    B, T, V = logits.shape
    x = logits / max(1e-6, float(temperature))
    if top_k is not None and 0 < top_k < V:
        kth = torch.topk(x, top_k, dim=-1).values[..., -1:].expand(B, T, V)
        x = torch.where(x < kth, torch.full_like(x, -1e10), x)
    if top_p is not None and 0.0 < top_p < 1.0:
        probs = torch.softmax(x, dim=-1)
        sorted_probs, sorted_idx = torch.sort(
            probs, dim=-1, descending=True
        )
        cum = torch.cumsum(sorted_probs, dim=-1)
        cutoff = cum > top_p
        cutoff[..., 1:] = cutoff[..., :-1].clone()
        cutoff[..., 0] = False
        sorted_logits = torch.log(sorted_probs.clamp_min(1e-20))
        sorted_logits[cutoff] = -1e10
        x = torch.full_like(x, -1e10).scatter(-1, sorted_idx, sorted_logits)
    ids = torch.distributions.Categorical(logits=x).sample()
    return ids


@torch.no_grad()
def score_continuation_gpvae_from_logits(model, x_full, T0):
    B, T = x_full.shape
    mu, logvar = model.encoder(x_full)
    z = mu
    logits, _ = model.decoder(z)
    targets = x_full[:, T0:]
    logits_c = logits[:, T0:, :]
    V = logits_c.size(-1)
    nll = F.cross_entropy(
        logits_c.reshape(-1, V),
        targets.reshape(-1),
        reduction="mean",
    ).item()
    ppl = math.exp(nll) if nll < 50 else float("inf")
    return {"nll": nll, "ppl": ppl}


def apply_overrides(cfg: Config, overrides: Optional[Dict[str, Any]] = None) -> Config:
    if not overrides:
        return cfg
    for k, v in overrides.items():
        if hasattr(cfg, k):
            setattr(cfg, k, v)
    return cfg


def train_gpvae(
    block_size=64,
    batch_size=16,
    steps=1000,
    lr=2e-4,
    log_every=50,
    cfg_overrides: Optional[Dict[str, Any]] = None,
):
    train_ds, tok = load_wikitext2(block_size=block_size, split="train")
    val_ds, _ = load_wikitext2(block_size=block_size, split="validation")

    cfg = Config(vocab_size=tok.vocab_size, block_size=block_size, lr=lr)
    cfg = apply_overrides(cfg, cfg_overrides)
    device = cfg.device

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, drop_last=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False, drop_last=True
    )

    model = GPVAE_TCN(cfg).to(device)
    optim = torch.optim.AdamW(
        model.parameters(),
        lr=cfg.lr,
        betas=(0.9, 0.95),
        weight_decay=cfg.weight_decay,
    )
    scaler = torch.amp.GradScaler("cuda", enabled=device.startswith("cuda"))
    warmup_steps = int(os.getenv("KL_WARMUP_STEPS", "4000"))

    hist_elbo, hist_ll0, hist_llm, hist_kl = [], [], [], []
    hist_kl_cap, hist_beta, hist_embreg, hist_tokps = [], [], [], []

    model.train()
    t0 = time.perf_counter()
    beta = cfg.beta_init
    model._beta_state = beta

    for step, x in enumerate(train_loader, 1):
        x = x.to(device)
        optim.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=device.startswith("cuda")):
            elbo_tok_t, stats, comps, _ = model.elbo(x, beta_override=beta)
            if step < warmup_steps:
                beta = min(
                    cfg.beta_max,
                    beta
                    + (cfg.beta_max - cfg.beta_init) / max(1, warmup_steps),
                )
            elif cfg.use_adaptive_beta:
                diff = float(
                    comps["kl_tok_cap_t"].detach().item()
                ) - cfg.kl_target_nats
                beta = float(beta) * (
                    1.0 + cfg.beta_adapt_rate * (1.0 if diff > 0 else -1.0)
                )
                beta = float(np.clip(beta, 1e-4, cfg.beta_max))
            model._beta_state = beta
            loss = -elbo_tok_t

        scaler.scale(loss).backward()

        gc = getattr(cfg, "grad_clip", 0.0)
        if gc and gc > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), gc)

        scaler.step(optim)
        scaler.update()

        hist_elbo.append(float(stats["elbo_tok"]))
        hist_ll0.append(float(stats["ll0_tok"]))
        hist_llm.append(float(stats["ll_multi_tok"]))
        hist_kl.append(float(stats["kl_tok_raw"]))
        hist_kl_cap.append(float(stats["kl_tok_cap"]))
        hist_beta.append(float(stats["beta"]))
        hist_embreg.append(float(stats["emb_reg"]))

        if step % log_every == 0:
            dt = time.perf_counter() - t0
            tokps = (batch_size * block_size * log_every) / max(1e-9, dt)
            hist_tokps.append(float(tokps))
            print(
                f"step {step} | elbo/tok {stats['elbo_tok']:.3f} | ll0 {stats['ll0_tok']:.3f} | "
                f"ll_multi {stats['ll_multi_tok']:.3f} | kl_raw {stats['kl_tok_raw']:.3f} | "
                f"kl_cap {stats['kl_tok_cap']:.3f} | beta {stats['beta']:.3f} | tok/s {tokps:,.0f}"
            )
            t0 = time.perf_counter()
        if step >= steps:
            break

    plot_curve(hist_elbo, "Évolution ELBO/token", "ELBO/token", "elbo_token.png")
    plot_curve(
        hist_ll0,
        "Évolution LL0/token (k=0)",
        "LL0/token",
        "ll0_token.png",
    )
    plot_curve(
        hist_llm,
        "Évolution LL_multi/token (k=1..K)",
        "LL_multi/token",
        "llmulti_token.png",
    )
    plot_curve(
        hist_kl,
        "Évolution KL/token (raw)",
        "KL/token",
        "kl_token_raw.png",
    )
    plot_curve(
        hist_kl_cap,
        "Évolution KL/token (cap)",
        "KL_cap/token",
        "kl_token_cap.png",
    )
    plot_curve(hist_beta, "Évolution β", "β", "beta.png")
    plot_curve(
        hist_embreg,
        "Évolution régularisation sémantique",
        "emb_reg",
        "emb_reg.png",
    )
    if len(hist_tokps) > 0:
        plot_curve(
            hist_tokps,
            "Évolution débit tokens/s",
            "tok/s",
            "tok_per_s.png",
        )

    model.eval()
    with torch.no_grad():
        try:
            x_val = next(iter(val_loader)).to(device)
        except StopIteration:
            x_val = next(iter(train_loader)).to(device)
        elbo_tok_v, st_v, comps_v, logits_v = model.elbo(
            x_val, beta_override=model._beta_state
        )
        V = logits_v.size(-1)
        nll = F.cross_entropy(
            logits_v.reshape(-1, V),
            x_val.reshape(-1),
            reduction="mean",
        ).item()
        ppl = math.exp(nll) if nll < 50 else float("inf")
        bin_conf, bin_acc, ece = reliability_diagram_data(
            logits_v, x_val, n_bins=15
        )
        nll_pos = nll_per_token(logits_v, x_val)
    plot_reliability(bin_conf, bin_acc, ece, "reliability.png")
    plot_nll_per_position(nll_pos, "nll_per_position.png")

    print(
        f"[VAL] elbo/tok {st_v['elbo_tok']:.3f} | ll0 {st_v['ll0_tok']:.3f} | "
        f"ll_multi {st_v['ll_multi_tok']:.3f} | kl_cap {st_v['kl_tok_cap']:.3f} | ppl {ppl:.2f}"
    )

    T = x_val.shape[1]
    T0 = int(float(os.getenv("T0_FRAC", "0.5")) * T)
    sc = score_continuation_gpvae_from_logits(model, x_val, T0)

    ckpt_path = os.path.abspath("gpvae_best.pt")
    torch.save(
        {"model_state_dict": model.state_dict(), "config": asdict(cfg)},
        ckpt_path,
    )

    metrics = {
        "val_elbo_tok": float(st_v["elbo_tok"]),
        "val_ppl": float(ppl),
        "cont_nll": float(sc["nll"]),
        "cont_ppl": float(sc["ppl"]),
        "kl_eff_nats": float(st_v["kl_tok_cap"]),
        "beta_final": float(model._beta_state),
        "tok_s": float(hist_tokps[-1] if len(hist_tokps) > 0 else 0.0),
        "checkpoint_path": ckpt_path,
    }

    figs = {
        "elbo_token_png": "elbo_token.png",
        "ll0_token_png": "ll0_token.png",
        "llmulti_token_png": "llmulti_token.png",
        "kl_token_raw_png": "kl_token_raw.png",
        "kl_token_cap_png": "kl_token_cap.png",
        "beta_png": "beta.png",
        "emb_reg_png": "emb_reg.png",
        "tok_per_s_png": "tok_per_s.png",
        "reliability_png": "reliability.png",
        "nll_per_position_png": "nll_per_position.png",
    }
    return model, tok, metrics, figs


@dataclass
class SearchSpace:
    metric: str = "cont_nll"
    n_trials: int = 30
    algo: str = "optuna"
    seed: int = 42
    outdir: str = "hpsearch_runs/gpvae_tcn_plus"
    block_size: int = 64
    batch_size: int = 16
    max_steps: int = 1000
    val_every: int = 200
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    lr_log10_min: float = -5.0
    lr_log10_max: float = -3.0
    beta_max_min: float = 0.05
    beta_max_max: float = 0.8
    kl_cap_choices: List[float] = None
    free_bits_choices: List[float] = None
    kl_target_min: float = 4.0
    kl_target_max: float = 14.0
    tcn_stacks_min: int = 2
    tcn_stacks_max: int = 4
    tcn_blocks_min: int = 2
    tcn_blocks_max: int = 4
    d_model_min: int = 192
    d_model_max: int = 512
    kernel_choices: List[int] = None
    dropout_min: float = 0.0
    dropout_max: float = 0.25
    droppath_min: float = 0.0
    droppath_max: float = 0.2
    gp_lengthscale_min: float = 2.0
    gp_lengthscale_max: float = 24.0
    gp_var_min: float = 0.2
    gp_var_max: float = 3.0
    gp_noise_min: float = 1e-5
    gp_noise_max: float = 1e-2
    t0_frac_min: float = 0.3
    t0_frac_max: float = 0.8
    weight_decay_min: float = 0.0
    weight_decay_max: float = 0.1
    grad_clip_choices: List[float] = None

    def __post_init__(self):
        if self.kl_cap_choices is None:
            self.kl_cap_choices = [8.0, 10.0, 12.0, 14.0, 16.0]
        if self.free_bits_choices is None:
            self.free_bits_choices = [0.0, 0.1, 0.2, 0.3]
        if self.kernel_choices is None:
            self.kernel_choices = [3, 5, 7]
        if self.grad_clip_choices is None:
            self.grad_clip_choices = [0.0, 0.5, 1.0, 2.0]


def ensure_outdir(p: str) -> str:
    pth = pathlib.Path(p)
    pth.mkdir(parents=True, exist_ok=True)
    return str(pth)


def sample_random(space: SearchSpace, rng: random.Random) -> Dict[str, Any]:
    def rfloat(a, b):
        return a + (b - a) * rng.random()

    cfg_over = {
        "lr": 10 ** rfloat(space.lr_log10_min, space.lr_log10_max),
        "weight_decay": rfloat(
            space.weight_decay_min, space.weight_decay_max
        ),
        "beta_init": 5e-4,
        "beta_max": rfloat(space.beta_max_min, space.beta_max_max),
        "kl_cap_nats": rng.choice(space.kl_cap_choices),
        "free_bits_nats": rng.choice(space.free_bits_choices),
        "kl_target_nats": rfloat(space.kl_target_min, space.kl_target_max),
        "tcn_stacks": rng.randint(
            space.tcn_stacks_min, space.tcn_stacks_max
        ),
        "tcn_blocks_per_stack": rng.randint(
            space.tcn_blocks_min, space.tcn_blocks_max
        ),
        "d_model": int(
            rfloat(space.d_model_min, space.d_model_max) // 32 * 32
        ),
        "tcn_kernel": rng.choice(space.kernel_choices),
        "tcn_dropout": rfloat(space.dropout_min, space.dropout_max),
        "tcn_droppath": rfloat(space.droppath_min, space.droppath_max),
        "gp_lengthscale_init": rfloat(
            space.gp_lengthscale_min, space.gp_lengthscale_max
        ),
        "gp_variance_init": rfloat(space.gp_var_min, space.gp_var_max),
        "_t0_frac": rfloat(space.t0_frac_min, space.t0_frac_max),
        "grad_clip": rng.choice(space.grad_clip_choices),
        "_block_size": space.block_size,
        "_batch_size": space.batch_size,
        "_max_steps": space.max_steps,
        "_device": space.device,
    }
    cfg_over["_gp_noise"] = 10 ** rfloat(
        math.log10(space.gp_noise_min), math.log10(space.gp_noise_max)
    )
    return cfg_over


def run_one_trial(overrides: Dict[str, Any]) -> Dict[str, Any]:
    block_size = overrides.pop("_block_size", 64)
    batch_size = overrides.pop("_batch_size", 16)
    max_steps = overrides.pop("_max_steps", 1000)
    device = overrides.pop(
        "_device", "cuda" if torch.cuda.is_available() else "cpu"
    )
    t0_frac = float(overrides.pop("_t0_frac", 0.5))
    gp_noise = float(overrides.pop("_gp_noise", 1e-5))

    os.environ["T0_FRAC"] = str(t0_frac)

    model, tok, metrics, figs = train_gpvae(
        block_size=block_size,
        batch_size=batch_size,
        steps=max_steps,
        lr=overrides.get("lr", 2e-4),
        cfg_overrides=overrides,
    )

    metrics["t0_frac"] = t0_frac
    metrics["device"] = device
    metrics["notes"] = ""
    return {"metrics": metrics, "cfg_overrides": overrides}


CSV_NAME = "trials.csv"
BEST_JSON = "best_config.json"
BEST_METRICS_JSON = "best_metrics.json"
BEST_CKPT = "best.ckpt"


def write_rows_csv(path: str, rows: List[Dict[str, Any]]):
    import csv

    if not rows:
        return
    keys = sorted({k for r in rows for k in r.keys()})
    with open(path, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=keys)
        w.writeheader()
        [w.writerow(r) for r in rows]


def append_row_csv(path: str, row: Dict[str, Any]):
    if not os.path.exists(path):
        write_rows_csv(path, [row])
    else:
        import csv

        with open(path, "r", encoding="utf-8") as f:
            header = f.readline().strip().split(",")
        for k in row.keys():
            if k not in header:
                try:
                    import pandas as pd

                    df = pd.read_csv(path)
                    df = pd.concat(
                        [df, pd.DataFrame([row])], ignore_index=True
                    )
                    df.to_csv(path, index=False)
                    return
                except Exception:
                    pass
        with open(path, "a", newline="", encoding="utf-8") as f:
            w = csv.DictWriter(f, fieldnames=header)
            w.writerow(row)


def save_best(outdir: str, best: Dict[str, Any]):
    os.makedirs(outdir, exist_ok=True)
    with open(os.path.join(outdir, BEST_JSON), "w", encoding="utf-8") as f:
        json.dump(best["cfg"], f, indent=2)
    with open(
        os.path.join(outdir, BEST_METRICS_JSON), "w", encoding="utf-8"
    ) as f:
        json.dump(best["metrics"], f, indent=2)
    ckpt = best["metrics"].get("checkpoint_path")
    if ckpt and os.path.exists(ckpt):
        dst = os.path.join(outdir, BEST_CKPT)
        try:
            if os.path.exists(dst):
                os.remove(dst)
            os.symlink(os.path.abspath(ckpt), dst)
        except Exception:
            shutil.copy2(ckpt, dst)


def run_random_search(space: SearchSpace) -> Dict[str, Any]:
    rng = random.Random(space.seed)
    outdir = ensure_outdir(space.outdir)
    rows: List[Dict[str, Any]] = []
    best: Optional[Dict[str, Any]] = None
    for t in range(space.n_trials):
        overrides = sample_random(space, rng)
        print(f"[RandomSearch] Trial {t+1}/{space.n_trials}: {overrides}")
        try:
            ret = run_one_trial(overrides)
        except Exception as e:
            print(f"Trial {t+1} failed: {e}")
            continue
        metrics = ret["metrics"]
        cfg_over = ret["cfg_overrides"]
        score = metrics.get(space.metric, float("inf"))
        row = {**cfg_over, **metrics, "trial": t, space.metric: score}
        append_row_csv(os.path.join(outdir, CSV_NAME), row)
        rows.append(row)
        if (best is None) or (score < best["score"]):
            best = {"score": score, "cfg": cfg_over, "metrics": metrics}
            save_best(outdir, best)
    return best or {}


def run_optuna_search(space: SearchSpace) -> Dict[str, Any]:
    if not _HAVE_OPTUNA:
        print("[INFO] Optuna indisponible → fallback random.")
        return run_random_search(space)
    outdir = ensure_outdir(space.outdir)

    def objective(trial: "optuna.Trial"):
        cfg_over = {
            "lr": 10
            ** trial.suggest_float(
                "lr_log10", space.lr_log10_min, space.lr_log10_max
            ),
            "weight_decay": trial.suggest_float(
                "weight_decay",
                space.weight_decay_min,
                space.weight_decay_max,
            ),
            "beta_init": 5e-4,
            "beta_max": trial.suggest_float(
                "beta_max", space.beta_max_min, space.beta_max_max
            ),
            "kl_cap_nats": trial.suggest_categorical(
                "kl_cap_nats", space.kl_cap_choices
            ),
            "free_bits_nats": trial.suggest_categorical(
                "free_bits_nats", space.free_bits_choices
            ),
            "kl_target_nats": trial.suggest_float(
                "kl_target_nats",
                space.kl_target_min,
                space.kl_target_max,
            ),
            "tcn_stacks": trial.suggest_int(
                "tcn_stacks",
                space.tcn_stacks_min,
                space.tcn_stacks_max,
            ),
            "tcn_blocks_per_stack": trial.suggest_int(
                "tcn_blocks_per_stack",
                space.tcn_blocks_min,
                space.tcn_blocks_max,
            ),
            "d_model": trial.suggest_int(
                "d_model", space.d_model_min, space.d_model_max, step=32
            ),
            "tcn_kernel": trial.suggest_categorical(
                "tcn_kernel", space.kernel_choices
            ),
            "tcn_dropout": trial.suggest_float(
                "tcn_dropout", space.dropout_min, space.dropout_max
            ),
            "tcn_droppath": trial.suggest_float(
                "tcn_droppath", space.droppath_min, space.droppath_max
            ),
            "gp_lengthscale_init": trial.suggest_float(
                "gp_lengthscale_init",
                space.gp_lengthscale_min,
                space.gp_lengthscale_max,
            ),
            "gp_variance_init": trial.suggest_float(
                "gp_variance_init", space.gp_var_min, space.gp_var_max
            ),
            "_t0_frac": trial.suggest_float(
                "t0_frac", space.t0_frac_min, space.t0_frac_max
            ),
            "grad_clip": trial.suggest_categorical(
                "grad_clip", space.grad_clip_choices
            ),
            "_block_size": space.block_size,
            "_batch_size": space.batch_size,
            "_max_steps": space.max_steps,
            "_device": space.device,
            "_gp_noise": 10
            ** trial.suggest_float(
                "gp_noise_log10",
                math.log10(space.gp_noise_min),
                math.log10(space.gp_noise_max),
            ),
        }
        ret = run_one_trial(cfg_over)
        metrics = ret["metrics"]
        score = metrics.get(space.metric)
        if score is None:
            raise RuntimeError(f"Metric {space.metric} manquante")
        row = {**cfg_over, **metrics, "trial": trial.number, space.metric: score}
        append_row_csv(os.path.join(outdir, CSV_NAME), row)
        trial.report(score, step=1)
        return score

    study = optuna.create_study(
        direction="minimize",
        pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=1),
    )
    study.optimize(objective, n_trials=space.n_trials, show_progress_bar=True)
    bt = study.best_trial
    best_cfg = {
        "lr": 10 ** bt.params.get("lr_log10", -4.0),
        "weight_decay": bt.params.get("weight_decay"),
        "beta_init": 5e-4,
        "beta_max": bt.params.get("beta_max"),
        "kl_cap_nats": bt.params.get("kl_cap_nats"),
        "free_bits_nats": bt.params.get("free_bits_nats"),
        "kl_target_nats": bt.params.get("kl_target_nats"),
        "tcn_stacks": bt.params.get("tcn_stacks"),
        "tcn_blocks_per_stack": bt.params.get("tcn_blocks_per_stack"),
        "d_model": bt.params.get("d_model"),
        "tcn_kernel": bt.params.get("tcn_kernel"),
        "tcn_dropout": bt.params.get("tcn_dropout"),
        "tcn_droppath": bt.params.get("tcn_droppath"),
        "gp_lengthscale_init": bt.params.get("gp_lengthscale_init"),
        "gp_variance_init": bt.params.get("gp_variance_init"),
        "grad_clip": bt.params.get("grad_clip"),
    }
    ret = run_one_trial(
        {
            **best_cfg,
            "_t0_frac": 0.6,
            "_block_size": space.block_size,
            "_batch_size": space.batch_size,
            "_max_steps": space.max_steps,
            "_device": space.device,
            "_gp_noise": 1e-5,
        }
    )
    best = {
        "score": ret["metrics"][space.metric],
        "cfg": best_cfg,
        "metrics": ret["metrics"],
    }
    save_best(space.outdir, best)
    return best


import argparse, sys


def parse_args(argv=None):
    p = argparse.ArgumentParser(
        description="GP-VAE-TCN+ avec GP latent et décodeur renforcé"
    )
    p.add_argument(
        "--search", type=str, default="none", choices=["none", "random", "optuna"]
    )
    p.add_argument("--n-trials", type=int, default=20)
    p.add_argument("--metric", type=str, default="cont_nll")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument(
        "--outdir", type=str, default="hpsearch_runs/gpvae_tcn_plus"
    )
    p.add_argument("--block-size", type=int, default=64)
    p.add_argument("--batch-size", type=int, default=16)
    p.add_argument("--steps", type=int, default=1000)
    p.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
    )
    p.add_argument(
        "--grad-clip",
        type=float,
        default=None,
        help="Override de Config.grad_clip",
    )
    p.add_argument(
        "--kl-cap",
        type=float,
        default=None,
        help="Override de Config.kl_cap_nats",
    )
    p.add_argument(
        "--beta-max",
        type=float,
        default=None,
        help="Override de Config.beta_max",
    )
    args, _ = p.parse_known_args(argv)
    return args


def main():
    args = parse_args()
    if args.search == "none":
        overrides = {}
        if args.grad_clip is not None:
            overrides["grad_clip"] = float(args.grad_clip)
        if args.kl_cap is not None:
            overrides["kl_cap_nats"] = float(args.kl_cap)
        if args.beta_max is not None:
            overrides["beta_max"] = float(args.beta_max)

        model, tok, metrics, figs = train_gpvae(
            block_size=args.block_size,
            batch_size=args.batch_size,
            steps=args.steps,
            cfg_overrides=overrides if overrides else None,
        )
        print("\n[VAL]", metrics)
        model.eval()
        device = next(model.parameters()).device
        prompt_text = "The meaning of life"
        ids = tok.encode(prompt_text)[: args.block_size]
        eos_id = getattr(tok, "eos_token_id", 0)
        ids = ids if len(ids) > 0 else [eos_id]
        prompt_ids = torch.tensor([ids], dtype=torch.long, device=device)
        total_len = min(args.block_size * 2, len(ids) + 64)
        x_out = model.generate(
            T=64, batch_size=1, top_k=50, top_p=0.9, temperature=0.9
        )
        print("\n[Sample] Non-conditionné\n", tok.decode(x_out[0].tolist()))
        x_out2, logits_prompt = model.generate_with_prompt(
            prompt_ids, total_len, eos_id
        )
        print(
            "\n[Sample] Prompt-conditionné\n",
            tok.decode(x_out2[0].tolist()),
        )
        T0 = prompt_ids.size(1)
        sc = score_continuation_gpvae_from_logits(model, x_out2, T0)
        print(
            f"[GP-VAE-TCN+] NLL(cont)={sc['nll']:.4f} | PPL(cont)={sc['ppl']:.2f}"
        )
        return

    space = SearchSpace(
        metric=args.metric,
        n_trials=args.n_trials,
        algo=args.search,
        seed=args.seed,
        outdir=args.outdir,
        block_size=args.block_size,
        batch_size=args.batch_size,
        max_steps=args.steps,
        device=args.device,
    )
    os.makedirs(space.outdir, exist_ok=True)
    with open(
        pathlib.Path(space.outdir) / "search_space.json",
        "w",
        encoding="utf-8",
    ) as f:
        json.dump(asdict(space), f, indent=2)
    if args.search == "random":
        best = run_random_search(space)
    else:
        best = run_optuna_search(space)
    if best:
        print("\n=== BEST ===")
        print(json.dumps(best["cfg"], indent=2))
        print("Metric", space.metric, "=", best["score"])
        print("Checkpoint:", best["metrics"].get("checkpoint_path"))
    else:
        print("Aucun essai concluant.")


if __name__ == "__main__":
    main()


README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

step 50 | elbo/tok -18.638 | ll0 -9.178 | ll_multi -9.180 | kl_raw 36837.410 | kl_cap 12.000 | beta 0.007 | tok/s 6,799
step 100 | elbo/tok -16.126 | ll0 -6.874 | ll_multi -8.902 | kl_raw 107043.148 | kl_cap 12.000 | beta 0.013 | tok/s 8,959
step 150 | elbo/tok -15.028 | ll0 -5.782 | ll_multi -8.828 | kl_raw 134674.438 | kl_cap 12.000 | beta 0.020 | tok/s 8,973
step 200 | elbo/tok -13.670 | ll0 -4.643 | ll_multi -8.538 | kl_raw 168386.312 | kl_cap 12.000 | beta 0.026 | tok/s 8,957
step 250 | elbo/tok -12.861 | ll0 -3.972 | ll_multi -8.330 | kl_raw 178654.531 | kl_cap 12.000 | beta 0.032 | tok/s 8,942
step 300 | elbo/tok -12.435 | ll0 -3.376 | ll_multi -8.430 | kl_raw 195583.547 | kl_cap 12.000 | beta 0.038 | tok/s 8,979
step 350 | elbo/tok -11.804 | ll0 -2.846 | ll_multi -8.256 | kl_raw 200799.656 | kl_cap 12.000 | beta 0.045 | tok/s 8,854
step 400 | elbo/tok -11.139 | ll0 -2.427 | ll_multi -7.939 | kl_raw 201319.469 | kl_cap 12.000 | beta 0.051 | tok/s 8,816
step 450 | elbo/tok -11.06