# StyleGAN-lite — Bug-Fixes

In [None]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data (CIFAR-10 → 16×16; normalized for Tanh)
# ---------------------------
transform = transforms.Compose([
    transforms.Resize(16),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
ds = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform)
loader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

z_dim = 128
w_dim = 128
g_lr  = 2e-4
d_lr  = 2e-4

# ---------------------------
# Mapping network  [E-5]
# ---------------------------
class Mapping(nn.Module):
    def __init__(self, z_dim=128, w_dim=128, depth=2):
        super().__init__()
        layers = []
        in_d = z_dim
        for _ in range(depth):
            layers += [nn.Linear(in_d, w_dim), nn.LeakyReLU(0.2, True)]
            in_d = w_dim
        self.net = nn.Sequential(*layers)
    def forward(self, z):
        z = F.normalize(z, dim=1)               # FIX: normalize z
        return self.net(z)

# ---------------------------
# AdaIN  [M-7]
# ---------------------------
class AdaIN(nn.Module):
    def __init__(self, ch, w_dim):
        super().__init__()
        self.scale = nn.Linear(w_dim, ch)
        self.shift = nn.Linear(w_dim, ch)
    def forward(self, x, w):
        mu = x.mean(dim=(2,3), keepdim=True)
        sigma = x.std(dim=(2,3), keepdim=True) + 1e-8
        x_n = (x - mu) / sigma
        s = self.scale(w).unsqueeze(-1).unsqueeze(-1)
        b = self.shift(w).unsqueeze(-1).unsqueeze(-1)
        return x_n * s + b

# ---------------------------
# Styled block  [SE-3][M-7]
# ---------------------------
class StyledBlock(nn.Module):
    def __init__(self, in_ch, out_ch, w_dim, upsample):
        super().__init__()
        self.upsample = upsample
        self.conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.adain = AdaIN(out_ch, w_dim)       # FIX: use AdaIN
        self.noise_strength = nn.Parameter(torch.zeros(1))
    def forward(self, x, w):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.conv(x)
        noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device)  # FIX: (B,1,H,W)
        x = x + self.noise_strength * noise
        x = F.leaky_relu(self.adain(x, w), 0.2, True)
        return x

# ---------------------------
# Generator  [SE-1][M-8][H-9]
# ---------------------------
class G_Lite(nn.Module):
    def __init__(self, z_dim=128, w_dim=128, ch=64, out_ch=3):
        super().__init__()
        self.map = Mapping(z_dim, w_dim)
        self.const = nn.Parameter(torch.randn(1, ch*4, 4, 4))   # FIX: learnable const 4×4
        self.b1 = StyledBlock(ch*4, ch*4, w_dim, upsample=False)
        self.b2 = StyledBlock(ch*4, ch*2, w_dim, upsample=True)
        self.b3 = StyledBlock(ch*2, ch,   w_dim, upsample=True)
        self.to_rgb = nn.Conv2d(ch, out_ch, 1)
    def forward(self, z):
        w = self.map(z)                          # FIX: use mapped style w
        x = self.const.repeat(z.size(0), 1, 1, 1).to(z.device)
        x = self.b1(x, w)
        x = self.b2(x, w)
        x = self.b3(x, w)
        x = self.to_rgb(x)
        return torch.tanh(x)                     # FIX: Tanh at output  [SE-1]

# ---------------------------
# Discriminator (logits; no Sigmoid)  [SE-2]
# ---------------------------
class D_Simple(nn.Module):
    def __init__(self, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, ch,   4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch, ch*2, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*2, ch*4, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*4, 1, 4, 1, 0)   # logits
        )
    def forward(self, x): return self.net(x).view(x.size(0))

G = G_Lite(z_dim, w_dim).to(device)
D = D_Simple().to(device)

# ---------------------------
# Loss & Optimizers  [E-4]
# ---------------------------
crit = nn.BCEWithLogitsLoss()
optG = torch.optim.Adam(G.parameters(), lr=g_lr, betas=(0.0, 0.99))   # FIX betas
optD = torch.optim.Adam(D.parameters(), lr=d_lr, betas=(0.0, 0.99))   # FIX betas

# ---------------------------
# Training (short sanity run)
# ---------------------------
G.train(); D.train()
for step, (real, _) in enumerate(loader):
    real = real.to(device)
    b = real.size(0)
    z = torch.randn(b, z_dim, device=device)

    # ---- D step ----
    with torch.no_grad():                      # or fake = G(z).detach()   [E-6]
        fake = G(z)
    optD.zero_grad()
    lossD = crit(D(real), torch.ones (b, device=device)) + crit(D(fake), torch.zeros(b, device=device))  # FIX labels & optimizer usage  [H-10]
            
    lossD.backward()
    optD.step()

    # ---- G step ----
    z = torch.randn(b, z_dim, device=device)
    optG.zero_grad()
    fake = G(z)
    lossG = crit(D(fake), torch.ones(b, device=device))       # push D(fake) → 1  [H-10]
    lossG.backward()
    optG.step()

    if step % 100 == 0:
        print(f"step {step:04d}  lossD={lossD.item():.3f}  lossG={lossG.item():.3f}")
    if step == 400:  # small sanity run
        break


## what was fixed (mapping to bug tags)

1. **\[SE-1]** Added `Tanh` at generator output (matches dataset normalization).
2. **\[SE-2]** Discriminator now outputs **logits** (no Sigmoid) for `BCEWithLogitsLoss`.
3. **\[SE-3]** Noise injected with shape **(B,1,H,W)** instead of (B,C,H,W).
4. **\[E-4]** Optimizer betas set to `(0, 0.99)` (acceptable alternatives: `(0, 0.9)`).
5. **\[E-5]** Mapping network normalizes `z` with `F.normalize(z, dim=1)` before MLP.
6. **\[E-6]** Fake images **detached** (or computed under `no_grad`) during D update.
7. **\[M-7]** Replaced BatchNorm with **AdaIN** inside styled blocks.
8. **\[M-8]** Generator uses a **learnable constant** (`nn.Parameter`) at 4×4 as the input.
9. **\[H-9]** Blocks consume **w** (mapped style) rather than raw `z`; same `w` reused across blocks in this lite version.
10. **\[H-10]** Correct labels, `zero_grad()`/`backward()`/`step()` order, and stepping the **right optimizer** per phase.
