# Experiment on VAE + diffusion model

In [1]:
# minimal latent-diffusion-on-VAE for MNIST (PyTorch)
import math, os, random, time
from dataclasses import dataclass
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms, utils

class VAE(nn.Module):
    def __init__(self, z_dim: int = 8):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1), nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(True),
            nn.Flatten()
        )
        self.enc_fc = nn.Linear(64 * 7 * 7, 2 * z_dim)
        self.dec_fc = nn.Linear(z_dim, 64 * 7 * 7)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(32, 1, 4, 2, 1)  # logits
        )
        self.z_dim = z_dim
        self.apply(self._init)

    @staticmethod
    def _init(m: nn.Module):
        if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.zeros_(m.bias)

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.enc(x)
        mu, logvar = self.enc_fc(h).chunk(2, dim=1)
        return mu, logvar

    @staticmethod
    def reparam(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        return mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)

    def decode_logits(self, z: torch.Tensor) -> torch.Tensor:
        h = self.dec_fc(z).view(z.size(0), 64, 7, 7)
        return self.dec(h)

    def forward(self, x: torch.Tensor):
        mu, logvar = self.encode(x)
        z = self.reparam(mu, logvar)
        logits = self.decode_logits(z)
        return logits, mu, logvar

def vae_loss(logits: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    recon = F.binary_cross_entropy_with_logits(logits, x, reduction="sum")
    kld = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum()
    return recon + kld, recon.detach(), kld.detach()

def sinusoidal_embedding(t: torch.Tensor, dim: int) -> torch.Tensor:
    device = t.device
    half = dim // 2
    freqs = torch.exp(torch.linspace(math.log(1e-4), math.log(1.0), half, device=device))
    ang = t.float()[:, None] * freqs[None, :]
    emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0,1))
    return emb

class LatentDenoiser(nn.Module):
    def __init__(self, z_dim: int, t_dim: int = 64, h: int = 256):
        super().__init__()
        self.t_dim = t_dim
        self.net = nn.Sequential(
            nn.Linear(z_dim + t_dim, h), nn.SiLU(),
            nn.Linear(h, h), nn.SiLU(),
            nn.Linear(h, z_dim)
        )
        self.apply(self._init)

    @staticmethod
    def _init(m: nn.Module):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, zt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        te = sinusoidal_embedding(t, self.t_dim)
        return self.net(torch.cat([zt, te], dim=1))

@dataclass
class DiffusionCfg:
    T: int = 200
    beta_start: float = 1e-4
    beta_end: float = 0.02

class LatentDDPM:
    def __init__(self, z_dim: int, cfg: DiffusionCfg):
        b = torch.linspace(cfg.beta_start, cfg.beta_end, cfg.T)
        a = 1.0 - b
        ab = torch.cumprod(a, dim=0)
        self.beta, self.alpha, self.alphabar = b, a, ab
        self.T = cfg.T
        self.model = LatentDenoiser(z_dim)

    def add_noise(self, z0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        ab = self.alphabar[t].to(z0.device)[:, None]
        return torch.sqrt(ab) * z0 + torch.sqrt(1 - ab) * noise

    def loss(self, z0: torch.Tensor) -> torch.Tensor:
        B, device = z0.size(0), z0.device
        t = torch.randint(0, self.T, (B,), device=device, dtype=torch.long)
        eps = torch.randn_like(z0)
        zt = self.add_noise(z0, t, eps)
        eps_pred = self.model(zt, t)
        return F.mse_loss(eps_pred, eps)

    @torch.no_grad()
    def sample(self, n: int, device: torch.device) -> torch.Tensor:
        z = torch.randn(n, self.model.net[-1].out_features, device=device)
        for t in reversed(range(self.T)):
            tt = torch.full((n,), t, device=device, dtype=torch.long)
            eps = self.model(z, tt)
            a, b, ab = self.alpha[t].to(device), self.beta[t].to(device), self.alphabar[t].to(device)
            z = (1/torch.sqrt(a))*(z - (b/torch.sqrt(1 - ab)) * eps)
            if t > 0:
                z = z + torch.sqrt(b) * torch.randn_like(z)
        return z

def main():
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # data
    tfm = transforms.Compose([transforms.ToTensor()])
    train_set = datasets.MNIST(root="data", train=True, download=True, transform=tfm)
    train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)

    # VAE
    vae = VAE(z_dim=8).to(device)
    opt_vae = torch.optim.Adam(vae.parameters(), lr=1e-3)

    print("Training VAE…")
    vae.train()
    for epoch in range(5):
        s_time, tot, rec, kl = time.time(), 0.0, 0.0, 0.0
        for x, _ in train_loader:
            x = x.to(device)
            logits, mu, logvar = vae(x)
            loss, r, k = vae_loss(logits, x, mu, logvar)
            opt_vae.zero_grad(); loss.backward(); opt_vae.step()
            tot += loss.item(); rec += r.item(); kl += k.item()
        n = len(train_loader.dataset)
        print(f"epoch {epoch+1}: loss={tot/n:.4f} recon={rec/n:.4f} kl={kl/n:.4f} ({time.time()-s_time:.1f}s)")

    # freeze VAE, build latent dataset (use μ for stability)
    vae.eval()
    zs = []
    with torch.no_grad():
        for x, _ in DataLoader(train_set, batch_size=512):
            mu, _ = vae.encode(x.to(device))
            zs.append(mu.cpu())
    z_all = torch.cat(zs)
    z_ds = TensorDataset(z_all)
    z_loader = DataLoader(z_ds, batch_size=512, shuffle=True, num_workers=2, pin_memory=True)

    # Diffusion in latent space
    ddpm = LatentDDPM(z_dim=vae.z_dim, cfg=DiffusionCfg(T=200)).__dict__  # grab buffers before moving
    diffusion = LatentDDPM(z_dim=vae.z_dim, cfg=DiffusionCfg(T=200))      # new object to hold model
    diffusion.beta, diffusion.alpha, diffusion.alphabar = ddpm["beta"], ddpm["alpha"], ddpm["alphabar"]
    diffusion.model.to(device)
    opt_diff = torch.optim.Adam(diffusion.model.parameters(), lr=2e-4)

    print("Training latent DDPM…")
    for epoch in range(10):
        s_time, tot = time.time(), 0.0
        for (z0,) in z_loader:
            z0 = z0.to(device)
            loss = diffusion.loss(z0)
            opt_diff.zero_grad(); loss.backward(); opt_diff.step()
            tot += loss.item() * z0.size(0)
        print(f"epoch {epoch+1}: mse={tot/len(z_ds):.6f} ({time.time()-s_time:.1f}s)")

    # sample → decode → save
    diffusion.model.eval()
    with torch.no_grad():
        z_samp = diffusion.sample(n=64, device=device)
        x_logits = vae.decode_logits(z_samp)
        x = torch.sigmoid(x_logits).cpu()
        os.makedirs("samples", exist_ok=True)
        utils.save_image(x, "samples/mnist_ldm.png", nrow=8)
    print("Saved samples to samples/mnist_ldm.png")

if __name__ == "__main__":
    main()


100%|██████████| 9.91M/9.91M [00:27<00:00, 364kB/s] 
100%|██████████| 28.9k/28.9k [00:00<00:00, 189kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.01MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.1MB/s]


Training VAE…
epoch 1: loss=210.4685 recon=198.5951 kl=11.8734 (7.7s)
epoch 2: loss=135.9668 recon=121.0695 kl=14.8972 (8.4s)
epoch 3: loss=128.3153 recon=112.8967 kl=15.4186 (9.2s)
epoch 4: loss=125.0975 recon=109.5625 kl=15.5350 (12.1s)
epoch 5: loss=122.8360 recon=107.1942 kl=15.6418 (16.2s)
Training latent DDPM…


RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)