In [1]:
import math, os, time
from dataclasses import dataclass
from typing import Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

try:
    from datasets import load_dataset
    from transformers import AutoTokenizer
except Exception:
    load_dataset = None
    AutoTokenizer = None
    print("[Warning] Install datasets and transformers: pip install datasets transformers")


def load_wikitext2(block_size=64, split='train'):
    assert load_dataset is not None and AutoTokenizer is not None, "Install datasets/transformers"
    ds = load_dataset('wikitext', 'wikitext-2-raw-v1')[split]
    try:
        from transformers.utils import hub as _hf_hub_utils
        _hf_hub_utils.list_repo_templates = lambda *a, **k: []
        import transformers.tokenization_utils_base as _tub
        _tub.list_repo_templates = lambda *a, **k: []
    except Exception:
        pass
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token
    texts = ds['text']
    dataset = LMBlocks(texts, tokenizer, block_size)
    return dataset, tokenizer


class LMBlocks(Dataset):
    def __init__(self, texts, tokenizer, block_size: int = 64):
        ids = []
        for t in texts:
            ids.extend(tokenizer.encode(t))
        n = max(1, len(ids) // block_size)
        ids = ids[: n * block_size]
        self.block_size = block_size
        self.data = torch.tensor(ids, dtype=torch.long).view(n, block_size)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]


def plot_curve(values, title, ylabel, out_png):
    if len(values) == 0:
        return
    plt.figure()
    plt.plot(values)
    plt.xlabel('Step')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()


def reliability_diagram_data(logits: torch.Tensor, x: torch.Tensor, n_bins: int = 15):
    with torch.no_grad():
        probs = torch.softmax(logits, dim=-1)
        conf, pred = probs.max(dim=-1)
        corr = (pred == x).float()
        conf = conf.detach().cpu().view(-1).numpy()
        corr = corr.detach().cpu().view(-1).numpy()
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(conf, bins) - 1
    bin_conf, bin_acc, bin_count = [], [], []
    ece = 0.0
    N = len(conf)
    for b in range(n_bins):
        m = (bin_ids == b)
        cnt = m.sum()
        if cnt == 0:
            bin_conf.append(0.0)
            bin_acc.append(0.0)
            bin_count.append(0)
            continue
        c = conf[m].mean()
        a = corr[m].mean()
        w = cnt / N
        ece += w * abs(a - c)
        bin_conf.append(c)
        bin_acc.append(a)
        bin_count.append(cnt)
    return np.array(bin_conf), np.array(bin_acc), float(ece)


def plot_reliability(confidences, accuracy, ece, out_png):
    plt.figure()
    plt.plot(confidences, accuracy, 'o-')
    plt.plot([0, 1], [0, 1], '--')
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.title(f'Reliability diagram — ECE={ece:.4f}')
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()


def nll_per_token(logits: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    B, T, V = logits.shape
    ce = F.cross_entropy(
        logits.reshape(B * T, V),
        x.reshape(B * T),
        reduction='none'
    ).view(B, T)
    return ce.mean(dim=0).detach().cpu()


def plot_nll_per_position(nll_per_pos, out_png):
    plt.figure()
    plt.plot(nll_per_pos)
    plt.xlabel('Position t')
    plt.ylabel('Token NLL')
    plt.title('NLL per position')
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()


def rbf_kernel(t: torch.Tensor, lengthscale: torch.Tensor, variance: torch.Tensor) -> torch.Tensor:
    t = t.view(-1, 1)
    d2 = (t - t.T) ** 2
    return variance * torch.exp(-0.5 * d2 / (lengthscale ** 2))


def mvn_logpdf_zero_mean(z: torch.Tensor, K: torch.Tensor, jitter: float = 1e-5) -> torch.Tensor:
    B, T, D = z.shape
    I = torch.eye(T, device=K.device, dtype=K.dtype)
    jitter_local = jitter
    for _ in range(5):
        try:
            L = torch.linalg.cholesky(K + jitter_local * I)
            break
        except RuntimeError:
            jitter_local *= 10.0
    else:
        L = torch.linalg.cholesky(K + jitter_local * I)

    log_det = 2.0 * torch.sum(torch.log(torch.diag(L)))
    z_bd = z.permute(0, 2, 1).reshape(B * D, T)
    y_bd = torch.cholesky_solve(z_bd.unsqueeze(-1), L).squeeze(-1)
    quad = (z_bd * y_bd).sum(dim=-1).view(B, D).sum(dim=-1)
    const = T * math.log(2 * math.pi)
    logp = -0.5 * (D * (log_det + const) + quad)
    return logp


def diag_mvn_logpdf(z: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    var = torch.exp(logvar)
    const = math.log(2 * math.pi)
    logp = -0.5 * (const + logvar + (z - mu) ** 2 / var)
    return logp.sum(dim=(1, 2))


class SinusoidalPos(nn.Module):
    def __init__(self, d: int):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(1))
        self.d = d

    def forward(self, T: int, device):
        D = self.d
        pe = torch.zeros(T, D, device=device)
        position = torch.arange(T, device=device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, D, 2, device=device) * (-math.log(10000.0) / D))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return (self.scale * pe).unsqueeze(0)


@torch.no_grad()
def sample_logits_from_timewise_logits(
    logits: torch.Tensor,
    top_k: int = 50,
    top_p: float = 0.9,
    temperature: float = 0.9
):
    B, T, V = logits.shape
    x = logits / max(1e-6, float(temperature))
    if top_k is not None and 0 < top_k < V:
        kth = torch.topk(x, top_k, dim=-1).values[..., -1:].expand(B, T, V)
        x = torch.where(x < kth, torch.full_like(x, -1e10), x)
    if top_p is not None and 0.0 < top_p < 1.0:
        probs = torch.softmax(x, dim=-1)
        sorted_probs, sorted_idx = torch.sort(probs, dim=-1, descending=True)
        cum = torch.cumsum(sorted_probs, dim=-1)
        cutoff = cum > top_p
        cutoff[..., 1:] = cutoff[..., :-1].clone()
        cutoff[..., 0] = False
        sorted_logits = torch.log(sorted_probs.clamp_min(1e-20))
        sorted_logits[cutoff] = -1e10
        x = torch.full_like(x, -1e10).scatter(-1, sorted_idx, sorted_logits)
    ids = torch.distributions.Categorical(logits=x).sample()
    return ids


@dataclass
class Config:
    vocab_size: int
    d_model: int = 256
    d_latent: int = 64
    block_size: int = 64
    emb_dim: int = 256
    gp_learn_hypers: bool = True
    gp_lengthscale_init: float = 8.0
    gp_variance_init: float = 1.0
    gp_reg_lambda: float = 5e-5
    lr: float = 2e-4
    label_smoothing: float = 0.01
    free_bits_nats: float = 0.3
    kl_cap_nats: float = 8.0
    kl_target_nats: float = 6.0
    beta_init: float = 5e-4
    beta_max: float = 0.35
    beta_adapt_rate: float = 0.03
    use_adaptive_beta: bool = True
    K_multi: int = 3
    multi_lambda_scheme: str = "harmonic"
    gamma_multi: float = 0.3
    logvar_min: float = -6.0
    logvar_max: float = 2.0
    logvar_init: float = -4.0
    embed_reg_weight: float = 0.1
    embed_reg_mode: str = "cos+mse"
    n_pyramid: int = 3
    dilations: tuple = (1, 2, 4, 8)
    kernel_size: int = 5
    dropout: float = 0.05
    word_dropout_p: float = 0.05
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


class CausalDilatedBlock(nn.Module):
    def __init__(self, channels, kernel_size=5, dilation=1, dropout=0.05):
        super().__init__()
        pad = (kernel_size - 1) * dilation
        self.left_pad = pad
        self.norm1 = nn.LayerNorm(channels)
        self.conv = nn.Conv1d(channels, 2 * channels, kernel_size, padding=pad, dilation=dilation)
        self.proj = nn.Conv1d(channels, channels, kernel_size=1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_bt_d):
        B, T, D = x_bt_d.shape
        h = self.norm1(x_bt_d)
        h = h.transpose(1, 2)
        h = self.conv(h)
        if self.left_pad > 0:
            h = h[..., :-self.left_pad]
        a, b = h.chunk(2, dim=1)
        h = a * torch.sigmoid(b)
        h = self.proj(h).transpose(1, 2)
        h = self.dropout(h)
        return x_bt_d + h


class PyramidConvEncoder(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.emb_dim)
        self.in_proj = nn.Linear(cfg.emb_dim, cfg.d_model)
        blocks = []
        for _ in range(cfg.n_pyramid):
            for d in cfg.dilations:
                blocks.append(
                    CausalDilatedBlock(
                        cfg.d_model,
                        kernel_size=cfg.kernel_size,
                        dilation=d,
                        dropout=cfg.dropout
                    )
                )
        self.blocks = nn.ModuleList(blocks)
        self.to_mu = nn.Linear(cfg.d_model, cfg.d_latent)
        self.to_logvar = nn.Linear(cfg.d_model, cfg.d_latent)
        nn.init.zeros_(self.to_mu.weight)
        nn.init.zeros_(self.to_mu.bias)
        nn.init.zeros_(self.to_logvar.weight)
        nn.init.constant_(self.to_logvar.bias, cfg.logvar_init)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.tok_emb(x)
        if self.training and self.cfg.word_dropout_p > 0.0:
            keep_prob = 1.0 - self.cfg.word_dropout_p
            mask = torch.empty(h.size(0), h.size(1), 1, device=h.device).bernoulli_(keep_prob)
            h = h * mask
        h = self.in_proj(h)
        for blk in self.blocks:
            h = blk(h)
        mu = self.to_mu(h)
        logvar = self.to_logvar(h).clamp(self.cfg.logvar_min, self.cfg.logvar_max)
        return mu, logvar


class TokenDecoder(nn.Module):
    def __init__(self, cfg: Config, tied_weight: torch.Tensor):
        super().__init__()
        self.cfg = cfg
        self.pe = SinusoidalPos(cfg.d_latent)
        self.mlp = nn.Sequential(
            nn.Linear(cfg.d_latent, cfg.d_model),
            nn.GELU(),
            nn.Linear(cfg.d_model, cfg.d_model),
            nn.GELU()
        )
        self.ln = nn.LayerNorm(cfg.d_model)
        self.post = nn.Conv1d(cfg.d_model, cfg.d_model, kernel_size=3, padding=1)
        self.to_emb = nn.Linear(cfg.d_model, cfg.emb_dim, bias=False)
        self.tied_weight = nn.Parameter(tied_weight)
        self.bias = nn.Parameter(torch.zeros(cfg.vocab_size))
        self.sem_head = nn.Linear(cfg.d_model, cfg.emb_dim, bias=False)

    def forward(self, z: torch.Tensor):
        z = z + self.pe(T=z.size(1), device=z.device)
        h = self.mlp(z)
        h = self.ln(h)
        h2 = self.post(h.transpose(1, 2)).transpose(1, 2)
        h = h + h2
        e_proj = self.to_emb(h)
        tw = F.normalize(self.tied_weight, dim=-1)
        logits = torch.matmul(e_proj, tw.t()) + self.bias
        e_hat = self.sem_head(h)
        return logits, e_hat


class GPVAE(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.encoder = PyramidConvEncoder(cfg)
        self.decoder = TokenDecoder(cfg, tied_weight=self.encoder.tok_emb.weight)
        self.register_buffer('t_train', torch.arange(cfg.block_size, dtype=torch.float32))
        self._ls_unconstrained = nn.Parameter(torch.tensor(cfg.gp_lengthscale_init).log())
        self._var_unconstrained = nn.Parameter(torch.tensor(cfg.gp_variance_init).log())
        self.softplus = nn.Softplus(beta=1.0)

    def gp_hypers(self):
        lengthscale = self.softplus(self._ls_unconstrained) + 1e-6
        variance = self.softplus(self._var_unconstrained) + 1e-6
        lengthscale = torch.clamp(lengthscale, 1e-2, 1e2)
        variance = torch.clamp(variance, 1e-3, 1e2)
        return lengthscale, variance

    def K_tt(self, t: torch.Tensor) -> torch.Tensor:
        ls, var = self.gp_hypers()
        return rbf_kernel(t, ls, var)

    def _label_smoothing_ce(self, logits, targets, eps):
        log_probs = F.log_softmax(logits, dim=-1)
        nll = F.nll_loss(log_probs.transpose(1, 2), targets, reduction='mean')
        smooth = -log_probs.mean(dim=-1).mean()
        return (1 - eps) * nll + eps * smooth

    def _multi_horizon_ll(self, logits, x) -> torch.Tensor:
        K = self.cfg.K_multi
        if K <= 0:
            return torch.zeros((), device=logits.device)
        B, T, V = logits.shape
        total = 0.0
        for k in range(1, K + 1):
            if T - k <= 0:
                break
            lam = 1.0 / k if self.cfg.multi_lambda_scheme == "harmonic" else 1.0
            ll_k = - self._label_smoothing_ce(
                logits[:, :T - k, :],
                x[:, k:],
                self.cfg.label_smoothing
            )
            total = total + self.cfg.gamma_multi * lam * ll_k
        return total

    def _embed_reg_losses(self, e_hat, x, token_emb_table: torch.Tensor):
        with torch.no_grad():
            e_tgt = token_emb_table[x]
        loss = 0.0
        mode = self.cfg.embed_reg_mode
        if mode in ("mse", "cos+mse"):
            loss = loss + F.mse_loss(e_hat, e_tgt, reduction="mean")
        if mode in ("cos", "cos+mse"):
            loss = loss + (1.0 - F.cosine_similarity(e_hat, e_tgt, dim=-1).mean())
        return torch.as_tensor(loss, device=e_hat.device)

    def elbo(self, x: torch.Tensor, beta_override: Optional[float] = None):
        cfg = self.cfg
        B, T = x.shape
        mu, logvar = self.encoder(x)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)

        logits, e_hat = self.decoder(z)
        ll_0 = - self._label_smoothing_ce(logits, x, cfg.label_smoothing)
        ll_multi = self._multi_horizon_ll(logits, x)
        emb_reg = self._embed_reg_losses(e_hat, x, self.encoder.tok_emb.weight.detach())

        tvec = self.t_train[:T]
        K_full = self.K_tt(tvec)
        log_pz_b = mvn_logpdf_zero_mean(z, K_full)
        log_qz_b = diag_mvn_logpdf(z, mu, logvar)
        kl_tok_raw_t = (log_qz_b - log_pz_b).mean() / T

        kl_tok_fb_t = torch.clamp(kl_tok_raw_t, min=cfg.free_bits_nats)
        kl_tok_capped = torch.clamp(kl_tok_fb_t, max=cfg.kl_cap_nats)

        beta = beta_override if beta_override is not None else getattr(self, "_beta_state", cfg.beta_init)

        gp_reg = cfg.gp_reg_lambda * ((self._ls_unconstrained ** 2 + self._var_unconstrained ** 2))
        elbo_tok_t = (ll_0 + ll_multi) - beta * kl_tok_capped - cfg.embed_reg_weight * emb_reg - gp_reg

        with torch.no_grad():
            V = logits.size(-1)
            nll_ce = F.cross_entropy(
                logits.reshape(-1, V),
                x.reshape(-1),
                reduction='mean'
            ).item()
            ppl = math.exp(nll_ce) if nll_ce < 50 else float("inf")

        stats = {
            'll0_tok': float(ll_0.detach().item()),
            'll_multi_tok': float(ll_multi.detach().item()),
            'emb_reg': float(emb_reg.detach().item()),
            'kl_tok_raw': float(kl_tok_raw_t.detach().item()),
            'kl_tok_cap': float(kl_tok_capped.detach().item()),
            'beta': float(beta),
            'elbo_tok': float(elbo_tok_t.detach().item()),
            'nll_ce': float(nll_ce),
            'ppl': float(ppl),
        }
        comps = {
            'll0_tok_t': ll_0,
            'll_multi_tok_t': ll_multi,
            'emb_reg_t': emb_reg,
            'kl_tok_raw_t': kl_tok_raw_t,
            'kl_tok_cap_t': kl_tok_capped,
            'elbo_tok_t': elbo_tok_t,
            'beta_t': torch.tensor(beta, device=x.device),
        }
        return elbo_tok_t, stats, comps, logits

    @torch.no_grad()
    def _gp_conditional_step(self, K: torch.Tensor, z_prefix: torch.Tensor, t_next: int):
        B, tp, Dz = z_prefix.shape
        if tp == 0:
            var0 = K[0, 0].clamp_min(1e-9)
            return torch.randn(B, Dz, device=z_prefix.device) * var0.sqrt()
        K_11 = K[:tp, :tp]
        k_12 = K[:tp, t_next]
        K_22 = K[t_next, t_next]
        jitter = 1e-6
        I = torch.eye(tp, device=K.device, dtype=K.dtype)
        for _ in range(5):
            try:
                L_11 = torch.linalg.cholesky(K_11 + jitter * I)
                break
            except RuntimeError:
                jitter *= 10.0
        else:
            L_11 = torch.linalg.cholesky(K_11 + jitter * I)
        alpha = torch.cholesky_solve(k_12.view(tp, 1), L_11).view(tp)
        mean_t = torch.einsum('t, btd -> bd', alpha, z_prefix)
        var_t = (K_22 - (k_12 * alpha).sum()).clamp_min(1e-9)
        return mean_t + torch.randn(B, Dz, device=z_prefix.device) * var_t.sqrt()

    @torch.no_grad()
    def sample_latents_parallel(self, T: int, batch_size: int = 1, temperature_z: float = 1.0):
        """
        Parallel sampling of z_{1:T} under the GP prior.

        p(z_{1:T}) = N(0, K ⊗ I_D)
        Equivalent to the causal factorization ∏_t p(z_t | z_{<t}), but sampled in one shot via Cholesky.
        """
        device = self.t_train.device
        t = torch.arange(T, dtype=torch.float32, device=device)
        K = self.K_tt(t)
        I = torch.eye(T, device=device, dtype=K.dtype)
        L = torch.linalg.cholesky(K + 1e-5 * I)

        B = batch_size
        Dz = self.cfg.d_latent
        eps = torch.randn(B * Dz, T, device=device)
        z_bd = eps @ L.T
        z = z_bd.view(B, Dz, T).permute(0, 2, 1)
        if temperature_z != 1.0:
            z = z * float(temperature_z)
        return z

    @torch.no_grad()
    def generate(
        self,
        T: int,
        batch_size: int = 1,
        top_k: int = 50,
        top_p: float = 0.9,
        temperature: float = 0.9,
        temperature_z: float = 1.0
    ):
        device = self.t_train.device
        t = torch.arange(T, dtype=torch.float32, device=device)
        K = self.K_tt(t)
        Dz = self.cfg.d_latent
        B = batch_size
        z = torch.zeros(B, T, Dz, device=device)
        for tp in range(T):
            step = self._gp_conditional_step(K, z[:, :tp, :], tp)
            z[:, tp, :] = step * temperature_z
        logits, _ = self.decoder(z)
        return sample_logits_from_timewise_logits(
            logits,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature
        )

    @torch.no_grad()
    def generate_parallel(
        self,
        T: int,
        batch_size: int = 1,
        top_k: int = 50,
        top_p: float = 0.9,
        temperature: float = 0.9,
        temperature_z: float = 1.0
    ):
        """
        Unconditional generation with parallel GP-sampled latents.
        Same marginal prior p(z_{1:T}), just vectorized.
        """
        z = self.sample_latents_parallel(T=T, batch_size=batch_size, temperature_z=temperature_z)
        logits, _ = self.decoder(z)
        return sample_logits_from_timewise_logits(
            logits,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature
        )

    @torch.no_grad()
    def generate_with_prompt(
        self,
        prompt_ids: torch.Tensor,
        total_len: int,
        eos_id: int,
        top_k: int = 50,
        top_p: float = 0.9,
        temperature: float = 0.9,
        temperature_z: float = 1.0
    ):
        device = self.t_train.device
        self.eval()
        B, T0 = prompt_ids.shape
        T = total_len
        x_in = torch.full(
            (B, T),
            fill_value=eos_id,
            dtype=torch.long,
            device=device
        )
        x_in[:, :T0] = prompt_ids.to(device)
        mu, logvar = self.encoder(x_in[:, :T0])
        std = torch.exp(0.5 * logvar)
        z_prompt = mu + std * torch.randn_like(std)
        t = torch.arange(T, dtype=torch.float32, device=device)
        K = self.K_tt(t)
        Dz = self.cfg.d_latent
        z = torch.zeros(B, T, Dz, device=device)
        if T0 > 0:
            z[:, :T0, :] = z_prompt
        for tp in range(T0, T):
            step = self._gp_conditional_step(K, z[:, :tp, :], tp)
            z[:, tp, :] = step * temperature_z
        logits, _ = self.decoder(z)
        new_ids = sample_logits_from_timewise_logits(
            logits[:, T0:, :],
            top_k=top_k,
            top_p=top_p,
            temperature=temperature
        )
        x_out = x_in.clone()
        if T > T0:
            x_out[:, T0:] = new_ids
        return x_out, logits


def train_gpvae(block_size=64, batch_size=16, steps=1000, lr=2e-4, log_every=50, seed: int = 1234):
    torch.manual_seed(seed)
    np.random.seed(seed)

    train_ds, tok = load_wikitext2(block_size=block_size, split='train')
    val_ds, _ = load_wikitext2(block_size=block_size, split='validation')
    cfg = Config(vocab_size=tok.vocab_size, block_size=block_size, lr=lr)
    device = cfg.device

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        drop_last=True
    )

    model = GPVAE(cfg).to(device)
    optim = torch.optim.AdamW(
        model.parameters(),
        lr=cfg.lr,
        betas=(0.9, 0.95),
        weight_decay=0.01
    )

    scaler = torch.amp.GradScaler('cuda', enabled=device.startswith("cuda"))

    env_warmup = int(os.getenv('KL_WARMUP_STEPS', '4000'))
    warmup_steps = min(env_warmup, steps)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optim,
        T_max=max(steps, 1),
        eta_min=lr * 0.1
    )

    hist_elbo, hist_ll0, hist_llm, hist_kl_raw = [], [], [], []
    hist_kl_cap, hist_beta, hist_ppl_train = [], [], []

    model.train()
    t0 = time.perf_counter()
    beta = cfg.beta_init
    model._beta_state = beta

    for step, x in enumerate(train_loader, 1):
        x = x.to(device)
        optim.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=device.startswith("cuda")):
            elbo_tok_t, stats, comps, _ = model.elbo(x, beta_override=beta)

            if step < warmup_steps:
                incr = (cfg.beta_max - cfg.beta_init) / max(1, warmup_steps)
                beta = min(cfg.beta_max, beta + incr)
            elif cfg.use_adaptive_beta:
                diff = float(comps['kl_tok_cap_t'].detach().item()) - cfg.kl_target_nats
                if diff > 0:
                    beta = float(beta) * (1.0 + cfg.beta_adapt_rate)
                else:
                    beta = float(beta) * (1.0 - cfg.beta_adapt_rate)
                beta = float(np.clip(beta, 1e-4, cfg.beta_max))

            model._beta_state = beta
            loss = -elbo_tok_t

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optim)
        scaler.update()
        scheduler.step()

        hist_elbo.append(float(stats['elbo_tok']))
        hist_ll0.append(float(stats['ll0_tok']))
        hist_llm.append(float(stats['ll_multi_tok']))
        hist_kl_raw.append(float(stats['kl_tok_raw']))
        hist_kl_cap.append(float(stats['kl_tok_cap']))
        hist_beta.append(float(beta))
        hist_ppl_train.append(float(stats['ppl']))

        if step % log_every == 0:
            dt = time.perf_counter() - t0
            tokps = (batch_size * block_size * log_every) / max(1e-9, dt)
            print(
                f"step {step} | elbo/tok {stats['elbo_tok']:.3f} | "
                f"ll0 {stats['ll0_tok']:.3f} | ll_multi {stats['ll_multi_tok']:.3f} | "
                f"kl/tok_raw {stats['kl_tok_raw']:.3f} | kl/tok_cap {stats['kl_tok_cap']:.3f} | "
                f"beta {beta:.3f} | ppl(train) {stats['ppl']:.2f} | tok/s {tokps:,.0f}"
            )
            t0 = time.perf_counter()

        if step >= steps:
            break

    plot_curve(hist_elbo, "ELBO/token evolution", "ELBO/token", "elbo_token.png")
    plot_curve(hist_ll0, "LL0/token (k=0) evolution", "LL0/token", "ll0_token.png")
    plot_curve(hist_llm, "LL_multi/token (k=1..K) evolution", "LL_multi/token", "llmulti_token.png")
    plot_curve(hist_kl_raw, "KL/token (raw) evolution", "KL/token (raw)", "kl_token_raw.png")
    plot_curve(hist_kl_cap, "KL/token (capped) evolution", "KL/token (capped)", "kl_token_cap.png")
    plot_curve(hist_beta, "Beta evolution", "beta", "beta.png")
    plot_curve(hist_ppl_train, "PPL (train) evolution", "PPL(train)", "ppl_train.png")

    model.eval()
    with torch.no_grad():
        try:
            x_val = next(iter(val_loader)).to(device)
        except StopIteration:
            x_val = next(iter(train_loader)).to(device)
        elbo_tok_v, st_v, comps_v, logits_v = model.elbo(x_val, beta_override=model._beta_state)
        V = logits_v.size(-1)
        nll = F.cross_entropy(
            logits_v.reshape(-1, V),
            x_val.reshape(-1),
            reduction='mean'
        ).item()
        ppl = math.exp(nll) if nll < 50 else float("inf")
        bin_conf, bin_acc, ece = reliability_diagram_data(logits_v, x_val, n_bins=15)
        nll_pos = nll_per_token(logits_v, x_val)

    plot_reliability(bin_conf, bin_acc, ece, "reliability.png")
    plot_nll_per_position(nll_pos, "nll_per_position.png")

    print(
        f"[VAL] elbo/tok {st_v['elbo_tok']:.3f} | ll0 {st_v['ll0_tok']:.3f} | "
        f"ll_multi {st_v['ll_multi_tok']:.3f} | kl/tok_cap {st_v['kl_tok_cap']:.3f} | ppl {ppl:.2f}"
    )

    return model, tok, {
        "elbo_token_png": "elbo_token.png",
        "ll0_token_png": "ll0_token.png",
        "llmulti_token_png": "llmulti_token.png",
        "kl_token_raw_png": "kl_token_raw.png",
        "kl_token_cap_png": "kl_token_cap.png",
        "beta_png": "beta.png",
        "ppl_train_png": "ppl_train.png",
        "reliability_png": "reliability.png",
        "nll_per_position_png": "nll_per_position.png",
        "val_nll": nll,
        "val_ppl": ppl,
        "val_elbo_tok": st_v["elbo_tok"],
    }


@torch.no_grad()
def score_continuation_gpvae_from_logits(model, x_full, T0):
    B, T = x_full.shape
    mu, logvar = model.encoder(x_full)
    z = mu
    logits, _ = model.decoder(z)
    targets = x_full[:, T0:]
    logits_c = logits[:, T0:, :]
    V = logits_c.size(-1)
    nll = F.cross_entropy(
        logits_c.reshape(-1, V),
        targets.reshape(-1),
        reduction='mean'
    ).item()
    ppl = math.exp(nll) if nll < 50 else float("inf")
    return {"nll": nll, "ppl": ppl}


if __name__ == "__main__":
    block_size = int(os.getenv('BLOCK_SIZE', 64))
    batch_size = int(os.getenv('BATCH_SIZE', 16))
    steps = int(os.getenv('STEPS', 1000))
    lr = float(os.getenv('LR', 2e-4))

    model, tok, figs = train_gpvae(
        block_size=block_size,
        batch_size=batch_size,
        steps=steps,
        lr=lr
    )

    model.eval()
    device = next(model.parameters()).device
    prompt_text = "The meaning of life"
    ids = tok.encode(prompt_text)[:block_size]
    eos_id = getattr(tok, "eos_token_id", 0)
    if len(ids) == 0:
        ids = [eos_id]
    prompt_ids = torch.tensor([ids], dtype=torch.long, device=device)
    total_len = min(block_size * 2, len(ids) + 64)

    x_out = model.generate(
        T=64,
        batch_size=1,
        top_k=50,
        top_p=0.9,
        temperature=0.9,
        temperature_z=0.9
    )
    print("\n[Sample] Unconditional (sequential latent AR)\n", tok.decode(x_out[0].tolist()))

    x_out_par = model.generate_parallel(
        T=64,
        batch_size=1,
        top_k=50,
        top_p=0.9,
        temperature=0.9,
        temperature_z=0.9
    )
    print("\n[Sample] Unconditional (parallel latent GP)\n", tok.decode(x_out_par[0].tolist()))

    x_out2, logits_prompt = model.generate_with_prompt(
        prompt_ids=prompt_ids,
        total_len=total_len,
        eos_id=eos_id,
        top_k=50,
        top_p=0.9,
        temperature=0.9,
        temperature_z=0.9
    )
    print("\n[Sample] Prompt-conditioned\n", tok.decode(x_out2[0].tolist()))

    T0 = prompt_ids.size(1)
    sc = score_continuation_gpvae_from_logits(model, x_out2, T0)
    print(f"[GP-VAE] NLL(cont)={sc['nll']:.4f} | PPL(cont)={sc['ppl']:.2f}")


README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

step 50 | elbo/tok -14.493 | ll0 -8.649 | ll_multi -5.486 | kl/tok_raw 66128.875 | kl/tok_cap 8.000 | beta 0.018 | ppl(train) 5537.65 | tok/s 6,768
step 100 | elbo/tok -12.684 | ll0 -6.428 | ll_multi -5.741 | kl/tok_raw 171276.484 | kl/tok_cap 8.000 | beta 0.035 | ppl(train) 585.53 | tok/s 8,568
step 150 | elbo/tok -11.023 | ll0 -4.866 | ll_multi -5.514 | kl/tok_raw 260920.625 | kl/tok_cap 8.000 | beta 0.053 | ppl(train) 120.17 | tok/s 8,568
step 200 | elbo/tok -10.650 | ll0 -4.343 | ll_multi -5.531 | kl/tok_raw 273149.625 | kl/tok_cap 8.000 | beta 0.070 | ppl(train) 70.69 | tok/s 8,574
step 250 | elbo/tok -10.061 | ll0 -3.648 | ll_multi -5.505 | kl/tok_raw 315782.625 | kl/tok_cap 8.000 | beta 0.088 | ppl(train) 34.94 | tok/s 8,570
step 300 | elbo/tok -9.458 | ll0 -2.942 | ll_multi -5.475 | kl/tok_raw 321722.906 | kl/tok_cap 8.000 | beta 0.105 | ppl(train) 17.06 | tok/s 8,570
step 350 | elbo/tok -9.366 | ll0 -2.781 | ll_multi -5.406 | kl/tok_raw 326754.000 | kl/tok_cap 8.000 | beta 0.1