#### AutoEncoders Implementation
#### AE-VAE-VQVAE-VQVAE2
#### Dataset: MNIST

In [None]:
# Connect Google Drive
from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/MyDrive/Colab Notebooks/Colab Notebooks/autoencoders

%pwd

In [None]:
import os, csv, json, time, gzip, struct
from datetime import datetime
from pathlib import Path

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms

# SOAP import (second-order optimizer)
try:
    from soap import SOAP
except Exception:
    SOAP = None


class Cfg:
    ROOT      = Path(".")
    #ROOT      = Path("/home/krajnish/autoencoders")
    DATA_DIR  = ROOT / "datasets/mnist"
    TRAIN_DIR = DATA_DIR / "train"
    TEST_DIR  = DATA_DIR / "test"
    OUT_DIR   = ROOT / "best_models/output_mnist/output_mnist_vae"

    # Training
    EPOCHS       = 1    #300
    BATCH_SIZE   = 256
    LR           = 1e-3
    WEIGHT_DECAY = 1e-4
    NUM_WORKERS  = 0
    PATIENCE     = None
    MIN_DELTA    = 1e-5
    GRAD_CLIP    = 1.0

    # Optimizer/Precision
    PRECISION    = "fp32"          # precision: "fp32", "fp64"
    OPTIMIZER    = "adam"          # optimizer: "adam", "soap"
    ADAM_BETAS   = (0.9, 0.999)

    # SOAP defaults (from official repo examples)
    SOAP_LR      = 3e-3
    SOAP_BETAS   = (0.95, 0.95)
    SOAP_WD      = 1e-2
    SOAP_PREFREQ = 10

    # Shared latent/channel sizes
    BOTTLENECK_CH = 56
    LATENT_SHAPE  = (56, 7, 7)

    # VQ specifics
    CODEBOOK_SIZE = 512
    COMMIT_BETA   = 0.25

    # VQ loss weights
    VQ_WEIGHT         = 1.0   # for VQVAE
    VQ_TOP_WEIGHT     = 1.0   # for VQVA2 top level
    VQ_BOTTOM_WEIGHT  = 1.0   # for VQVA2 bottom level

    # VQ-VAE-2 specifics
    TOP_CH        = 56

    # VAE KL weight
    VAE_BETA      = 1.0

    # Scenarios
    #TRAIN_FRACTIONS_B = [0.50, 0.60, 0.70]     # scenario b
    TRAIN_FRACTIONS_B = [0.70]
    #TEST_NOISES       = [1, 5, 10]             # scenarios c & d
    TEST_NOISES       = [1]
    TRAIN_FRACTION_FIXED = 0.70                # for c/d

    # Checkpoint subdir
    CKPT_SUBDIR = "ckpts"

    # Logging
    EXP_CSV = OUT_DIR / "experiments_all.csv"
    LOSS_CSV = OUT_DIR / "loss_history_all.csv"
    METRICS_CSV = OUT_DIR / "metrics_all.csv"

    SEED = 42
    SHUFFLE = False


cfg = Cfg()
cfg.OUT_DIR.mkdir(parents=True, exist_ok=True)

# Torch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(cfg.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(cfg.SEED)

print("Device:", device)

# Data (IDX loader & noise)
to_tensor_norm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


def _open_idx(path: Path):
    return gzip.open(path, "rb") if path.suffix == ".gz" else open(path, "rb")


def parse_idx_images(path: Path):
    with _open_idx(path) as f:
        header = f.read(16)
        if len(header) != 16:
            raise RuntimeError("Malformed IDX header")
        magic, num, rows, cols = struct.unpack(">IIII", header)
        if magic != 2051:
            raise RuntimeError(f"Bad magic {magic}")
        data = np.frombuffer(f.read(), dtype=np.uint8)
    return data.reshape(num, rows, cols)


def find_idx_in_dir(dir_path: Path, candidates):
    for fn in candidates:
        p = dir_path / fn
        if p.exists():
            return p
    for p in dir_path.rglob("*"):
        if p.is_file() and p.name in candidates:
            return p
    return None


def add_noise(x: torch.Tensor, pct: float) -> torch.Tensor:
    if pct <= 0:
        return x
    std = float(pct) / 100.0
    return (x + torch.randn_like(x) * std).clamp(-1.0, 1.0)


class MNISTIdxDataset(Dataset):
    def __init__(self, images: np.ndarray, transform=None, noise_pct: float = 0.0):
        self.images = images
        self.transform = transform
        self.noise_pct = float(noise_pct)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.fromarray(self.images[idx], mode="L")
        x_clean = self.transform(img) if self.transform else transforms.ToTensor()(img)
        x_noisy = add_noise(x_clean, self.noise_pct)
        return x_noisy, x_clean


# Load IDX paths
train_idx = find_idx_in_dir(cfg.TRAIN_DIR, ["train-images.idx3-ubyte", "train-images-idx3-ubyte.gz"])
test_idx  = find_idx_in_dir(cfg.TEST_DIR,  ["t10k-images.idx3-ubyte", "t10k-images-idx3-ubyte.gz"])
if train_idx is None or test_idx is None:
    raise RuntimeError("Expect 60k train + 10k test IDX in mnist/{train,test}")

TRAIN_FULL = parse_idx_images(train_idx)
TEST_FULL  = parse_idx_images(test_idx)


def build_loaders(train_fraction: float,
                  batch_size: int,
                  train_noise_pct: float = 0.0,
                  val_noise_pct: float   = 0.0,
                  test_noise_pct: float  = 0.0):
    n_full = len(TRAIN_FULL)
    n_train = int(train_fraction * n_full)
    train_ds = MNISTIdxDataset(TRAIN_FULL[:n_train], transform=to_tensor_norm, noise_pct=train_noise_pct)
    val_ds   = MNISTIdxDataset(TRAIN_FULL[n_train:], transform=to_tensor_norm, noise_pct=val_noise_pct)
    test_ds  = MNISTIdxDataset(TEST_FULL,            transform=to_tensor_norm, noise_pct=test_noise_pct)
    pin = (device.type == "cuda")
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=cfg.SHUFFLE,
                              num_workers=cfg.NUM_WORKERS, pin_memory=pin)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=cfg.SHUFFLE,
                              num_workers=cfg.NUM_WORKERS, pin_memory=pin)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=cfg.SHUFFLE,
                              num_workers=cfg.NUM_WORKERS, pin_memory=pin)
    return train_ds, val_ds, test_ds, train_loader, val_loader, test_loader


# ---------------- Snake activation ----------------
class Snake(nn.Module):
    """
    Snake activation: x + (1/a) * sin^2(a x)
    """
    def __init__(self, alpha=1.0):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(alpha))

    def forward(self, x):
        a = self.alpha.abs() + 1e-6
        return x + (1.0 / a) * torch.sin(a * x).pow(2)


# Models
class EncoderBlock(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(cin, cout, 3, padding=1),
            nn.GroupNorm(num_groups=min(8, cout), num_channels=cout), Snake(),
            nn.Conv2d(cout, cout, 3, padding=1),
            nn.GroupNorm(num_groups=min(8, cout), num_channels=cout), Snake(),
        )

    def forward(self, x):
        return self.net(x)


class DecoderBlock(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(cin, cout, 3, padding=1),
            nn.GroupNorm(num_groups=min(8, cout), num_channels=cout), Snake(),
            nn.Conv2d(cout, cout, 3, padding=1),
            nn.GroupNorm(num_groups=min(8, cout), num_channels=cout), Snake(),
        )

    def forward(self, x):
        return self.net(x)


# --- AuE (plain autoencoder) ---
class AuE(nn.Module):
    """Autoencoder with latent 56x7x7."""
    def __init__(self, ch=56):
        super().__init__()
        self.enc = nn.Sequential(
            EncoderBlock(1, 32),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 28->14
            Snake(),
            EncoderBlock(64, 64),
            nn.Conv2d(64, ch, 4, stride=2, padding=1),  # 14->7 -> latent (ch,7,7)
            Snake(),
        )
        self.dec = nn.Sequential(
            DecoderBlock(ch, 128),
            nn.ConvTranspose2d(128, 64, 2, stride=2),   # 7->14
            Snake(),
            DecoderBlock(64, 64),
            nn.ConvTranspose2d(64, 32, 2, stride=2),    # 14->28
            Snake(),
            nn.Conv2d(32, 1, 1),
            nn.Tanh(),
        )

    def forward(self, x):
        z = self.enc(x)
        xhat = self.dec(z)
        return xhat, {"aux_loss": torch.tensor(0.0, device=x.device)}


# --- VAE ---
class VAE(nn.Module):
    """VAE with latent 56x7x7 (Gaussian per latent cell)."""
    def __init__(self, ch=56, beta_kl=1.0):
        super().__init__()
        self.beta_kl = beta_kl

        self.enc = nn.Sequential(
            EncoderBlock(1, 32),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            Snake(),
            EncoderBlock(64, 64),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # → (128,7,7)
            Snake(),
        )
        self.mu     = nn.Conv2d(128, ch, 1)
        self.logvar = nn.Conv2d(128, ch, 1)
        self.dec = nn.Sequential(
            DecoderBlock(ch, 128),
            nn.ConvTranspose2d(128, 64, 2, stride=2),
            Snake(),
            DecoderBlock(64, 64),
            nn.ConvTranspose2d(64, 32, 2, stride=2),
            Snake(),
            nn.Conv2d(32, 1, 1),
            nn.Tanh(),
        )

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.enc(x)
        mu, logvar = self.mu(h), self.logvar(h)

        if self.training:
            z = self.reparam(mu, logvar)
        else:
            z = mu

        xhat = self.dec(z)

        # KL per-sample
        kl_map = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())   # (B,C,H,W)
        kl_per_sample = kl_map.mean(dim=[1, 2, 3])                # avg over C,H,W
        kl = kl_per_sample.mean()                                 # avg over batch

        kl_weighted = self.beta_kl * kl

        return xhat, {
            "aux_loss": kl_weighted,
            "kl": kl,
            "mu": mu,
            "logvar": logvar,
        }


# --- Vector Quantizer ---
class VectorQuantizer(nn.Module):
    def __init__(self, K, D, beta_commit=0.25):
        super().__init__()
        self.K = K
        self.D = D
        self.beta = beta_commit
        self.codebook = nn.Embedding(K, D)
        # Use D for initialization range
        nn.init.uniform_(self.codebook.weight, -1.0 / D, 1.0 / D)

    def forward(self, z_e):
        B, D, H, W = z_e.shape
        z = z_e.permute(0, 2, 3, 1).contiguous().view(-1, D)  # (BHW, D)
        e = self.codebook.weight
        dist = (z.pow(2).sum(1, keepdim=True) + e.pow(2).sum(1) - 2 * z @ e.t())
        idx = torch.argmin(dist, dim=1)
        z_q = self.codebook(idx).view(B, H, W, D).permute(0, 3, 1, 2).contiguous()

        codebook_loss = F.mse_loss(z_q.detach(), z_e)
        commit_loss   = self.beta * F.mse_loss(z_q, z_e.detach())
        vq_loss = codebook_loss + commit_loss

        z_q_st = z_e + (z_q - z_e).detach()

        with torch.no_grad():
            one_hot = F.one_hot(idx, num_classes=self.K).float()
            avg_probs = one_hot.mean(0)
            perplexity = torch.exp(-(avg_probs * (avg_probs + 1e-10).log()).sum())
        return z_q_st, vq_loss, perplexity, idx.view(B, H, W)


# --- VQVAE ---
class VQVAE(nn.Module):
    """VQ-VAE with latent 56x7x7 (codebook dim = 56)."""
    def __init__(self, K=512, D=56, beta_commit=0.25):
        super().__init__()
        self.encoder = nn.Sequential(
            EncoderBlock(1, 32),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 28→14
            Snake(),
            EncoderBlock(64, 64),
            nn.Conv2d(64, D, 4, stride=2, padding=1),   # 14→7 → (D,7,7)
            Snake(),  # activation after encoder
        )
        self.quant = VectorQuantizer(K, D, beta_commit=beta_commit)
        self.decoder = nn.Sequential(
            DecoderBlock(D, 128),
            nn.ConvTranspose2d(128, 64, 2, stride=2),   # 7→14
            Snake(),
            DecoderBlock(64, 64),
            nn.ConvTranspose2d(64, 32, 2, stride=2),    # 14→28
            Snake(),
            nn.Conv2d(32, 1, 1),
            nn.Tanh(),
        )

    def forward(self, x):
        z_e = self.encoder(x)
        z_q, vq_loss, ppl, idx = self.quant(z_e)
        xhat = self.decoder(z_q)
        aux = {
            "aux_loss": vq_loss,
            "vq_loss": vq_loss,
            "perplexity": ppl,
            "indices": idx,
        }
        return xhat, aux


# --- VQVA2 (VQ-VAE-2; two scales) ---
class VQVA2(nn.Module):
    def __init__(self, K=512, D=56, beta_commit=0.25, top_ch=56):
        super().__init__()
        self.enc_bottom = nn.Sequential(
            EncoderBlock(1, 32),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),   # 28→14
            Snake(),
            EncoderBlock(64, 64),
            nn.Conv2d(64, D, 4, stride=2, padding=1),    # 14→7 → (D,7,7)
            Snake(),  # activation after bottom encoder
        )
        self.enc_top = nn.Sequential(
            nn.Conv2d(D, 128, 3, padding=1),
            Snake(),
            nn.Conv2d(128, top_ch, 4, stride=2, padding=1),  # 7→4
            Snake(),  # activation after top encoder
        )
        self.quant_top = VectorQuantizer(K, top_ch, beta_commit=beta_commit)
        self.bottom_condition = nn.Sequential(
            nn.Conv2d(D + top_ch, D, 1),
            Snake()
        )
        self.quant_bottom = VectorQuantizer(K, D, beta_commit=beta_commit)
        self.dec = nn.Sequential(
            DecoderBlock(D + top_ch, 128),
            nn.ConvTranspose2d(128, 64, 2, stride=2),   # 7→14
            Snake(),
            DecoderBlock(64, 64),
            nn.ConvTranspose2d(64, 32, 2, stride=2),    # 14→28
            Snake(),
            nn.Conv2d(32, 1, 1),
            nn.Tanh(),
        )

    def forward(self, x):
        zb_e = self.enc_bottom(x)                       # (D,7,7)
        zt_e = self.enc_top(zb_e)                       # (top_ch,4,4)
        zt_q, vq_t, ppl_t, idx_t = self.quant_top(zt_e)
        zt_up = F.interpolate(zt_q, size=zb_e.shape[-2:], mode="nearest")  # (top_ch,7,7)
        zb_cond = self.bottom_condition(torch.cat([zb_e, zt_up], dim=1))   # (D,7,7)
        zb_q, vq_b, ppl_b, idx_b = self.quant_bottom(zb_cond)
        dec_in = torch.cat([zb_q, zt_up], dim=1)        # (D+top_ch,7,7)
        xhat = self.dec(dec_in)
        aux = {
            "aux_loss": cfg.VQ_TOP_WEIGHT * vq_t + cfg.VQ_BOTTOM_WEIGHT * vq_b,
            "vq_top": vq_t,
            "vq_bottom": vq_b,
            "perplexity_top": ppl_t,
            "perplexity_bottom": ppl_b,
            "indices_top": idx_t,
            "indices_bottom": idx_b
        }
        return xhat, aux


# Losses & metrics
def huber_recon(xhat, y, delta=1.0):
    diff = xhat - y
    absd = diff.abs()
    quad = torch.clamp(absd, max=delta)
    lin  = absd - quad
    return (0.5 * quad * quad / delta + lin).mean()


def rel_errors(xhat, y, eps=1e-12):
    diff = xhat - y
    relL1 = diff.abs().flatten(1).sum(1) / (y.abs().flatten(1).sum(1) + eps)
    relL2 = torch.sqrt((diff**2).flatten(1).sum(1)) / (torch.sqrt((y**2).flatten(1).sum(1) + eps))
    return relL1, relL2


# Logging utilities (single CSVs)
def _ensure_header(path: Path, fieldnames):
    write_header = not path.exists()
    f = path.open("a", newline="")
    w = csv.DictWriter(f, fieldnames=fieldnames)
    if write_header:
        w.writeheader()
    return f, w


def log_experiment_row(row: dict):
    fields = [
        "timestamp","run_name","run_path","model","scenario","sc",
        "epochs","batch_size","seed",
        "train_frac","val_frac","train_noise_pct","val_noise_pct","test_noise_pct",
        "bottleneck_ch","codebook_size","commitment_beta","top_ch",
        "precision","optimizer",
        "best_val","best_ckpt","last_ckpt"
    ]
    f, w = _ensure_header(cfg.EXP_CSV, fields)
    w.writerow({k: row.get(k, None) for k in fields})
    f.close()


def log_loss_history_rows(run_name, run_path, model, scenario, epoch, train_total, val_total, ppl=None, extra=None):
    fields = ["run_name","run_path","model","scenario","epoch","train_total","val_total","ppl","extra_json"]
    f, w = _ensure_header(cfg.LOSS_CSV, fields)
    w.writerow({
        "run_name": run_name,
        "run_path": str(run_path),
        "model": model,
        "scenario": scenario,
        "epoch": epoch,
        "train_total": train_total,
        "val_total": val_total,
        "ppl": ppl if (ppl is not None) else "",
        "extra_json": json.dumps(extra or {})
    })
    f.close()


def log_metrics_row(row: dict):
    fields = [
        "run_name","run_path","model","scenario","sc",
        "train_frac","val_frac","train_noise_pct","val_noise_pct","test_noise_pct",
        "recon_huber_mean","aux_loss_mean","total_loss_mean",
        "relL1_mean","relL2_mean",
        "train_size","val_size","test_size",
        "precision","optimizer"
    ]
    f, w = _ensure_header(cfg.METRICS_CSV, fields)
    w.writerow({k: row.get(k, None) for k in fields})
    f.close()


# Train/Eval
class CheckpointManager:
    def __init__(self, ckpt_dir: Path):
        self.ckpt_dir = Path(ckpt_dir)
        self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        self.best_path = self.ckpt_dir / "best.pt"
        self.last_path = self.ckpt_dir / "last.pt"

    def save(self, epoch, model, optimizer, scheduler, best_val, val_loss, history, extras=None):
        state = {
            "epoch": epoch,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict() if optimizer else None,
            "scheduler": scheduler.state_dict() if scheduler else None,
            "best_val": best_val,
            "val_loss": val_loss,
            "history": history,
            "extras": extras or {},
            "seed": cfg.SEED,
        }
        torch.save(state, self.last_path)
        if val_loss + 1e-12 < best_val - cfg.MIN_DELTA:
            torch.save(state, self.best_path)
            return True
        return False


def build_model(model_name: str):
    if model_name == "AuE":
        return AuE(ch=cfg.BOTTLENECK_CH)
    if model_name == "VAE":
        return VAE(ch=cfg.BOTTLENECK_CH, beta_kl=cfg.VAE_BETA)
    if model_name == "VQVAE":
        return VQVAE(K=cfg.CODEBOOK_SIZE, D=cfg.BOTTLENECK_CH, beta_commit=cfg.COMMIT_BETA)
    if model_name == "VQVA2":
        return VQVA2(K=cfg.CODEBOOK_SIZE, D=cfg.BOTTLENECK_CH, beta_commit=cfg.COMMIT_BETA, top_ch=cfg.TOP_CH)
    raise ValueError(f"Unknown model: {model_name}")


def forward_and_losses(model_name: str, model, xb, yb):
    xhat, info = model(xb)
    rec = huber_recon(xhat, yb)
    aux_term = info.get("aux_loss", torch.tensor(0.0, device=xb.device, dtype=xb.dtype))

    if model_name == "VAE":
        loss = rec + aux_term
        ppl = None

    elif model_name == "VQVAE":
        vq_loss = info.get("vq_loss", aux_term)
        loss = rec + cfg.VQ_WEIGHT * vq_loss
        ppl = float(info.get("perplexity", float("nan")))

    elif model_name == "VQVA2":
        vq_t = info.get("vq_top", None)
        vq_b = info.get("vq_bottom", None)
        if (vq_t is None) or (vq_b is None):
            loss = rec + aux_term
        else:
            loss = rec + cfg.VQ_TOP_WEIGHT * vq_t + cfg.VQ_BOTTOM_WEIGHT * vq_b
        ppl_t = float(info.get("perplexity_top", float("nan")))
        ppl_b = float(info.get("perplexity_bottom", float("nan")))
        ppl = (ppl_t + ppl_b) / 2.0

    else:
        loss = rec
        ppl = None

    return xhat, loss, rec, aux_term, ppl


def make_optimizer(model):
    opt_name = cfg.OPTIMIZER.lower()
    if opt_name == "soap":
        assert SOAP is not None, "SOAP optimizer not found"
        return SOAP(
            params=model.parameters(),
            lr=cfg.SOAP_LR,
            betas=cfg.SOAP_BETAS,
            weight_decay=cfg.SOAP_WD,
            precondition_frequency=cfg.SOAP_PREFREQ
        )
    # default: Adam (first-order)
    return torch.optim.Adam(
        model.parameters(),
        lr=cfg.LR,
        betas=cfg.ADAM_BETAS,
        weight_decay=cfg.WEIGHT_DECAY
    )


def train_eval_once(model_name: str,
                    train_frac: float,
                    train_noise_pct: float,
                    val_noise_pct: float,
                    test_noise_pct: float,
                    tag: str):

    run_root = cfg.OUT_DIR / f"{model_name.lower()}_{tag}_{int(train_frac*100)}_tn{train_noise_pct}_vn{val_noise_pct}_ts{test_noise_pct}_{cfg.PRECISION}_{cfg.OPTIMIZER}_{time.strftime('%Y%m%d_%H%M%S')}"
    run_root.mkdir(parents=True, exist_ok=True)
    ckpt_dir = run_root / cfg.CKPT_SUBDIR
    cpm = CheckpointManager(ckpt_dir)

    train_ds, val_ds, test_ds, train_loader, val_loader, test_loader = build_loaders(
        train_fraction=train_frac, batch_size=cfg.BATCH_SIZE,
        train_noise_pct=train_noise_pct, val_noise_pct=val_noise_pct, test_noise_pct=test_noise_pct
    )

    model = build_model(model_name).to(device)
    if cfg.PRECISION == "fp64":
        model = model.double()

    opt = make_optimizer(model)
    sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=max(5, cfg.EPOCHS))

    best_val, no_imp = float("inf"), 0
    history = {"train_total": [], "val_total": [], "ppl": []}

    for epoch in range(1, cfg.EPOCHS + 1):
        # ---- Train
        model.train()
        t_losses, t_ppls = [], []
        for xb, yb in train_loader:
            xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
            if cfg.PRECISION == "fp64":
                xb = xb.to(torch.float64)
                yb = yb.to(torch.float64)

            opt.zero_grad(set_to_none=True)
            xhat, loss, rec, aux, ppl = forward_and_losses(model_name, model, xb, yb)
            loss.backward()
            if cfg.GRAD_CLIP:
                nn.utils.clip_grad_norm_(model.parameters(), cfg.GRAD_CLIP)
            opt.step()

            t_losses.append(loss.item())
            if ppl is not None:
                t_ppls.append(ppl)
        sch.step(epoch)

        # ---- Val
        model.eval()
        v_losses, v_ppls = [], []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                if cfg.PRECISION == "fp64":
                    xb = xb.to(torch.float64)
                    yb = yb.to(torch.float64)
                _, loss, rec, aux, ppl = forward_and_losses(model_name, model, xb, yb)
                v_losses.append(loss.item())
                if ppl is not None:
                    v_ppls.append(ppl)
        v_mean = float(np.mean(v_losses)) if v_losses else float("inf")
        ppl_mean = float(np.mean(v_ppls)) if v_ppls else float("nan")
        history["train_total"].append(float(np.mean(t_losses)) if t_losses else float("inf"))
        history["val_total"].append(v_mean)
        history["ppl"].append(ppl_mean)

        # log per-epoch row to central loss CSV
        log_loss_history_rows(
            run_name=run_root.name, run_path=str(run_root), model=model_name, scenario=tag,
            epoch=epoch, train_total=history["train_total"][-1], val_total=v_mean, ppl=ppl_mean
        )

        # checkpointing
        improved = (best_val - v_mean) > cfg.MIN_DELTA
        if improved:
            best_val, no_imp = v_mean, 0
        else:
            no_imp += 1

        cpm.save(epoch=epoch, model=model, optimizer=opt, scheduler=sch,
                 best_val=best_val, val_loss=v_mean, history=history,
                 extras={"scenario": tag, "train_frac": train_frac})

        print(f"[{model_name}/{tag}/{cfg.PRECISION}/{cfg.OPTIMIZER}] Epoch {epoch:03d} | val={v_mean:.6f} | ppl={ppl_mean if not np.isnan(ppl_mean) else '—'}")

        if (cfg.PATIENCE is not None) and (no_imp >= cfg.PATIENCE):
            print("Early stopping.")
            break

    with open(run_root / "history.json", "w") as f:
        json.dump(history, f, indent=2)
    with open(run_root / "loss_history.csv", "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["epoch","train_total","val_total","ppl"])
        for e,(tr,va,pp) in enumerate(zip(history["train_total"], history["val_total"], history["ppl"]), start=1):
            w.writerow([e, tr, va, pp])

    best_to_load = (ckpt_dir / "best.pt") if (ckpt_dir / "best.pt").exists() else (ckpt_dir / "last.pt")
    print("Testing with:", best_to_load)
    state = torch.load(best_to_load, map_location=device, weights_only=False)
    model.load_state_dict(state["model"])
    model.eval()

    recon_vals, aux_vals, total_vals, r1_vals, r2_vals = [], [], [], [], []
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            if cfg.PRECISION == "fp64":
                xb = xb.to(torch.float64)
                yb = yb.to(torch.float64)
            xhat, loss, rec, aux, ppl = forward_and_losses(model_name, model, xb, yb)
            recon_vals.append(rec.item())
            aux_vals.append(float(aux))
            total_vals.append(loss.item())
            r1, r2 = rel_errors(xhat, yb)
            r1_vals.append(r1.cpu())
            r2_vals.append(r2.cpu())

    recon_mean = float(np.mean(recon_vals))
    aux_mean   = float(np.mean(aux_vals)) if aux_vals else 0.0
    total_mean = float(np.mean(total_vals))
    relL1_mean = torch.cat(r1_vals).mean().item()
    relL2_mean = torch.cat(r2_vals).mean().item()

    # Save per-sample errors
    np.savez_compressed(run_root / "per_sample_metrics.npz",
                        relL1=torch.cat(r1_vals).numpy(),
                        relL2=torch.cat(r2_vals).numpy(),
                        indices=np.arange(len(torch.cat(r1_vals))))

    # Report Rel L1 / L2 to stdout
    print(
        f"[TEST {model_name}/{tag}/{cfg.PRECISION}/{cfg.OPTIMIZER}] "
        f"recon={recon_mean:.6f} aux={aux_mean:.6f} total={total_mean:.6f} "
        f"relL1={relL1_mean:.6f} relL2={relL2_mean:.6f}"
    )

    # Log registry and metrics
    info_row = {
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "run_name": run_root.name,
        "run_path": str(run_root.resolve()),
        "model": model_name,
        "scenario": tag,
        "sc": tag[:1] if tag else tag,
        "epochs": cfg.EPOCHS,
        "batch_size": cfg.BATCH_SIZE,
        "seed": cfg.SEED,
        "train_frac": train_frac, "val_frac": 1.0 - train_frac,
        "train_noise_pct": train_noise_pct, "val_noise_pct": val_noise_pct, "test_noise_pct": test_noise_pct,
        "bottleneck_ch": cfg.BOTTLENECK_CH,
        "codebook_size": cfg.CODEBOOK_SIZE, "commitment_beta": cfg.COMMIT_BETA,
        "top_ch": cfg.TOP_CH if model_name == "VQVA2" else "",
        "precision": cfg.PRECISION,
        "optimizer": cfg.OPTIMIZER,
        "best_val": best_val,
        "best_ckpt": str((ckpt_dir / "best.pt").resolve()) if (ckpt_dir / "best.pt").exists() else "",
        "last_ckpt": str((ckpt_dir / "last.pt").resolve()) if (ckpt_dir / "last.pt").exists() else "",
    }
    log_experiment_row(info_row)

    metrics_row = {
        "run_name": run_root.name, "run_path": str(run_root.resolve()),
        "model": model_name, "scenario": tag, "sc": tag[:1] if tag else tag,
        "train_frac": train_frac, "val_frac": 1.0-train_frac,
        "train_noise_pct": train_noise_pct, "val_noise_pct": val_noise_pct, "test_noise_pct": test_noise_pct,
        "recon_huber_mean": recon_mean, "aux_loss_mean": aux_mean, "total_loss_mean": total_mean,
        "relL1_mean": relL1_mean, "relL2_mean": relL2_mean,
        "train_size": len(train_ds), "val_size": len(val_ds), "test_size": len(test_ds),
        "precision": cfg.PRECISION, "optimizer": cfg.OPTIMIZER
    }
    log_metrics_row(metrics_row)

    print(f"[DONE {model_name}/{tag}/{cfg.PRECISION}/{cfg.OPTIMIZER}] run_dir → {run_root.resolve()}")
    return run_root, info_row, metrics_row


# Scenario runner
def run_scenarios_for_model(model_name: str):
    print(f"\n==================== {model_name} ({cfg.PRECISION}/{cfg.OPTIMIZER}) ====================")
    results = []

    # b) split sweep, clean test
    for tf in cfg.TRAIN_FRACTIONS_B:
        results.append(train_eval_once(model_name, tf, 0.0, 0.0, 0.0, tag="b_split"))

    # c) fixed split, test noise sweep
    for tn in cfg.TEST_NOISES:
        results.append(train_eval_once(model_name, cfg.TRAIN_FRACTION_FIXED, 0.0, 0.0, float(tn), tag="c_test_noise"))

    # d) fixed split, noise only on TRAIN + same test noise
    for tn in cfg.TEST_NOISES:
        results.append(train_eval_once(model_name, cfg.TRAIN_FRACTION_FIXED, float(tn), 0.0, float(tn), tag="d_train_and_test_noise"))

    return results


# Run ALL models
#ALL_MODELS = ["AuE", "VAE", "VQVAE", "VQVA2"]
ALL_MODELS = ["VAE"]

# ---- 4-case sweep: (fp32|fp64) × (adam|soap) ----
all_results_by_model = {m: [] for m in ALL_MODELS}

for prec in ["fp32", "fp64"]:
    for opt in ["adam", "soap"]:
        cfg.PRECISION = prec
        cfg.OPTIMIZER = opt
        if cfg.OPTIMIZER == "soap" and SOAP is None:
            raise RuntimeError("Requested SOAP but soap.py not found")
        torch.set_default_dtype(torch.float64 if cfg.PRECISION == "fp64" else torch.float32)

        for m in ALL_MODELS:
            results = run_scenarios_for_model(m)
            all_results_by_model[m].extend(results)

# After all runs, pick best for each model
for m in ALL_MODELS:
    runs = all_results_by_model[m]
    if not runs:
        continue

    # Select run with best (lowest) total_loss_mean
    best_run = min(runs, key=lambda r: r[2]["total_loss_mean"])
    best_run_root, best_info, best_metrics = best_run

    best_ckpt_path = best_info.get("best_ckpt") or best_info.get("last_ckpt")
    if best_ckpt_path:
        state = torch.load(best_ckpt_path, map_location="cpu", weights_only=False)
        best_out_path = cfg.OUT_DIR / f"best_overall_{m}.pt"
        torch.save(state, best_out_path)
        print(
            f"\n[GLOBAL BEST] {m} total_loss_mean={best_metrics['total_loss_mean']:.6f} "
            f"(run={best_info['run_name']}, prec={best_info['precision']}, opt={best_info['optimizer']})\n"
            f"Saved to: {best_out_path}"
        )

print("\nAll experiments finished.")
print(f"Registry:   {cfg.EXP_CSV.resolve()}")
print(f"Loss CSV:   {cfg.LOSS_CSV.resolve()}")
print(f"Metrics:    {cfg.METRICS_CSV.resolve()}")