In [None]:
# ============================================================
# ES-SSM (Elastic Spectral State Space Model) on PG19
# ============================================================

import os, sys, math, time, random, hashlib, shutil
from pathlib import Path


os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0,1")
if "torch" in sys.modules:
    raise RuntimeError(
        "torch is already imported in this kernel. Restart the kernel, then run this cell again "
        "so CUDA_VISIBLE_DEVICES takes effect."
    )

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, IterableDataset, get_worker_info, Dataset
from tqdm import tqdm
from datasets import load_dataset



# 1) CONFIG (PG19)

CONFIG = {
    # Repro
    "seed": 86,
    "deterministic": False,

    # Data
    "dataset_name": "pg19",
    "vocab_size": 258,            
    "seq_len": 4096,            
    "num_workers": 2,
    "batch_size": 4,                
    "steps_per_epoch": 4000,       
    "pin_memory": True,
    "train_shuffle_buffer_books": 512,


    "eval_cache_dir": "./eval_cache_pg19",
    "eval_cache_segments": 1024,  
    "eval_cache_shuffle_books": True,
    "eval_cache_shuffle_buffer_books": 4096,

    # Regularization
    "token_dropout_prob": 0.0,     
    "dropout": 0.10,
    "stoch_depth": 0.05,

    # Model (ES-SSM)
    "d_model": 256,
    "n_layers": 8,
    "K_max": 32,
    "K_chunk": 8,

    # Training
    "epochs": 300,  
    "warmup_epochs": 2,
    "grad_accum_steps": 4,          

    "lr": 2e-4,
    "min_lr_ratio": 0.10,
    "warmup_ratio": 0.06,

    "weight_decay": 0.03,
    "adam_betas": (0.9, 0.95),
    "adam_eps": 1e-8,

    "grad_clip": 1.0,
    "use_amp": True,
    "ema_decay": 0.999,

    # Budget dropout
    "budget_enabled": True,
    "budget_k_min": 2,
    "budget_full_every": 8,       
    "budget_ks": [2, 3, 4, 6, 8, 12, 16, 24, 32],
    "budget_bias_to_large": 0.6,    

    # Evaluation schedule
    "eval_batches_train": 25,       
    "test_every": 3,
    "eval_curve_every": 3,
    "eval_batches_final": 256,    

    # Early stopping
    "early_stop_patience": 10,
    "early_stop_min_delta": 1e-4,

    # Save paths 
    "weights_dir": "./weights_esssm",
    "run_name": "pg19_bytes_esssm_dp_L4096",
    "resume": False,
    "fresh_start": True,
    "save_last_every_epoch": 1,

    # Hankel filter cache
    "filter_cache_dir": "./cache_filters",
    "filter_iters": 70,
    "filter_oversample": 8,
    "filter_tol": 1e-6,

    # Numeric stability
    "rms_eps": 1e-6,


    "device": "cuda" if torch.cuda.is_available() else "cpu",
}



# 2) Utilities
def print0(*args, **kwargs):
    print(*args, **kwargs)

def unwrap_model(m: nn.Module) -> nn.Module:
    return m.module if isinstance(m, nn.DataParallel) else m

def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(CONFIG["seed"])

torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass
if CONFIG["deterministic"]:
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)

Path(CONFIG["weights_dir"]).mkdir(parents=True, exist_ok=True)
Path(CONFIG["filter_cache_dir"]).mkdir(parents=True, exist_ok=True)
Path(CONFIG["eval_cache_dir"]).mkdir(parents=True, exist_ok=True)

AMP_DEVICE = "cuda" if CONFIG["device"] == "cuda" else "cpu"

print0(f"[Device] {CONFIG['device']} | CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")
print0(f"[Run] L={CONFIG['seq_len']} | batch={CONFIG['batch_size']} | grad_accum={CONFIG['grad_accum_steps']} | AMP={CONFIG['use_amp']}")
print0(f"[Model] d={CONFIG['d_model']} | layers={CONFIG['n_layers']} | K_max={CONFIG['K_max']} | K_chunk={CONFIG['K_chunk']}")
print0(f"[Data] PG19 via HF datasets: {CONFIG['dataset_name']} | streaming=True")



# 3) Byte tokenization
def _bytes_to_tokens(b: bytes) -> torch.Tensor:
    fb = getattr(torch, "frombuffer", None)
    if fb is None:
        t = torch.tensor(list(b), dtype=torch.uint8)
    else:
        t = torch.frombuffer(bytearray(b), dtype=torch.uint8)
    return t.to(torch.int64).add_(2) 



# 4) PG19 streaming dataset for TRAIN
class PG19ByteLMIterable(IterableDataset):

    def __init__(self, split: str, seq_len: int, seed: int, shuffle_books: bool, shuffle_buffer_books: int):
        super().__init__()
        self.split = split
        self.seq_len = int(seq_len)
        self.seg_len = int(seq_len) + 1
        self.seed = int(seed)
        self.shuffle_books = bool(shuffle_books)
        self.shuffle_buffer_books = int(shuffle_buffer_books)

    def _make_stream(self):
        return load_dataset(CONFIG["dataset_name"], split=self.split, streaming=True)

    def __iter__(self):
        wi = get_worker_info()
        worker_id = 0 if wi is None else wi.id
        num_workers = 1 if wi is None else wi.num_workers

        ds0 = self._make_stream()
        n_shards = getattr(ds0, "n_shards", None)
        if n_shards is None:
            n_shards = getattr(ds0, "num_shards", None)
        if n_shards is None:
            n_shards = 1
        n_shards = int(n_shards)

        eff_workers = min(num_workers, n_shards)
        if wi is not None and worker_id >= eff_workers:
            return

        ds = self._make_stream()
        if eff_workers > 1:
            ds = ds.shard(num_shards=eff_workers, index=worker_id)

        if self.shuffle_books:
            ds = ds.shuffle(seed=self.seed + 13 * worker_id, buffer_size=self.shuffle_buffer_books)

        buf = bytearray()
        stride = self.seq_len

        while True:
            for ex in ds:
                text = ex.get("text", None)
                if not text:
                    continue
                b = text.encode("utf-8", errors="ignore")
                if not b:
                    continue

                buf.extend(b)
                buf.extend(b"\n")

                while len(buf) >= self.seg_len:
                    chunk = bytes(buf[: self.seg_len])
                    del buf[: stride]
                    yield _bytes_to_tokens(chunk)  # (L+1,)

            ds = self._make_stream()
            if eff_workers > 1:
                ds = ds.shard(num_shards=eff_workers, index=worker_id)
            if self.shuffle_books:
                ds = ds.shuffle(seed=self.seed + 13 * worker_id, buffer_size=self.shuffle_buffer_books)

def collate_fixed(batch):
    return torch.stack(batch, dim=0)  # (B, L+1)



# 5) Eval cache
@torch.no_grad()
def _token_unigram_bpb(tokens_u16: torch.Tensor) -> float:
    t = tokens_u16.to(torch.long)
    x = t[:, :-1].reshape(-1)
    y = t[:, 1:].reshape(-1)
    V = int(CONFIG["vocab_size"])
    counts = torch.bincount(x, minlength=V).float()
    probs = counts / counts.sum().clamp_min(1.0)
    py = probs[y].clamp_min(1e-12)
    nll = (-torch.log(py)).mean().item()  # nats
    return float(nll / math.log(2.0))

def _seg_hash_u16(seg_u16: torch.Tensor) -> str:
    seg_u8 = seg_u16.contiguous().view(torch.uint8).cpu()
    b = bytes(seg_u8.tolist())
    return hashlib.sha1(b).hexdigest()

def _stream_n_shards(split: str) -> int:
    d = load_dataset(CONFIG["dataset_name"], split=split, streaming=True)
    ns = getattr(d, "n_shards", None)
    if ns is None:
        ns = getattr(d, "num_shards", None)
    return int(ns) if ns is not None else 1

@torch.no_grad()
def build_eval_cache(split: str, out_path: Path, num_segments: int, seed: int,
                     shuffle_books: bool, shuffle_buffer_books: int, seq_len: int):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    if out_path.exists():
        print0(f"[EvalCache] Using existing cache: {out_path}")
        ck = torch.load(out_path, map_location="cpu")
        return ck["tokens_u16"]

    print0(f"[EvalCache] Building {split} cache: segments={num_segments} -> {out_path}")
    ds = PG19ByteLMIterable(
        split=split,
        seq_len=seq_len,
        seed=seed,
        shuffle_books=shuffle_books,
        shuffle_buffer_books=shuffle_buffer_books
    )

    tokens = []
    it = iter(ds)
    for i in range(int(num_segments)):
        seg = next(it)  # (L+1,) int64
        tokens.append(seg.to(torch.uint16).cpu())
        if (i + 1) % 200 == 0:
            print0(f"  cached {i+1}/{num_segments} segments ...")

    tokens_u16 = torch.stack(tokens, dim=0).contiguous()  # (N, L+1) uint16
    torch.save(
        {
            "tokens_u16": tokens_u16,
            "split": split,
            "seq_len": int(seq_len),
            "seed": int(seed),
            "shuffle_books": bool(shuffle_books),
            "num_segments": int(num_segments),
        },
        out_path
    )
    print0(f"[EvalCache] Saved {split} -> {out_path}")
    return tokens_u16

class CachedSegmentsDataset(Dataset):
    def __init__(self, tokens_u16: torch.Tensor):
        super().__init__()
        assert tokens_u16.dtype == torch.uint16
        self.tokens_u16 = tokens_u16

    def __len__(self):
        return self.tokens_u16.shape[0]

    def __getitem__(self, idx):
        return self.tokens_u16[idx].to(torch.long)



# 6) Hankel top-K via FFT subspace iteration

def hankel_toeplitz_embed_g(L: int, eps: float = 1e-6) -> torch.Tensor:
    k = torch.arange(0, 2 * L - 1, dtype=torch.float32)
    s = k + 2.0
    a = 2.0 / (s**3 - s + eps)
    c = a[L - 1 :]
    g = torch.cat([c, torch.zeros(1), a[: L - 1]])  # len 2L
    return g.contiguous()

@torch.no_grad()
def hankel_mv_batch_fft(g_f: torch.Tensor, X: torch.Tensor) -> torch.Tensor:

    L, p = X.shape
    n = 2 * L
    X_rev = torch.flip(X, dims=[0])
    X_pad = torch.zeros((p, n), dtype=torch.float32, device=X.device)
    X_pad[:, :L] = X_rev.T
    X_f = torch.fft.rfft(X_pad, dim=-1)
    Y_pad = torch.fft.irfft(X_f * g_f.unsqueeze(0), n=n, dim=-1)
    return Y_pad[:, :L].T

@torch.no_grad()
def hankel_topk_subspace(L: int, K: int, cache_path: str,
                         iters: int, oversample: int, tol: float):

    cache_path = str(cache_path)
    os.makedirs(os.path.dirname(cache_path), exist_ok=True)
    if os.path.exists(cache_path):
        ck = torch.load(cache_path, map_location="cpu")
        return ck["sigma"].float(), ck["phi"].float()

    print0(f"[Filters] Computing Hankel top-K (FFT subspace): L={L}, K={K}, iters={iters}, oversample={oversample}")
    device = "cpu"
    g = hankel_toeplitz_embed_g(L).to(device)
    g_f = torch.fft.rfft(g)

    p = K + oversample
    Q = torch.randn(L, p, dtype=torch.float32, device=device)
    Q, _ = torch.linalg.qr(Q, mode="reduced")

    last = None
    for t in range(1, iters + 1):
        ZQ = hankel_mv_batch_fft(g_f, Q)
        Q, _ = torch.linalg.qr(ZQ, mode="reduced")

        if (t % 10 == 0) or (t == iters):
            AQ = hankel_mv_batch_fft(g_f, Q)
            Bm = Q.T @ AQ
            w, V = torch.linalg.eigh(Bm)
            idx = torch.argsort(w, descending=True)
            w = w[idx]; V = V[:, idx]
            Q = Q @ V

            Qk = Q[:, :K]
            AQk = hankel_mv_batch_fft(g_f, Qk)
            R = AQk - Qk * w[:K].unsqueeze(0)
            rel = (R.norm(dim=0) / (w[:K].abs() + 1e-12))
            mx = float(rel.max().item())
            msg = f"[Filters] iter {t:03d}/{iters} | max_rel_res(topK)={mx:.3e}"
            if last is not None:
                msg += f" | delta={last - mx:+.2e}"
            print0(msg)
            last = mx
            if mx < tol:
                print0("[Filters] Converged.")
                break

    AQ = hankel_mv_batch_fft(g_f, Q)
    Bm = Q.T @ AQ
    w, V = torch.linalg.eigh(Bm)
    idx = torch.argsort(w, descending=True)
    w = w[idx]; V = V[:, idx]
    Q = Q @ V

    sigma = w[:K].contiguous()
    phi = Q[:, :K].contiguous()

    torch.save({"sigma": sigma.cpu(), "phi": phi.cpu()}, cache_path)
    print0(f"[Filters] Saved torch cache -> {cache_path}")
    return sigma.cpu(), phi.cpu()



# 7) ES-SSM layer

class DropPath(nn.Module):
    def __init__(self, drop_prob: float):
        super().__init__()
        self.drop_prob = float(drop_prob)

    def forward(self, x):
        if (not self.training) or self.drop_prob <= 0.0:
            return x
        keep = 1.0 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        mask = (torch.rand(shape, device=x.device, dtype=x.dtype) < keep)
        return x * mask / keep

def rms_rescale_logits(s: torch.Tensor, eps: float) -> torch.Tensor:

    K = s.shape[-1]
    norm = torch.linalg.vector_norm(s.float(), ord=2, dim=-1, keepdim=True).clamp_min(eps)
    return s * (float(K) ** 0.5) / norm

class ESSSM_Layer(nn.Module):
    def __init__(self, d_model: int, L: int, K_max: int, K_chunk: int, drop_path: float = 0.0):
        super().__init__()
        self.K_max = int(K_max)
        self.L = int(L)
        self.K_chunk = int(K_chunk)

        cache_dir = Path(CONFIG["filter_cache_dir"])
        cache_path = cache_dir / f"hankel_topk_L{L}_K{K_max}.pt"

        sigma, phi = hankel_topk_subspace(
            L=L, K=K_max, cache_path=str(cache_path),
            iters=int(CONFIG["filter_iters"]),
            oversample=int(CONFIG["filter_oversample"]),
            tol=float(CONFIG["filter_tol"]),
        )
        self.register_buffer("phi", phi)       # (L, K_max)
        self.register_buffer("sigma", sigma)   # (K_max,)

        # FFT(phi_k) for convolution
        n = 2 * L
        with torch.no_grad():
            phi_f = torch.fft.rfft(phi.T.contiguous().float(), n=n, dim=-1)   # (K, F) complex
            phi_f_ri = torch.view_as_real(phi_f).contiguous()                 # (K, F, 2)
        self.register_buffer("phi_f_ri", phi_f_ri)

        # Mixing matrices M_k
        self.M_phi = nn.Parameter(torch.randn(K_max, d_model, d_model) * 1e-3)

        # Gate MLP -> logits s_k(t)
        self.selector = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Linear(d_model // 2, K_max),
        )
        nn.init.constant_(self.selector[-1].bias, -2.0)

        # D u(t)
        self.D = nn.Linear(d_model, d_model, bias=False)
        self.drop_path = DropPath(drop_path)

    def forward(self, u, K_runtime: int = None):
        """
        u: (B, L, D)
        K_runtime: runtime budget K <= K_max
        """
        B, L, D = u.shape
        assert L == self.L, f"Expected L={self.L}, got {L}"
        K_max = self.K_max
        n = 2 * L

        # logits s(t) over all channels
        s_all = self.selector(u)  # (B, L, K_max)

        # budget K
        if K_runtime is None:
            K = K_max
        else:
            K = int(K_runtime)
            K = max(int(CONFIG.get("budget_k_min", 2)), min(K_max, K))

        # active prefix 1..K
        s = s_all[:, :, :K]
        s_tilde = rms_rescale_logits(s, eps=float(CONFIG["rms_eps"]))
        w = torch.softmax(s_tilde.float(), dim=-1).to(dtype=u.dtype)  # alpha_k(t), (B, L, K)

        # sigma^(1/4)
        sigma_scale = torch.clamp(self.sigma[:K], min=1e-12).pow(0.25).to(dtype=u.dtype)  # (K,)

        # FFT(u) in float32
        with torch.amp.autocast(device_type=AMP_DEVICE, enabled=False):
            u_f = torch.fft.rfft(u.float().permute(0, 2, 1), n=n, dim=-1)  # (B, D, F) complex
            u_f_ri = torch.view_as_real(u_f).contiguous()                  # (B, D, F, 2)

        y_acc = torch.zeros((B, L, D), device=u.device, dtype=u.dtype)

        # chunk over K
        for k0 in range(0, K, self.K_chunk):
            k1 = min(K, k0 + self.K_chunk)
            Kc = k1 - k0

            phi_f_chunk = self.phi_f_ri[k0:k1]            # (Kc, F, 2)
            M_chunk = self.M_phi[k0:k1]                   # (Kc, D, D)
            w_chunk = w[:, :, k0:k1]                      # (B, L, Kc)
            s_chunk = sigma_scale[k0:k1]                  # (Kc,)

            # Convolution (Phi_k * u) via FFT multiply and iFFT
            with torch.amp.autocast(device_type=AMP_DEVICE, enabled=False):
                a = u_f_ri.unsqueeze(2)                   # (B, D, 1, F, 2)
                b = phi_f_chunk.unsqueeze(0).unsqueeze(0) # (1, 1, Kc, F, 2)

                real = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
                imag = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]

                U_f_ri = torch.stack((real, imag), dim=-1).contiguous()  # (B, D, Kc, F, 2)
                U_f_c = torch.view_as_complex(U_f_ri)                    # (B, D, Kc, F)
                U_time = torch.fft.irfft(U_f_c, n=n, dim=-1)[..., :L]    # (B, D, Kc, L)
                U_out = U_time.permute(0, 3, 2, 1).contiguous()          # (B, L, Kc, D)

            U_out = U_out.to(dtype=u.dtype)


            projected = torch.einsum("blkd,kdo->blko", U_out, M_chunk.to(dtype=u.dtype))
            contrib = (projected * s_chunk.view(1, 1, Kc, 1) * w_chunk.unsqueeze(-1)).sum(dim=2)
            y_acc = y_acc + contrib

        core = y_acc + self.D(u)
        return self.drop_path(core), K



# 8) Deep LM 

class DeepESSM_LM(nn.Module):
    def __init__(self):
        super().__init__()
        V = int(CONFIG["vocab_size"])
        d = int(CONFIG["d_model"])
        L = int(CONFIG["seq_len"])
        K_max = int(CONFIG["K_max"])

        self.embedding = nn.Embedding(V, d, padding_idx=0)
        self.drop = nn.Dropout(float(CONFIG["dropout"]))

        sd_max = float(CONFIG["stoch_depth"])
        sd_rates = torch.linspace(0.0, sd_max, int(CONFIG["n_layers"])).tolist()

        self.ssm_layers = nn.ModuleList([
            ESSSM_Layer(d, L, K_max, int(CONFIG["K_chunk"]), drop_path=float(sd_rates[i]))
            for i in range(int(CONFIG["n_layers"]))
        ])

        self.norm1 = nn.ModuleList([nn.LayerNorm(d) for _ in range(int(CONFIG["n_layers"]))])
        self.norm2 = nn.ModuleList([nn.LayerNorm(d) for _ in range(int(CONFIG["n_layers"]))])

        self.ffn = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d, d * 4),
                nn.GELU(),
                nn.Dropout(float(CONFIG["dropout"])),
                nn.Linear(d * 4, d),
            )
            for _ in range(int(CONFIG["n_layers"]))
        ])
        self.ffn_drop = nn.ModuleList([DropPath(float(sd_rates[i])) for i in range(int(CONFIG["n_layers"]))])

        self.final_norm = nn.LayerNorm(d)
        self.lm_head = nn.Linear(d, V, bias=False)

        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        self.lm_head.weight = self.embedding.weight 

    def forward(self, tokens, K_runtime=None):
        x = self.drop(self.embedding(tokens))  # (B, L, d)

        K_used = []
        for i in range(len(self.ssm_layers)):
            y, K = self.ssm_layers[i](self.norm1[i](x), K_runtime=K_runtime)
            x = x + y
            K_used.append(float(K))
            x = x + self.ffn_drop[i](self.ffn[i](self.norm2[i](x)))

        x = self.final_norm(x)
        logits = self.lm_head(x)
        K_used_t = torch.tensor(K_used, device=logits.device, dtype=torch.float32).unsqueeze(0)
        return logits, K_used_t



# 9) EMA

class EMA:
    def __init__(self, model: nn.Module, decay: float):
        self.decay = float(decay)
        self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items() if v.dtype.is_floating_point}
        self.backup = None

    @torch.no_grad()
    def update(self, model: nn.Module):
        sd = model.state_dict()
        for k in self.shadow:
            v = sd[k]
            self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)

    def apply(self, model: nn.Module):
        self.backup = {}
        sd = model.state_dict()
        for k in self.shadow:
            self.backup[k] = sd[k].detach().clone()
            sd[k].copy_(self.shadow[k])

    def restore(self, model: nn.Module):
        if self.backup is None:
            return
        sd = model.state_dict()
        for k in self.backup:
            sd[k].copy_(self.backup[k])
        self.backup = None



# 10) LR schedule

def make_warmup_cosine_scheduler(optimizer, total_steps: int, warmup_steps: int, min_lr_ratio: float):
    from torch.optim.lr_scheduler import LambdaLR
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
        return min_lr_ratio + (1.0 - min_lr_ratio) * cosine
    return LambdaLR(optimizer, lr_lambda=lr_lambda)


# 11) Budget dropout
class BudgetSampler:
    def __init__(self, seed: int):
        self.rng = random.Random(seed + 777)
        self.cands = sorted(list(CONFIG["budget_ks"]))
        if len(self.cands) == 0:
            raise ValueError("CONFIG['budget_ks'] is empty.")
        self.k_min = max(int(CONFIG["budget_k_min"]), int(self.cands[0]))
        self.k_max = int(self.cands[-1])

    def _snap(self, k: float) -> int:
        best = self.cands[0]
        best_d = abs(best - k)
        for c in self.cands[1:]:
            d = abs(c - k)
            if d < best_d:
                best, best_d = c, d
        return int(best)

    def sample(self, update_step: int, epoch: int) -> int:
        K_full = int(CONFIG["K_max"])
        if (not CONFIG["budget_enabled"]) or (epoch < int(CONFIG["warmup_epochs"])):
            return K_full

        full_every = int(CONFIG["budget_full_every"])
        if full_every > 0 and (update_step % full_every == 0):
            return K_full

        # log-uniform with bias toward large K
        u = self.rng.random()
        bias = float(CONFIG.get("budget_bias_to_large", 0.6))
        bias = max(1e-6, min(1.0, bias))
        u = 1.0 - (1.0 - u) ** (1.0 / bias)

        k_cont = math.exp(math.log(self.k_min) + u * (math.log(self.k_max) - math.log(self.k_min)))
        k_snap = self._snap(k_cont)
        k_snap = max(int(CONFIG["budget_k_min"]), min(self.k_max, k_snap))
        return int(k_snap)



# 12) Eval

@torch.no_grad()
def evaluate_lm(model, loader, K_eval=None, max_batches=200, use_amp=True):
    model.eval()
    total_nll = 0.0
    total_tok = 0
    t0 = time.time()

    nb = 0
    for batch in loader:
        if nb >= int(max_batches):
            break
        nb += 1
        batch = batch.to(CONFIG["device"], non_blocking=True)  # (B, L+1)
        x = batch[:, :-1].contiguous()
        y = batch[:, 1:].contiguous()

        with torch.amp.autocast(device_type=AMP_DEVICE, enabled=(use_amp and CONFIG["device"] == "cuda")):
            logits, _ = model(x, K_runtime=K_eval)
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), reduction="mean")

        n_tok = y.numel()
        total_nll += float(loss.item()) * n_tok
        total_tok += int(n_tok)

    avg_nll = total_nll / max(1, total_tok)
    bpb = avg_nll / math.log(2.0)
    ppl = math.exp(avg_nll)

    elapsed = max(1e-9, time.time() - t0)
    tok_per_s = total_tok / elapsed
    return {"loss_nats": avg_nll, "bpb": bpb, "ppl": ppl, "tok_s": tok_per_s}

@torch.no_grad()
def budget_quality_curve(model, val_loader, K_list, max_batches):
    print0("\n" + "=" * 72)
    print0(f"[Budgetâ€“Quality] sweep on VAL | batches={max_batches}")
    print0("=" * 72)
    for K in K_list:
        m = evaluate_lm(model, val_loader, K_eval=int(K), max_batches=max_batches, use_amp=True)
        print0(f"K={int(K):02d} | BPB={m['bpb']:.4f} | PPL={m['ppl']:.2f} | tok/s={m['tok_s']:.0f}")
    print0("=" * 72 + "\n")



# 13) Save/load training state

def state_paths():
    base = Path(CONFIG["weights_dir"]) / CONFIG["run_name"]
    base.mkdir(parents=True, exist_ok=True)
    return {"dir": base, "last": base / "last_state.pt", "best": base / "best_state.pt"}

def wipe_run_dir(run_dir: Path):
    if run_dir.exists():
        for p in run_dir.glob("*"):
            try:
                if p.is_file() or p.is_symlink():
                    p.unlink()
                elif p.is_dir():
                    shutil.rmtree(p)
            except Exception:
                pass

def save_state(path: Path, model: nn.Module, optimizer, scheduler, scaler, ema: EMA,
              epoch: int, global_step: int, best_bpb: float):
    base_model = unwrap_model(model)
    payload = {
        "epoch": epoch,
        "global_step": global_step,
        "best_bpb": float(best_bpb),
        "model": base_model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "scaler": scaler.state_dict() if scaler is not None else None,
        "ema_shadow": ema.shadow,
        "config": CONFIG,
    }
    tmp = str(path) + ".tmp"
    torch.save(payload, tmp)
    os.replace(tmp, str(path))

def load_state(path: Path, model: nn.Module, optimizer, scheduler, scaler, ema: EMA):
    st = torch.load(path, map_location="cpu")
    unwrap_model(model).load_state_dict(st["model"], strict=True)
    optimizer.load_state_dict(st["optimizer"])
    scheduler.load_state_dict(st["scheduler"])
    if scaler is not None and st.get("scaler", None) is not None:
        scaler.load_state_dict(st["scaler"])
    ema.shadow = st.get("ema_shadow", ema.shadow)
    return int(st["epoch"]), int(st["global_step"]), float(st.get("best_bpb", 1e9))



# 14) Optim param groups 

def build_param_groups(model: nn.Module, weight_decay: float):
    decay = []
    no_decay = []
    seen = set()

    for name, p in model.named_parameters():
        if (p is None) or (not p.requires_grad):
            continue
        pid = id(p)
        if pid in seen:
            continue
        seen.add(pid)

        lname = name.lower()
        if (p.ndim == 1) or lname.endswith("bias") or ("norm" in lname) or ("embedding" in lname):
            no_decay.append(p)
        else:
            decay.append(p)

    return [
        {"params": decay, "weight_decay": float(weight_decay)},
        {"params": no_decay, "weight_decay": 0.0},
    ]



# 15) Train

def train():
    DEVICE = CONFIG["device"]
    paths = state_paths()

    if bool(CONFIG.get("fresh_start", False)):
        print0(f"[FreshStart] Wiping run dir: {paths['dir']}")
        wipe_run_dir(paths["dir"])


    Nseg = int(CONFIG["eval_cache_segments"])
    L = int(CONFIG["seq_len"])
    cache_dir = Path(CONFIG["eval_cache_dir"])
    cache_dir.mkdir(parents=True, exist_ok=True)

    val_cache_path = cache_dir / f"pg19_val_L{L}_N{Nseg}.pt"
    test_cache_path = cache_dir / f"pg19_test_L{L}_N{Nseg}.pt"

    base_seed = int(CONFIG["seed"])
    val_seed = base_seed + 111
    test_seed = base_seed + 222

    val_u16 = build_eval_cache(
        split="validation",
        out_path=val_cache_path,
        num_segments=Nseg,
        seed=val_seed,
        shuffle_books=bool(CONFIG["eval_cache_shuffle_books"]),
        shuffle_buffer_books=int(CONFIG["eval_cache_shuffle_buffer_books"]),
        seq_len=L
    )
    test_u16 = build_eval_cache(
        split="test",
        out_path=test_cache_path,
        num_segments=Nseg,
        seed=test_seed,
        shuffle_books=bool(CONFIG["eval_cache_shuffle_books"]),
        shuffle_buffer_books=int(CONFIG["eval_cache_shuffle_buffer_books"]),
        seq_len=L
    )


    print0(f"[Sanity] Unigram BPB | VAL={_token_unigram_bpb(val_u16):.4f} | TEST={_token_unigram_bpb(test_u16):.4f}")
    hv = set(_seg_hash_u16(val_u16[i]) for i in range(min(128, val_u16.shape[0])))
    ht = set(_seg_hash_u16(test_u16[i]) for i in range(min(128, test_u16.shape[0])))
    print0(f"[Sanity] Segment hash overlap (first 128) | intersection={len(hv.intersection(ht))} (expect 0)")


    val_loader = DataLoader(
        CachedSegmentsDataset(val_u16),
        batch_size=int(CONFIG["batch_size"]),
        shuffle=False,
        num_workers=0,
        pin_memory=bool(CONFIG["pin_memory"]),
        drop_last=True,
    )
    test_loader = DataLoader(
        CachedSegmentsDataset(test_u16),
        batch_size=int(CONFIG["batch_size"]),
        shuffle=False,
        num_workers=0,
        pin_memory=bool(CONFIG["pin_memory"]),
        drop_last=True,
    )


    train_shards = _stream_n_shards("train")
    train_nw = min(int(CONFIG["num_workers"]), int(train_shards))
    print0(f"[Data] HF n_shards | train={train_shards} val=cache test=cache")
    print0(f"[Data] DataLoader num_workers capped | train={train_nw} (val/test use cached loaders)")

    train_ds = PG19ByteLMIterable(
        split="train",
        seq_len=L,
        seed=int(CONFIG["seed"]),
        shuffle_books=True,
        shuffle_buffer_books=int(CONFIG["train_shuffle_buffer_books"]),
    )
    dl_kw = dict(
        batch_size=int(CONFIG["batch_size"]),
        num_workers=int(train_nw),
        pin_memory=bool(CONFIG["pin_memory"]),
        collate_fn=collate_fixed,
    )
    if train_nw > 0:
        dl_kw["persistent_workers"] = True
        dl_kw["prefetch_factor"] = 4
    train_loader = DataLoader(train_ds, **dl_kw)


    print0("[Model] Building ES-SSM LM ...")
    model = DeepESSM_LM().to(DEVICE)

    if DEVICE == "cuda" and torch.cuda.device_count() >= 2:
        print0(f"[DP] Using DataParallel over {torch.cuda.device_count()} GPUs (device_ids=[0,1])")
        model = nn.DataParallel(model, device_ids=[0, 1])
    else:
        print0(f"[DP] disabled (cuda_count={torch.cuda.device_count()}). Running single GPU/CPU.")

    n_params = sum(p.numel() for p in unwrap_model(model).parameters())
    print0(f"[Model] Params: {n_params/1e6:.2f}M")


    base_model = unwrap_model(model)
    param_groups = build_param_groups(base_model, weight_decay=float(CONFIG["weight_decay"]))
    opt_kwargs = dict(
        lr=float(CONFIG["lr"]),
        betas=tuple(CONFIG["adam_betas"]),
        eps=float(CONFIG["adam_eps"]),
    )
    try:
        optimizer = torch.optim.AdamW(param_groups, **opt_kwargs, fused=True)
        print0("[Opt] Using fused AdamW.")
    except TypeError:
        optimizer = torch.optim.AdamW(param_groups, **opt_kwargs)
        print0("[Opt] Using standard AdamW.")

    scaler = torch.amp.GradScaler(enabled=(bool(CONFIG["use_amp"]) and DEVICE == "cuda"))

    total_steps = int(CONFIG["steps_per_epoch"]) * int(CONFIG["epochs"])
    warmup_steps = int(total_steps * float(CONFIG["warmup_ratio"]))
    scheduler = make_warmup_cosine_scheduler(
        optimizer, total_steps=total_steps, warmup_steps=warmup_steps, min_lr_ratio=float(CONFIG["min_lr_ratio"])
    )
    print0(f"[LR] total_steps={total_steps} | warmup_steps={warmup_steps}")

    ema = EMA(base_model, decay=float(CONFIG["ema_decay"]))
    budget = BudgetSampler(int(CONFIG["seed"]))


    start_epoch = 0
    global_step = 0
    best_bpb = 1e9

    if bool(CONFIG.get("resume", False)) and paths["last"].exists():
        print0(f"[Resume] Loading last state: {paths['last']} ...")
        try:
            start_epoch, global_step, best_bpb = load_state(paths["last"], model, optimizer, scheduler, scaler, ema)
            start_epoch += 1
            print0(f"[Resume] start_epoch={start_epoch} | global_step={global_step} | best_VAL_BPB={best_bpb:.4f}")
        except Exception as e:
            print0(f"[Resume] Incompatible state. Ignoring resume. Error:\n  {repr(e)}")
            start_epoch, global_step, best_bpb = 0, 0, 1e9

    no_improve = 0


    train_it_ref = [iter(train_loader)]
    def next_batch(it_ref, loader):
        try:
            return next(it_ref[0])
        except StopIteration:
            it_ref[0] = iter(loader)
            return next(it_ref[0])


    for epoch in range(start_epoch, int(CONFIG["epochs"])):
        model.train()
        t0 = time.time()

        optimizer.zero_grad(set_to_none=True)
        total_nll = 0.0
        total_tok = 0
        K_meter = []

        current_K = budget.sample(update_step=global_step, epoch=epoch)
        desc = f"Epoch {epoch+1:02d}/{CONFIG['epochs']} | ES-SSM | warmup={'Y' if epoch < CONFIG['warmup_epochs'] else 'N'}"
        pbar = tqdm(range(int(CONFIG["steps_per_epoch"]) * int(CONFIG["grad_accum_steps"])), desc=desc)

        micro_in_accum = 0
        for _ in pbar:
            tokens_full = next_batch(train_it_ref, train_loader).to(DEVICE, non_blocking=True)  # (B, L+1)
            x = tokens_full[:, :-1].contiguous()
            y = tokens_full[:, 1:].contiguous()


            p_drop = float(CONFIG.get("token_dropout_prob", 0.0))
            if p_drop > 0.0:
                drop = (torch.rand_like(x.float()) < p_drop)
                x = x.masked_fill(drop, 1)

            with torch.amp.autocast(device_type=AMP_DEVICE, enabled=(bool(CONFIG["use_amp"]) and DEVICE == "cuda")):
                logits, K_used_t = model(x, K_runtime=current_K)
                loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), reduction="mean")
                loss = loss / int(CONFIG["grad_accum_steps"])

            scaler.scale(loss).backward()

            n_tok = y.numel()
            total_nll += float(loss.item()) * int(CONFIG["grad_accum_steps"]) * n_tok
            total_tok += int(n_tok)

            K_meter.append(float(K_used_t.mean().item()))

            micro_in_accum += 1
            if micro_in_accum >= int(CONFIG["grad_accum_steps"]):
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(CONFIG["grad_clip"]))
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

                scheduler.step()
                ema.update(unwrap_model(model))

                global_step += 1
                micro_in_accum = 0
                current_K = budget.sample(update_step=global_step, epoch=epoch)

            avg_nll = total_nll / max(1, total_tok)
            avg_bpb = avg_nll / math.log(2.0)
            avg_ppl = math.exp(avg_nll)
            avgK = sum(K_meter[-50:]) / max(1, len(K_meter[-50:]))
            lr_now = optimizer.param_groups[0]["lr"]

            pbar.set_postfix({
                "BPB": f"{avg_bpb:.4f}",
                "PPL": f"{avg_ppl:.2f}",
                "avgK": f"{avgK:.1f}",
                "K": int(current_K),
                "lr": f"{lr_now:.2e}",
            })


        if micro_in_accum > 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), float(CONFIG["grad_clip"]))
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

            scheduler.step()
            ema.update(unwrap_model(model))
            global_step += 1


        ema.apply(unwrap_model(model))

        eval_train_batches = int(CONFIG.get("eval_batches_train", 25))
        test_every = int(CONFIG.get("test_every", 3))
        curve_every = int(CONFIG.get("eval_curve_every", 3))

        val_full = evaluate_lm(
            model, val_loader, K_eval=int(CONFIG["K_max"]),
            max_batches=eval_train_batches, use_amp=True
        )

        do_test = (test_every > 0 and ((epoch + 1) % test_every == 0))
        if do_test:
            test_full = evaluate_lm(
                model, test_loader, K_eval=int(CONFIG["K_max"]),
                max_batches=eval_train_batches, use_amp=True
            )
        else:
            test_full = {"bpb": float("nan"), "ppl": float("nan"), "tok_s": float("nan")}

        elapsed = time.time() - t0
        print0(f"\n[Epoch {epoch+1:02d}] time={elapsed:.1f}s | "
               f"VAL BPB={val_full['bpb']:.4f} PPL={val_full['ppl']:.2f} | "
               f"TEST BPB={test_full['bpb']:.4f} PPL={test_full['ppl']:.2f}")

        if curve_every > 0 and ((epoch + 1) % curve_every == 0):
            budget_quality_curve(
                model, val_loader,
                K_list=[32, 24, 16, 12, 8, 6, 4, 3, 2],
                max_batches=eval_train_batches
            )

        if int(CONFIG.get("save_last_every_epoch", 1)) > 0:
            save_state(paths["last"], model, optimizer, scheduler, scaler, ema, epoch, global_step, best_bpb)

        min_delta = float(CONFIG.get("early_stop_min_delta", 1e-4))
        improved = (val_full["bpb"] < best_bpb - min_delta)
        if improved:
            best_bpb = float(val_full["bpb"])
            no_improve = 0
            save_state(paths["best"], model, optimizer, scheduler, scaler, ema, epoch, global_step, best_bpb)
            print0(f"[Best] New best VAL BPB={best_bpb:.4f} | saved -> {paths['best']}")
        else:
            no_improve += 1
            print0(f"[EarlyStop] no_improve={no_improve}/{int(CONFIG['early_stop_patience'])} | best_bpb={best_bpb:.4f}")

        ema.restore(unwrap_model(model))

        if no_improve >= int(CONFIG["early_stop_patience"]):
            print0("[EarlyStop] Triggered. Stop training.")
            break


    if paths["best"].exists():
        print0(f"[Final] Loading best state -> {paths['best']}")
        st = torch.load(paths["best"], map_location="cpu")
        unwrap_model(model).load_state_dict(st["model"], strict=True)

        final_batches = int(CONFIG.get("eval_batches_final", 256))
        final_val = evaluate_lm(model, val_loader, K_eval=int(CONFIG["K_max"]), max_batches=final_batches, use_amp=True)
        final_test = evaluate_lm(model, test_loader, K_eval=int(CONFIG["K_max"]), max_batches=final_batches, use_amp=True)

        print0("\n" + "=" * 72)
        print0(f"[Final Best] VAL BPB={final_val['bpb']:.4f} PPL={final_val['ppl']:.2f} | "
               f"TEST BPB={final_test['bpb']:.4f} PPL={final_test['ppl']:.2f}")
        print0("=" * 72)

    return model


def main():
    _ = train()


if __name__ == "__main__":
    main()
