
Generative Models: Investigating **VAE** vs **GAN** Biases (CIFAR-10)

This notebook trains and evaluates a **Variational Autoencoder (VAE)** and a **DCGAN** on **CIFAR-10**, then compares their biases:
- **Fidelity vs. Diversity**
- **Reconstruction vs. Generation**
- **Latent Space Structure & Interpolations**
- **Out-of-Distribution (OOD) behavior**
- **Training dynamics & stability**



In [None]:


!python -m pip install torch-fidelity==0.3.0 scipy


Collecting torch-fidelity==0.3.0
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Collecting torchvision (from torch-fidelity==0.3.0)
  Using cached torchvision-0.23.0-cp313-cp313-macosx_11_0_arm64.whl.metadata (6.1 kB)
Collecting torch (from torch-fidelity==0.3.0)
  Using cached torch-2.8.0-cp313-none-macosx_11_0_arm64.whl.metadata (30 kB)
Downloading torch_fidelity-0.3.0-py3-none-any.whl (37 kB)
Using cached torchvision-0.23.0-cp313-cp313-macosx_11_0_arm64.whl (1.9 MB)
Downloading torch-2.8.0-cp313-none-macosx_11_0_arm64.whl (73.6 MB)
[2K   [91m━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.2/73.6 MB[0m [31m60.7 kB/s[0m eta [36m0:17:42[0m^C


## Imports & Utilities

In [None]:

import os, json, math, random, numpy as np
from pathlib import Path
from datetime import datetime

import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, utils as vutils

import matplotlib.pyplot as plt
from tqdm import tqdm

try:
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE
    SKLEARN_OK = True
except Exception:
    SKLEARN_OK = False

try:
    from torch_fidelity import calculate_metrics as tf_calculate_metrics
    TORCH_FIDELITY_OK = True
except Exception:
    TORCH_FIDELITY_OK = False

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def save_grid(tensor, path, nrow=8, normalize=True, value_range=(-1,1)):
    grid = vutils.make_grid(tensor, nrow=nrow, normalize=normalize, value_range=value_range)
    plt.figure(figsize=(8,8)); plt.axis('off')
    plt.imshow(np.transpose(grid.cpu().numpy(), (1,2,0)))
    plt.tight_layout(); plt.savefig(path, dpi=150); plt.close()

def save_fig(path):
    plt.tight_layout(); plt.savefig(path, dpi=150); plt.close()


## Configuration

In [None]:

class Cfg:
    run = 'both'          # 'vae', 'gan', or 'both'
    epochs = 30
    batch_size = 128
    z_dim = 128
    beta = 1.0          
    recon_loss = 'mse'    
    lr_vae = 1e-3
    lr_g = 2e-4
    lr_d = 2e-4
    seed = 42
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    outdir = None        
    compute_metrics = False 
    run_ood = False         

set_seed(Cfg.seed)
stamp = datetime.now().strftime('%Y%m%d_%H%M%S')
OUTDIR = Path(Cfg.outdir or f'./runs/{stamp}')
(OUTDIR / 'checkpoints').mkdir(parents=True, exist_ok=True)
(OUTDIR / 'figures').mkdir(parents=True, exist_ok=True)
DEVICE = torch.device(Cfg.device)
OUTDIR


## Data Loaders (CIFAR-10, plus OOD sets if enabled)

In [None]:

def get_dataloaders(batch_size=128, num_workers=2, root='./data'):
    tfm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ])
    train = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=tfm)
    test  = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=tfm)
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

def get_ood_loaders(batch_size=128, num_workers=2, root='./data'):
    tfm = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ])
    cifar100 = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=tfm)
    svhn = torchvision.datasets.SVHN(root=root, split='test', download=True, transform=tfm)
    cifar100_loader = DataLoader(cifar100, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    svhn_loader = DataLoader(svhn, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return cifar100_loader, svhn_loader

train_loader, test_loader = get_dataloaders(Cfg.batch_size)
try:
    next(iter(train_loader))
    print("CIFAR-10 ready.")
except StopIteration:
    print("Data issue.")


## VAE: Model & Loss

In [None]:

class Encoder(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1), 
            nn.BatchNorm2d(256),
            nn.ReLU(True),
        )
        self.fc_mu = nn.Linear(256*4*4, Cfg.z_dim)
        self.fc_logvar = nn.Linear(256*4*4, Cfg.z_dim)

    def forward(self, x):
        h = self.net(x).view(x.size(0), -1)
        return self.fc_mu(h), self.fc_logvar(h)

class Decoder(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.fc = nn.Linear(Cfg.z_dim, 256*4*4)
        self.net = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), 
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),   
            nn.Tanh()
        )

    def forward(self, z):
        h = self.fc(z).view(z.size(0), 256, 4, 4)
        return self.net(h)

class VAE(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.encoder = Encoder(z_dim)
        self.decoder = Decoder(z_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar, z

def vae_loss(x, x_recon, mu, logvar, recon_type='mse', beta=1.0):
    if recon_type == 'mse':
        recon = F.mse_loss(x_recon, x, reduction='sum') / x.size(0)
    else:
        x_scaled = (x + 1) / 2
        x_recon_scaled = (x_recon + 1) / 2
        recon = F.binary_cross_entropy(x_recon_scaled, x_scaled, reduction='sum') / x.size(0)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon + beta * kl, recon, kl


### Train VAE

In [None]:

def train_vae():
    vae = VAE(Cfg.z_dim).to(DEVICE)
    opt = optim.Adam(vae.parameters(), lr=Cfg.lr_vae, betas=(0.9, 0.999))
    history = {'loss': [], 'recon': [], 'kl': []}
    fixed_test = next(iter(test_loader))[0][:64].to(DEVICE)

    for epoch in range(1, Cfg.epochs+1):
        vae.train()
        total_loss = total_recon = total_kl = 0.0
        for x, _ in tqdm(train_loader, desc=f'VAE Epoch {epoch}/{Cfg.epochs}'):
            x = x.to(DEVICE)
            opt.zero_grad()
            x_recon, mu, logvar, _ = vae(x)
            loss, recon, kl = vae_loss(x, x_recon, mu, logvar, recon_type=Cfg.recon_loss, beta=Cfg.beta)
            loss.backward(); opt.step()
            b = x.size(0)
            total_loss += loss.item() * b
            total_recon += recon.item() * b
            total_kl += kl.item() * b

        n = len(train_loader.dataset)
        epoch_loss = total_loss / n
        epoch_recon = total_recon / n
        epoch_kl = total_kl / n
        history['loss'].append(epoch_loss); history['recon'].append(epoch_recon); history['kl'].append(epoch_kl)

        vae.eval()
        with torch.no_grad():
            x = fixed_test
            x_recon, _, _, _ = vae(x)
            save_grid(x.cpu(), OUTDIR / 'figures' / f"vae_epoch{epoch:03d}_orig.png")
            save_grid(x_recon.cpu(), OUTDIR / 'figures' / f"vae_epoch{epoch:03d}_recon.png")
            z = torch.randn(64, Cfg.z_dim, device=DEVICE)
            samples = vae.decoder(z)
            save_grid(samples.cpu(), OUTDIR / 'figures' / f"vae_epoch{epoch:03d}_samples.png")

        torch.save(vae.state_dict(), OUTDIR / 'checkpoints' / f"vae_epoch{epoch:03d}.pt")
    plt.figure()
    plt.plot(history['loss'], label='loss'); plt.plot(history['recon'], label='recon'); plt.plot(history['kl'], label='kl')
    plt.legend(); save_fig(OUTDIR / 'figures' / "vae_training_curves.png")
    torch.save(vae.state_dict(), OUTDIR / 'checkpoints' / 'vae_final.pt')
    return vae, history

vae, vae_hist = (None, None)
if Cfg.run in ['vae', 'both']:
    vae, vae_hist = train_vae()
    print("VAE trained and saved.")
else:
    print("Skipping VAE training (Cfg.run != 'vae'/'both').")


## DCGAN: Model & Training

In [None]:

class DCGANGenerator(nn.Module):
    def __init__(self, z_dim=128, ngf=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, ngf*4, 4, 1, 0, bias=False), 
            nn.BatchNorm2d(ngf*4), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False), 
            nn.BatchNorm2d(ngf*2), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),    
            nn.BatchNorm2d(ngf), nn.ReLU(True),
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),        
            nn.Tanh()
        )
    def forward(self, z):
        return self.net(z.view(z.size(0), z.size(1), 1, 1))

class DCGANDiscriminator(nn.Module):
    def __init__(self, ndf=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),  
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False), 
            nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*4, 1, 4, 1, 0, bias=False),  
        )
    def forward(self, x):
        return self.net(x).view(-1, 1).squeeze(1)

def train_gan():
    netG = DCGANGenerator(Cfg.z_dim).to(DEVICE)
    netD = DCGANDiscriminator().to(DEVICE)
    optG = optim.Adam(netG.parameters(), lr=Cfg.lr_g, betas=(0.5, 0.999))
    optD = optim.Adam(netD.parameters(), lr=Cfg.lr_d, betas=(0.5, 0.999))
    criterion = nn.BCEWithLogitsLoss()

    fixed_z = torch.randn(64, Cfg.z_dim, device=DEVICE)
    history = {'d_loss': [], 'g_loss': []}

    for epoch in range(1, Cfg.epochs+1):
        g_loss_epoch = d_loss_epoch = 0.0
        for x, _ in tqdm(train_loader, desc=f'GAN Epoch {epoch}/{Cfg.epochs}'):
            x = x.to(DEVICE); b = x.size(0)

            optD.zero_grad()
            real_logits = netD(x); real_labels = torch.ones(b, device=DEVICE)
            d_real = criterion(real_logits, real_labels)
            z = torch.randn(b, Cfg.z_dim, device=DEVICE)
            fake = netG(z)
            fake_logits = netD(fake.detach()); fake_labels = torch.zeros(b, device=DEVICE)
            d_fake = criterion(fake_logits, fake_labels)
            d_loss = d_real + d_fake; d_loss.backward(); optD.step()

            optG.zero_grad()
            fake_logits = netD(fake)
            g_loss = criterion(fake_logits, real_labels)
            g_loss.backward(); optG.step()

            d_loss_epoch += d_loss.item() * b
            g_loss_epoch += g_loss.item() * b

        n = len(train_loader.dataset)
        history['d_loss'].append(d_loss_epoch / n)
        history['g_loss'].append(g_loss_epoch / n)

        with torch.no_grad():
            samples = netG(fixed_z)
        save_grid(samples.cpu(), OUTDIR / 'figures' / f"gan_epoch{epoch:03d}_samples.png")
        torch.save({'G': netG.state_dict(), 'D': netD.state_dict()}, OUTDIR / 'checkpoints' / f"gan_epoch{epoch:03d}.pt")

    plt.figure()
    plt.plot(history['d_loss'], label='D loss'); plt.plot(history['g_loss'], label='G loss')
    plt.legend(); save_fig(OUTDIR / 'figures' / "gan_training_curves.png")
    torch.save({'G': netG.state_dict(), 'D': netD.state_dict()}, OUTDIR / 'checkpoints' / 'gan_final.pt')
    return netG, netD, history

netG = netD = gan_hist = None
if Cfg.run in ['gan', 'both']:
    netG, netD, gan_hist = train_gan()
    print("GAN trained and saved.")
else:
    print("Skipping GAN training (Cfg.run != 'gan'/'both').")


## Reconstructions (VAE) & Samples / Interpolations (Both)

In [None]:

@torch.no_grad()
def vae_reconstructions(vae, loader, device, outdir):
    vae.eval()
    x, y = next(iter(loader))
    x = x.to(device)[:64]
    x_recon, mu, logvar, z = vae(x)
    save_grid(x.cpu(), outdir / "vae_recon_orig.png")
    save_grid(x_recon.cpu(), outdir / "vae_recon_recon.png")
    mse = F.mse_loss(x_recon, x, reduction='mean').item()
    return mse

@torch.no_grad()
def vae_interpolations(vae, loader, device, outdir, steps=8):
    vae.eval()
    x, _ = next(iter(loader))
    x1, x2 = x[:1].to(device), x[1:2].to(device)
    mu1, logvar1 = vae.encoder(x1); mu2, logvar2 = vae.encoder(x2)
    z1, z2 = mu1, mu2
    alphas = torch.linspace(0,1,steps, device=device).view(-1,1)
    z_interp = (1 - alphas) * z1 + alphas * z2
    imgs = vae.decoder(z_interp)
    save_grid(imgs.cpu(), outdir / "vae_interp.png", nrow=steps)

@torch.no_grad()
def gan_interpolations(netG, device, outdir, z_dim=128, steps=8):
    if netG is None: return
    netG.eval()
    z1 = torch.randn(1, z_dim, device=device); z2 = torch.randn(1, z_dim, device=device)
    alphas = torch.linspace(0,1,steps, device=device).view(-1,1)
    z = (1 - alphas) * z1 + alphas * z2
    imgs = netG(z)
    save_grid(imgs.cpu(), outdir / "gan_interp.png", nrow=steps)

metrics = {}
if vae is not None:
    recon_mse = vae_reconstructions(vae, test_loader, DEVICE, OUTDIR / 'figures')
    vae_interpolations(vae, test_loader, DEVICE, OUTDIR / 'figures')
    metrics['vae_recon_mse'] = float(recon_mse)
gan_interpolations(netG, DEVICE, OUTDIR / 'figures', z_dim=Cfg.z_dim)
print("Saved reconstructions and interpolations.")


## VAE Latent Space Analysis (PCA / optional t-SNE)

In [None]:

@torch.no_grad()
def vae_latent_analysis(vae, loader, device, outdir, max_samples=5000):
    if vae is None: return
    vae.eval()
    zs, ys, count = [], [], 0
    for x, y in loader:
        x = x.to(device)
        mu, logvar = vae.encoder(x)
        zs.append(mu.cpu().numpy()); ys.append(y.numpy())
        count += x.size(0)
        if count >= max_samples: break
    Z = np.concatenate(zs, axis=0); Y = np.concatenate(ys, axis=0)
    # PCA
    pca = PCA(n_components=2)
    Zp = pca.fit_transform(Z)
    plt.figure(figsize=(6,6))
    scatter = plt.scatter(Zp[:,0], Zp[:,1], c=Y, s=5, alpha=0.7)
    plt.title("VAE latent PCA (CIFAR-10 classes)")
    plt.legend(*scatter.legend_elements(num_classes=10), title="Class", loc="best", fontsize=6)
    save_fig(outdir / "vae_latent_pca.png")
    if SKLEARN_OK:
        tsne = TSNE(n_components=2, init='pca', learning_rate='auto', perplexity=30, n_iter=1000)
        Zt = tsne.fit_transform(Z[:3000])
        plt.figure(figsize=(6,6))
        plt.scatter(Zt[:,0], Zt[:,1], c=Y[:3000], s=5, alpha=0.7)
        plt.title("VAE latent t-SNE (CIFAR-10)")
        save_fig(outdir / "vae_latent_tsne.png")

vae_latent_analysis(vae, test_loader, DEVICE, OUTDIR / 'figures')
print("Saved latent projections.")


## OOD: VAE Reconstruction Error (CIFAR-100 / SVHN)

In [None]:

@torch.no_grad()
def vae_ood(vae, in_loader, ood_loader, device, outdir, name='cifar10_vs_cifar100'):
    if vae is None: return None
    vae.eval()
    def recon_errors(loader, max_batches=100):
        errs = []
        for bi, (x, _) in enumerate(loader):
            x = x.to(device)
            x_recon, mu, logvar, _ = vae(x)
            mse = F.mse_loss(x_recon, x, reduction='none').view(x.size(0), -1).mean(dim=1)
            errs.append(mse.cpu().numpy())
            if bi+1 >= max_batches: break
        return np.concatenate(errs, axis=0)

    in_err = recon_errors(in_loader); ood_err = recon_errors(ood_loader)
    plt.figure(figsize=(6,4))
    plt.hist(in_err, bins=50, alpha=0.7, label='In (CIFAR-10)', density=True)
    plt.hist(ood_err, bins=50, alpha=0.7, label='OOD', density=True)
    plt.xlabel('Reconstruction MSE per image'); plt.ylabel('Density')
    plt.title('VAE Reconstruction Error: In vs OOD'); plt.legend()
    save_fig(outdir / f"vae_ood_{name}.png")

    from scipy.stats import rankdata
    scores = np.concatenate([in_err, ood_err])
    labels = np.concatenate([np.zeros_like(in_err), np.ones_like(ood_err)])
    ranks = rankdata(scores)
    pos = labels == 1; n_pos = pos.sum(); n_neg = (~pos).sum()
    auc = (ranks[pos].sum() - n_pos*(n_pos+1)/2) / (n_pos*n_neg)
    with open(outdir / f"vae_ood_{name}.json", 'w') as f:
        json.dump({'auroc': float(auc)}, f, indent=2)
    print(f"[OOD] {name} AUROC ~ {auc:.3f}")
    return float(auc)

if Cfg.run_ood and vae is not None:
    cifar100_loader, svhn_loader = get_ood_loaders(Cfg.batch_size)
    auc1 = vae_ood(vae, test_loader, cifar100_loader, DEVICE, OUTDIR / 'figures', name='cifar10_vs_cifar100')
    auc2 = vae_ood(vae, test_loader, svhn_loader, DEVICE, OUTDIR / 'figures', name='cifar10_vs_svhn')


## FID & Inception Score for GAN

In [None]:

@torch.no_grad()
def dump_gan_samples(netG, outdir, device, z_dim=128, n=50000, batch=250):
    if netG is None: return None
    netG.eval()
    save_dir = outdir / "gan_samples_for_metrics"
    save_dir.mkdir(parents=True, exist_ok=True)
    idx = 0
    for _ in range(0, n, batch):
        cur = min(batch, n - idx)
        z = torch.randn(cur, z_dim, device=device)
        imgs = netG(z).cpu()
        imgs = (imgs + 1) / 2
        for i in range(imgs.size(0)):
            torchvision.utils.save_image(imgs[i], save_dir / f"{idx+i:06d}.png")
        idx += cur
    return save_dir

def compute_fid_is(sample_dir, ref='cifar10', device='cuda'):
    if not TORCH_FIDELITY_OK:
        print("[Metrics] torch-fidelity not installed; skipping FID/IS."); return {}
    mets = tf_calculate_metrics(input1=str(sample_dir), input2=ref, cuda=('cuda' in device), isc=True, fid=True, verbose=False)
    return {k: float(v) for k, v in mets.items()}

if Cfg.compute_metrics and netG is not None:
    sample_dir = dump_gan_samples(netG, OUTDIR / 'figures', DEVICE, z_dim=Cfg.z_dim, n=50000, batch=250)
    mets = compute_fid_is(sample_dir, ref='cifar10', device=Cfg.device)
    metrics['gan_metrics'] = mets
    print(mets)


## Save Summary & Where to Find Results

In [None]:

with open(OUTDIR / 'results_summary.json', 'w') as f:
    json.dump(metrics, f, indent=2)

print("Done. Find outputs in:", str(OUTDIR.resolve()))
print("Figures:", str((OUTDIR / 'figures').resolve()))
print("Checkpoints:", str((OUTDIR / 'checkpoints').resolve()))
