#### AE-VAE-VQVAE-VQVAE2 Implementation
#### Dataset: Imagenet100

In [None]:
# Connect Google Drive
from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/MyDrive/Colab Notebooks/Colab Notebooks/autoencoders

%pwd

In [None]:
import os, csv, json, time, shutil, tempfile
from datetime import datetime
from pathlib import Path
import numpy as np
import torch, torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

# path
#OUT_DIR   = Path("/home/krajnish/autoencoders/best_models/output_inet/output_inet_vae")
OUT_DIR   = Path("./best_models/output_inet/output_inet_vaetest")
#TENSORCACHE_DIR   = Path("/home/krajnish/autoencoders")
TENSORCACHE_DIR   = Path(".")
CACHE_DIR = TENSORCACHE_DIR / "datasets/inet100"
TRAIN_PT  = CACHE_DIR / "train.pt"
TEST_PT   = CACHE_DIR / "test.pt"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# SOAP import (second-order optimizer)
try:
    from soap import SOAP
except Exception:
    SOAP = None

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
print("Using cached tensors:")
print(" -", TRAIN_PT, TRAIN_PT.exists())
print(" -", TEST_PT,  TEST_PT.exists())
if not TRAIN_PT.exists() or not TEST_PT.exists():
    raise FileNotFoundError("cache not found")

# ---------------- Config ----------------
class CFG:
    # Training
    EPOCHS       = 1      #300
    BATCH_SIZE   = 256
    LR           = 1e-3
    WEIGHT_DECAY = 1e-4
    NUM_WORKERS  = 4
    PIN_MEMORY   = torch.cuda.is_available()
    PERSISTENT   = (NUM_WORKERS > 0)
    PATIENCE     = None
    MIN_DELTA    = 1e-5
    GRAD_CLIP    = 1.0
    SEED         = 42

    # Precision / Optimizer
    PRECISION    = "fp32"          # Precision: "fp32", "fp64"
    OPTIMIZER    = "adam"          # Optimizer: "adam", "soap"
    ADAM_BETAS   = (0.9, 0.999)

    # SOAP recommended defaults
    SOAP_LR      = 3e-3
    SOAP_BETAS   = (0.95, 0.95)
    SOAP_WD      = 1e-2
    SOAP_PREFREQ = 10

    # Latent/channel sizes
    BOTTLENECK_CH = 56
    LATENT_SHAPE  = (56, 7, 7)

    # VQ specifics
    CODEBOOK_SIZE = 512
    COMMIT_BETA   = 0.25

    # VQ loss weights
    VQ_WEIGHT         = 1.0   # for VQVE (VQ-VAE)
    VQ_TOP_WEIGHT     = 1.0   # for VQVA2 top level
    VQ_BOTTOM_WEIGHT  = 1.0   # for VQVA2 bottom level

    # VQ-VAE-2 specifics
    TOP_CH        = 56

    # VAE KL weight
    VAE_BETA      = 1.0

    # Scenarios
    #TRAIN_FRACTIONS_B = [0.50, 0.60, 0.70]     # scenario b
    TRAIN_FRACTIONS_B = [0.70]
    #TEST_NOISES       = [1, 5, 10]             # scenarios c & d
    TEST_NOISES       = [1]
    TRAIN_FRACTION_FIXED = 0.70                 # for c/d

    # Central CSVs
    EXP_CSV  = OUT_DIR / "experiments_inet100.csv"
    LOSS_CSV = OUT_DIR / "loss_history_inet100.csv"
    MET_CSV  = OUT_DIR / "metrics_inet100.csv"

cfg = CFG()
torch.manual_seed(cfg.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(cfg.SEED)

torch.set_default_dtype(torch.float64 if cfg.PRECISION == "fp64" else torch.float32)

# ---------------- Snake activation ----------------
class Snake(nn.Module):
    """
    Snake activation: x + (1/a) * sin^2(a x)
    """
    def __init__(self, alpha=1.0):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(alpha))

    def forward(self, x):
        a = self.alpha.abs() + 1e-6
        return x + (1.0 / a) * torch.sin(a * x).pow(2)

# ---------------- CSV helpers ----------------
def _read_header_if_exists(path: Path):
    if not path.exists() or path.stat().st_size == 0:
        return None
    with path.open("r", newline="") as f:
        reader = csv.reader(f)
        try:
            hdr = next(reader); return hdr
        except StopIteration:
            return None

def _upgrade_csv_header_if_needed(path: Path, desired_fields):
    existing = _read_header_if_exists(path)
    if existing is None:
        return desired_fields[:]
    if set(desired_fields).issubset(existing):
        return existing[:]

    new_fields = existing[:] + [f for f in desired_fields if f not in existing]
    tmp = Path(str(path) + ".tmp")
    with path.open("r", newline="") as fin, tmp.open("w", newline="") as fout:
        rin = csv.DictReader(fin)
        wout = csv.DictWriter(fout, fieldnames=new_fields)
        wout.writeheader()
        for row in rin:
            for k in new_fields:
                row.setdefault(k, "")
            wout.writerow({k: row.get(k, "") for k in new_fields})
    shutil.move(tmp, path)
    return new_fields

def _ensure_writer(path: Path, desired_fields):
    path.parent.mkdir(parents=True, exist_ok=True)
    final_fields = _upgrade_csv_header_if_needed(path, desired_fields)
    f = path.open("a", newline="")
    w = csv.DictWriter(f, fieldnames=final_fields)
    if path.stat().st_size == 0:
        w.writeheader()
    return f, w, final_fields

def log_exp(row: dict):
    fields = [
        "timestamp","dataset","run_name","run_path","model","scenario","sc",
        "epochs","batch_size","seed",
        "train_frac","val_frac","train_noise_pct","val_noise_pct","test_noise_pct",
        "bottleneck_ch","best_val",
        "precision","optimizer"
    ]
    f, w, cols = _ensure_writer(cfg.EXP_CSV, fields)
    w.writerow({k: row.get(k, "") for k in cols}); f.close()

def log_loss(run_info: dict, epoch: int, train_total: float, val_total: float):
    fields = ["dataset","run_name","run_path","model","scenario","epoch","train_total","val_total"]
    f, w, cols = _ensure_writer(cfg.LOSS_CCSV, fields) if False else _ensure_writer(cfg.LOSS_CSV, fields)
    row = {
        "dataset": "inet100",
        "run_name": run_info.get("run_name",""),
        "run_path": run_info.get("run_path",""),
        "model":    run_info.get("model",""),
        "scenario": run_info.get("scenario",""),
        "epoch": epoch,
        "train_total": train_total,
        "val_total": val_total,
    }
    w.writerow({k: row.get(k, "") for k in cols}); f.close()

def log_metrics(row: dict):
    fields = [
        "dataset","run_name","run_path","model","scenario","sc",
        "train_frac","val_frac","train_noise_pct","val_noise_pct","test_noise_pct",
        "recon_huber_mean","aux_loss_mean","total_loss_mean",
        "relL1_mean","relL2_mean",
        "train_size","val_size","test_size",
        "precision","optimizer"
    ]
    f, w, cols = _ensure_writer(cfg.MET_CSV, fields)
    w.writerow({k: row.get(k, "") for k in cols}); f.close()

# ---------------- Dataset / Loader ----------------
def _noisy_dataset_from_tensor(X: torch.Tensor, noise_pct: float):
    class PairDS(torch.utils.data.Dataset):
        def __init__(self, base, pct): self.base, self.std = base, noise_pct/100.0
        def __len__(self): return len(self.base)
        def __getitem__(self, i):
            x = self.base[i][0].float()  # cache typically fp32
            if self.std > 0:
                x_noisy = (x + torch.randn_like(x) * self.std).clamp(-1, 1)
            else:
                x_noisy = x.clone()
            return x_noisy, x
    return PairDS(TensorDataset(X), noise_pct)

def build_loaders(train_fraction: float, batch_size: int,
                  train_noise_pct: float = 0.0,
                  val_noise_pct: float   = 0.0,
                  test_noise_pct: float  = 0.0):
    Xtr = torch.load(TRAIN_PT, map_location="cpu", weights_only=True)
    Xte = torch.load(TEST_PT,  map_location="cpu", weights_only=True)
    n_train = int(round(train_fraction * Xtr.shape[0]))
    X_train, X_val = Xtr[:n_train], Xtr[n_train:]

    train_ds = _noisy_dataset_from_tensor(X_train, train_noise_pct)
    val_ds   = _noisy_dataset_from_tensor(X_val,   val_noise_pct)
    test_ds  = _noisy_dataset_from_tensor(Xte,     test_noise_pct)

    kw = dict(batch_size=batch_size, shuffle=False,
              num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY,
              persistent_workers=cfg.PERSISTENT)
    return (train_ds, val_ds, test_ds,
            DataLoader(train_ds, **kw),
            DataLoader(val_ds,   **kw),
            DataLoader(test_ds,  **kw))

# ---------------- Losses / Metrics ----------------
def huber_recon(xhat, y, delta=1.0):
    d = xhat - y
    a = d.abs()
    q = torch.clamp(a, max=delta)
    l = a - q
    return (0.5 * q * q / delta + l).mean()

def rel_errors(xhat, y, eps=1e-12):
    d = xhat - y
    r1 = d.abs().flatten(1).sum(1) / (y.abs().flatten(1).sum(1) + eps)
    r2 = torch.sqrt((d**2).flatten(1).sum(1)) / (torch.sqrt((y**2).flatten(1).sum(1) + eps))
    return r1, r2

# ---------------- Models ----------------
class AuE(nn.Module):
    def __init__(self, ch=cfg.BOTTLENECK_CH):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), Snake(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1), Snake(),
            nn.Conv2d(64, ch, 4, stride=2, padding=1), Snake()   # -> 56x7x7 + Snake
        )
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(ch, 64, 2, stride=2), Snake(),
            nn.ConvTranspose2d(64, 32, 2, stride=2), Snake(),
            nn.Conv2d(32, 1, 1), nn.Tanh()
        )
    def forward(self, x):
        z = self.enc(x)
        xh = self.dec(z)
        return xh, {}

# ---------- VAE ----------
class VAE(nn.Module):
    def __init__(self, ch=cfg.BOTTLENECK_CH):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), Snake(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1), Snake()
        )
        self.mu     = nn.Conv2d(64, ch, 3, padding=1)
        self.logvar = nn.Conv2d(64, ch, 3, padding=1)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(ch, 64, 2, stride=2), Snake(),
            nn.Conv2d(64, 32, 3, padding=1), Snake(),
            nn.Conv2d(32, 1, 1), nn.Tanh()
        )

    def reparam(self, mu, logv):
        std = (0.5 * logv).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h  = self.enc(x)
        mu = self.mu(h)
        lv = self.logvar(h)

        # Sample during training
        if self.training:
            z = self.reparam(mu, lv)
        else:
            z = mu

        xh = self.dec(z)

        # KL map
        kl_map = -0.5 * (1 + lv - mu.pow(2) - lv.exp())   # (B,C,H,W)
        kl_per_sample = kl_map.mean(dim=[1, 2, 3])        # average over latent dims
        kld = kl_per_sample.mean()                        # average over batch

        return xh, {"kld": kld}

class VectorQuantizer(nn.Module):
    def __init__(self, K, D, beta=cfg.COMMIT_BETA):
        super().__init__()
        self.K, self.D, self.beta = K, D, beta
        self.emb = nn.Embedding(K, D)
        self.emb.weight.data.uniform_(-1.0 / D, 1.0 / D)
    def forward(self, z):
        zf   = z.permute(0,2,3,1).contiguous()
        flat = zf.view(-1, self.D)
        dist = (flat.pow(2).sum(1, keepdim=True)
                - 2 * flat @ self.emb.weight.t()
                + self.emb.weight.pow(2).sum(1))
        ind  = dist.argmin(1)
        zq   = self.emb(ind).view_as(zf)
        # losses
        commit = self.beta * ((zq - zf.detach())**2).mean()
        codebk = ((zf - zq.detach())**2).mean()
        loss = commit + codebk
        # straight-through
        zq   = zf + (zq - zf).detach()
        return zq.permute(0,3,1,2).contiguous(), loss

class VQVE(nn.Module):
    """Single-level VQ-VAE."""
    def __init__(self, ch=cfg.BOTTLENECK_CH, K=cfg.CODEBOOK_SIZE, beta=cfg.COMMIT_BETA):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1), Snake(),
            nn.Conv2d(32, ch, 4, stride=2, padding=1), Snake(),
        )
        self.vq  = VectorQuantizer(K, ch, beta)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(ch, 32, 4, stride=2, padding=1), Snake(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1), Snake(),
            nn.Tanh()
        )
    def forward(self, x):
        z = self.enc(x)
        zq, lvq = self.vq(z)
        xh = self.dec(zq)
        return xh, {"vq": lvq}

class VQVA2(nn.Module):
    def __init__(self, ch=cfg.BOTTLENECK_CH, top_ch=cfg.TOP_CH,
                 K=cfg.CODEBOOK_SIZE, beta=cfg.COMMIT_BETA):
        super().__init__()

        # ---- Bottom encoder: 7×7 ----
        self.enc_b = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1), Snake(),
            nn.Conv2d(32, ch, 4, stride=2, padding=1), Snake()
        )

        # ---- Top encoder: 4×4 ----
        self.enc_t = nn.Sequential(
            nn.Conv2d(ch, top_ch, 4, stride=2, padding=1), Snake()
        )

        # ---- Quantizers ----
        self.vq_t = VectorQuantizer(K, top_ch, beta)
        self.vq_b = VectorQuantizer(K, ch, beta)

        # ---- Upsample top quantized: ≈7×7 ----
        self.up_t = nn.Sequential(
            nn.ConvTranspose2d(top_ch, ch, 4, stride=2, padding=1), Snake(),
            nn.Conv2d(ch, ch, 3, padding=1), Snake()
        )

        # ---- Decoder (zq_b + upsampled top): image ----
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(ch * 2, 64, 4, stride=2, padding=1), Snake(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), Snake(),
            nn.Conv2d(32, 1, 1), Snake(),
            nn.Tanh()
        )

    def forward(self, x):
        zb = self.enc_b(x)     # (B, ch, 7,7)
        zt = self.enc_t(zb)    # (B, top_ch, 4,4)

        zqt, lt = self.vq_t(zt)
        up = self.up_t(zqt)

        # Align to bottom spatially
        Hb, Wb = zb.shape[-2:]
        Hu, Wu = up.shape[-2:]
        dh, dw = Hb - Hu, Wb - Wu
        if dh > 0 or dw > 0:
            up = nn.functional.pad(up, (0, max(0, dw), 0, max(0, dh)))
        elif dh < 0 or dw < 0:
            up = up[:, :, :Hb, :Wb]

        zb_input = zb + up
        zqb, lb = self.vq_b(zb_input)

        dec_in = torch.cat([zqb, up], dim=1)  # (B, ch*2, 7,7)
        xh = self.dec(dec_in)

        return xh, {"vq_top": lt, "vq_bottom": lb}

# ---------------- Training helpers ----------------
def forward_losses(model, xb, yb, name: str):
    xh, extra = model(xb)
    rec = huber_recon(xh, yb)
    if name == "VAE":
        loss = rec + cfg.VAE_BETA * extra["kld"]
    elif name == "VQVE":
        loss = rec + cfg.VQ_WEIGHT * extra["vq"]
    elif name == "VQVA2":
        loss = rec + cfg.VQ_TOP_WEIGHT * extra["vq_top"] + cfg.VQ_BOTTOM_WEIGHT * extra["vq_bottom"]
    else:
        loss = rec
    return xh, loss, rec, extra

@torch.no_grad()
def test_pass(model, loader, name: str):
    model.eval()
    r1s, r2s, recs, tots = [], [], [], []
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        if cfg.PRECISION == "fp64":
            xb = xb.to(torch.float64); yb = yb.to(torch.float64)
        xh, loss, rec, _ = forward_losses(model, xb, yb, name)
        r1, r2 = rel_errors(xh, yb)
        r1s.append(r1.cpu()); r2s.append(r2.cpu())
        recs.append(rec.item()); tots.append(loss.item())
    r1m = torch.cat(r1s).mean().item()
    r2m = torch.cat(r2s).mean().item()
    return r1m, r2m, float(np.mean(recs)), float(np.mean(tots)), torch.cat(r1s).numpy(), torch.cat(r2s).numpy()

def make_optimizer(model):
    name = cfg.OPTIMIZER.lower()
    if name == "soap":
        assert SOAP is not None, "soap.py not found."
        return SOAP(
            params=model.parameters(),
            lr=cfg.SOAP_LR,
            betas=cfg.SOAP_BETAS,
            weight_decay=cfg.SOAP_WD,
            precondition_frequency=cfg.SOAP_PREFREQ
        )
    # First-order Adam
    return torch.optim.Adam(
        model.parameters(),
        lr=cfg.LR,
        betas=cfg.ADAM_BETAS,
        weight_decay=cfg.WEIGHT_DECAY
    )

def run_one(model_class, model_name, tag, tf, trn, vn, tsn):
    run_dir = OUT_DIR / f"{model_name}_{tag}_tr{int(tf*100)}_tn{trn}_vn{vn}_ts{tsn}_{cfg.PRECISION}_{cfg.OPTIMIZER}_{time.strftime('%Y%m%d_%H%M%S')}"
    run_dir.mkdir(parents=True, exist_ok=True)

    # Data
    tr_ds, va_ds, te_ds, trL, vaL, teL = build_loaders(
        train_fraction=tf, batch_size=cfg.BATCH_SIZE,
        train_noise_pct=trn, val_noise_pct=vn, test_noise_pct=tsn
    )

    # Model & optim
    model = model_class().to(device)
    if cfg.PRECISION == "fp64":
        model = model.double()

    opt = make_optimizer(model)
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.EPOCHS)

    best_val, no_imp = float("inf"), 0
    history = {"train": [], "val": []}

    for ep in range(1, cfg.EPOCHS + 1):
        model.train(); t_losses = []
        for xb, yb in trL:
            xb, yb = xb.to(device), yb.to(device)
            if cfg.PRECISION == "fp64":
                xb = xb.to(torch.float64); yb = yb.to(torch.float64)
            opt.zero_grad(set_to_none=True)
            _, loss, _, _ = forward_losses(model, xb, yb, model_name)
            loss.backward()
            if cfg.GRAD_CLIP:
                nn.utils.clip_grad_norm_(model.parameters(), cfg.GRAD_CLIP)
            opt.step()
            t_losses.append(loss.item())
        sch.step()

        model.eval(); v_losses = []
        with torch.no_grad():
            for xb, yb in vaL:
                xb, yb = xb.to(device), yb.to(device)
                if cfg.PRECISION == "fp64":
                    xb = xb.to(torch.float64); yb = yb.to(torch.float64)
                _, vloss, _, _ = forward_losses(model, xb, yb, model_name)
                v_losses.append(vloss.item())
        v_mean = float(np.mean(v_losses)) if v_losses else float("inf")
        history["train"].append(float(np.mean(t_losses)) if t_losses else float("inf"))
        history["val"].append(v_mean)

        # central per-epoch
        log_loss(
            {"run_name": run_dir.name, "run_path": str(run_dir.resolve()), "model": model_name, "scenario": tag},
            ep, history["train"][-1], v_mean
        )

        improved = (best_val - v_mean) > cfg.MIN_DELTA
        if improved:
            best_val, no_imp = v_mean, 0
        else:
            no_imp += 1

        print(f"[{model_name}/{tag}/{cfg.PRECISION}/{cfg.OPTIMIZER}] epoch {ep:03d} | val={v_mean:.6f}")

        if (cfg.PATIENCE is not None) and (no_imp >= cfg.PATIENCE):
            print("Early stopping")
            break

    # Save local history files
    with (run_dir / "history.json").open("w") as f:
        json.dump(history, f, indent=2)
    with (run_dir / "loss_history.csv").open("w", newline="") as f:
        w = csv.writer(f); w.writerow(["epoch","train_total","val_total"])
        for e,(tr,va) in enumerate(zip(history["train"], history["val"]), start=1):
            w.writerow([e, tr, va])

    # Test
    r1m, r2m, rec_m, tot_m, r1_arr, r2_arr = test_pass(model, teL, model_name)
    np.savez_compressed(run_dir / "per_sample_metrics.npz",
                        relL1=r1_arr, relL2=r2_arr, indices=np.arange(r1_arr.shape[0]))

    # Report Rel L1 / L2 to stdout
    print(
        f"[TEST {model_name}/{tag}/{cfg.PRECISION}/{cfg.OPTIMIZER}] "
        f"recon={rec_m:.6f} total={tot_m:.6f} relL1={r1m:.6f} relL2={r2m:.6f}"
    )

    # Save checkpoint for this run
    ckpt_path = run_dir / "model.pt"
    torch.save(
        {
            "model_state": model.state_dict(),
            "model_name": model_name,
            "cfg": cfg.__dict__,
            "run_dir": str(run_dir.resolve()),
            "scenario": tag,
            "precision": cfg.PRECISION,
            "optimizer": cfg.OPTIMIZER,
        },
        ckpt_path,
    )

    # central experiment info + metrics
    exp_row = {
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "dataset": "inet100",
        "run_name": run_dir.name, "run_path": str(run_dir.resolve()),
        "model": model_name, "scenario": tag, "sc": tag[:1],
        "epochs": cfg.EPOCHS, "batch_size": cfg.BATCH_SIZE, "seed": cfg.SEED,
        "train_frac": tf, "val_frac": 1.0 - tf,
        "train_noise_pct": trn, "val_noise_pct": vn, "test_noise_pct": tsn,
        "bottleneck_ch": cfg.BOTTLENECK_CH, "best_val": best_val,
        "precision": cfg.PRECISION, "optimizer": cfg.OPTIMIZER
    }
    log_exp(exp_row)

    met_row = {
        "dataset": "inet100",
        "run_name": run_dir.name, "run_path": str(run_dir.resolve()),
        "model": model_name, "scenario": tag, "sc": tag[:1],
        "train_frac": tf, "val_frac": 1.0 - tf,
        "train_noise_pct": trn, "val_noise_pct": vn, "test_noise_pct": tsn,
        "recon_huber_mean": rec_m, "aux_loss_mean": 0.0, "total_loss_mean": tot_m,
        "relL1_mean": r1m, "relL2_mean": r2m,
        "train_size": len(tr_ds), "val_size": len(va_ds), "test_size": len(te_ds),
        "precision": cfg.PRECISION, "optimizer": cfg.OPTIMIZER
    }
    log_metrics(met_row)

    print(f"[DONE {model_name}/{tag}/{cfg.PRECISION}/{cfg.OPTIMIZER}] {run_dir}")

    # Return info for global best selection
    return {
        "run_dir": str(run_dir.resolve()),
        "model_name": model_name,
        "scenario": tag,
        "train_frac": tf,
        "train_noise_pct": trn,
        "val_noise_pct": vn,
        "test_noise_pct": tsn,
        "precision": cfg.PRECISION,
        "optimizer": cfg.OPTIMIZER,
        "total_loss_mean": tot_m,
        "ckpt_path": str(ckpt_path),
    }

def train_all_for_model(model_class, model_name):
    results = []
    # b) split sweep with clean test
    for tf in cfg.TRAIN_FRACTIONS_B:
        results.append(run_one(model_class, model_name, "b_split", tf, 0.0, 0.0, 0.0))
    # c) fixed split with test noise sweep
    for tn in cfg.TEST_NOISES:
        results.append(run_one(model_class, model_name, "c_test_noise", cfg.TRAIN_FRACTION_FIXED, 0.0, 0.0, float(tn)))
    # d) noise on train (same pct) + same test noise
    for tn in cfg.TEST_NOISES:
        results.append(run_one(model_class, model_name, "d_train_and_test_noise", cfg.TRAIN_FRACTION_FIXED, float(tn), 0.0, float(tn)))
    print(f"[{model_name}] all scenarios complete.")
    return results

# ---------------- Run models ----------------
ALL_MODELS = [
    # (AuE,   "AuE"),
    (VAE,   "VAE"),
    # (VQVE,  "VQVE"),
    # (VQVA2, "VQVA2"),
]

# Collect all run results for global best selection
all_results_by_model = {name: [] for (_, name) in ALL_MODELS}

# 4-case sweep: (fp32|fp64) × (adam|soap)
for prec in ["fp32", "fp64"]:
    cfg.PRECISION = prec
    torch.set_default_dtype(torch.float64 if cfg.PRECISION == "fp64" else torch.float32)
    for opt in ["adam", "soap"]:
        cfg.OPTIMIZER = opt
        if cfg.OPTIMIZER == "soap" and SOAP is None:
            raise RuntimeError("Requested SOAP but soap.py not found. Place soap.py next to this script.")
        for cls, name in ALL_MODELS:
            results = train_all_for_model(cls, name)
            all_results_by_model[name].extend(results)

# Global best per model (over all scenarios / prec / optimizers)
for model_name, runs in all_results_by_model.items():
    if not runs:
        continue
    best_run = min(runs, key=lambda r: r["total_loss_mean"])
    ckpt_path = best_run["ckpt_path"]
    state = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    best_out_path = OUT_DIR / f"best_overall_{model_name}.pt"
    torch.save(state, best_out_path)
    print(
        f"\n[GLOBAL BEST] {model_name} total_loss_mean={best_run['total_loss_mean']:.6f} "
        f"(scenario={best_run['scenario']}, prec={best_run['precision']}, opt={best_run['optimizer']})\n"
        f"Saved to: {best_out_path}"
    )

print("\nAll models completed. Central logs:")
print("  -", cfg.EXP_CSV.resolve())
print("  -", cfg.LOSS_CSV.resolve())
print("  -", cfg.MET_CSV.resolve())