# WGAN-GP — Bug-Fixes

In [None]:
# WGAN-GP — FIXED SOLUTION

import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data (CIFAR-10 32×32; normalized for Tanh generator)
# ---------------------------
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
])
ds = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform)
loader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# ---------------------------
# Hyperparams
# ---------------------------
z_dim     = 128
g_lr      = 2e-4
d_lr      = 2e-4
n_critic  = 5          # FIX [E-4]
lambda_gp = 10.0

# ---------------------------
# Models
# ---------------------------
class Critic(nn.Module):
    def __init__(self, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, ch,   4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch, ch*2, 4, 2, 1), nn.InstanceNorm2d(ch*2, affine=True), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*2, ch*4, 4, 2, 1), nn.InstanceNorm2d(ch*4, affine=True), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*4, 1, 4, 1, 0)   # FIX [SE-1]: no Sigmoid — raw scores
        )
    def forward(self, x): return self.net(x).view(x.size(0))

class Gen(nn.Module):
    def __init__(self, z=128, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z,   ch*4, 4, 1, 0, bias=False), nn.BatchNorm2d(ch*4), nn.ReLU(True),
            nn.ConvTranspose2d(ch*4, ch*2, 4, 2, 1, bias=False), nn.BatchNorm2d(ch*2), nn.ReLU(True),
            nn.ConvTranspose2d(ch*2, ch,   4, 2, 1, bias=False), nn.BatchNorm2d(ch),   nn.ReLU(True),
            nn.ConvTranspose2d(ch,   3,    4, 2, 1, bias=False), nn.Tanh()
        )
    def forward(self, z): return self.net(z.view(z.size(0), z.size(1), 1, 1))

D = Critic().to(device)
G = Gen(z_dim).to(device)

# ---------------------------
# Optimizers (WGAN-GP-friendly betas)
# ---------------------------
optG = torch.optim.Adam(G.parameters(), lr=g_lr, betas=(0.0, 0.9))   # FIX [E-3]
optD = torch.optim.Adam(D.parameters(), lr=d_lr, betas=(0.0, 0.9))   # FIX [E-3]

# ---------------------------
# Gradient penalty
# ---------------------------
def gradient_penalty(Dnet, real, fake, lambda_gp=10.0):
    b = real.size(0)
    eps = torch.rand(b, 1, 1, 1, device=real.device)               # FIX [M-5]
    x_hat = (eps * real + (1.0 - eps) * fake).requires_grad_(True)
    d_hat = Dnet(x_hat)
    grads = torch.autograd.grad(
        outputs=d_hat.sum(), inputs=x_hat,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]                                                             # FIX [M-6]
    grad_norm = grads.view(b, -1).norm(2, dim=1)
    return lambda_gp * ((grad_norm - 1.0) ** 2).mean()               # FIX [M-6]

# ---------------------------
# Training (short sanity run)
# ---------------------------
G.train(); D.train()
for step, (real, _) in enumerate(loader):
    real = real.to(device)
    b = real.size(0)

    # ---- Critic updates ----
    for _ in range(n_critic):                                        # FIX [E-4]
        z = torch.randn(b, z_dim, device=device)
        fake = G(z).detach()                                         # FIX [E-7]

        optD.zero_grad()
        d_real = D(real).mean()                                      # FIX [SE-9]: use .mean()
        d_fake = D(fake).mean()                                      # FIX [SE-9]
        gp = gradient_penalty(D, real, fake, lambda_gp=lambda_gp)

        lossD = (d_fake - d_real) + gp                               # FIX [SE-2] Wasserstein + GP
        lossD.backward()
        optD.step()

        # FIX [H-8]: no weight clipping when using GP (intentionally omitted)

    # ---- Generator update ----
    z = torch.randn(b, z_dim, device=device)
    optG.zero_grad()
    fake = G(z)
    lossG = -D(fake).mean()                                          # FIX [SE-2] Wasserstein G loss
    lossG.backward()
    optG.step()                                                      # FIX [H-10]

    if step % 100 == 0:
        print(f"step {step:05d}  lossD={lossD.item():.3f}  lossG={lossG.item():.3f}")
    if step == 600:   # small sanity run
        break


## what are fixes (mapping to bug tags)

1. **\[SE-1]** removed Sigmoid from critic → raw scores
2. **\[SE-2]** replaced BCE objectives with Wasserstein: `lossD=(D(fake)−D(real)).mean()+GP`, `lossG=−D(fake).mean()`
3. **\[E-3]** Adam betas → `(0, 0.9)`
4. **\[E-4]** `n_critic = 5`
5. **\[M-5]** interpolation epsilon `eps ~ U(0,1)` with shape `(b,1,1,1)`
6. **\[M-6]** GP grads wrt `x_hat`, `create_graph=True`, L2 norm penalty
7. **\[E-7]** detached `fake` during critic update
8. **\[H-8]** removed weight clipping (don’t mix with GP)
9. **\[SE-9]** used `.mean()` for stable loss scaling
10. **\[H-10]** correct optimizer usage and grad zeroing per update
