# StyleGAN-lite — Bug-Fix Labs - 10 Bugs to Fix

In [None]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data  (CIFAR-10 → 16×16)
# ---------------------------
transform = transforms.Compose([
    transforms.Resize(16),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])   # OK for Tanh outputs
])
ds = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform)
loader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

z_dim = 128
w_dim = 128
g_lr  = 2e-4
d_lr  = 2e-4

# ---------------------------
# Mapping (WRONG / incomplete on purpose)
# ---------------------------
class Mapping(nn.Module):
    def __init__(self, z_dim=128, w_dim=128, depth=2):
        super().__init__()
        layers = []
        for _ in range(depth):
            layers += [nn.Linear(z_dim, w_dim), nn.LeakyReLU(0.2, True)]
            z_dim = w_dim
        self.net = nn.Sequential(*layers)
    def forward(self, z):
        # BUG
        return self.net(z)

# ---------------------------
# AdaIN (OK)
# ---------------------------
class AdaIN(nn.Module):
    def __init__(self, ch, w_dim):
        super().__init__()
        self.scale = nn.Linear(w_dim, ch)
        self.shift = nn.Linear(w_dim, ch)
    def forward(self, x, w):
        mu = x.mean(dim=(2,3), keepdim=True)
        sigma = x.std(dim=(2,3), keepdim=True) + 1e-8
        x_n = (x - mu) / sigma
        s = self.scale(w).unsqueeze(-1).unsqueeze(-1)
        b = self.shift(w).unsqueeze(-1).unsqueeze(-1)
        return x_n * s + b

# ---------------------------
# Styled block (WRONG on purpose)
# ---------------------------
class StyledBlock(nn.Module):
    def __init__(self, in_ch, out_ch, w_dim, upsample):
        super().__init__()
        self.upsample = upsample
        self.conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        # BUG
        self.bn = nn.BatchNorm2d(out_ch)             # WRONG
        self.noise_strength = nn.Parameter(torch.zeros(1))
    def forward(self, x, w):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.conv(x)
        # BUG
        noise = torch.randn(x.size(0), x.size(1), x.size(2), x.size(3), device=x.device)
        x = x + self.noise_strength * noise
        # WRONG:
        x = F.leaky_relu(self.bn(x), 0.2, True)
        return x

# ---------------------------
# Generator (WRONG in several places)
# ---------------------------
class G_Lite(nn.Module):
    def __init__(self, z_dim=128, w_dim=128, ch=64, out_ch=3):
        super().__init__()
        self.map = Mapping(z_dim, w_dim)
        # BUG
        self.const = torch.randn(1, ch*4, 4, 4)      # WRONG
        self.b1 = StyledBlock(ch*4, ch*4, w_dim, upsample=False)
        self.b2 = StyledBlock(ch*4, ch*2, w_dim, upsample=True)
        self.b3 = StyledBlock(ch*2, ch,   w_dim, upsample=True)
        self.to_rgb = nn.Conv2d(ch, out_ch, 1)
        # BUG
    def forward(self, z):
        # BUG
        # WRONG
        x = self.const.repeat(z.size(0), 1, 1, 1)    # const is not a Parameter; also not on the right device potentially
        x = self.b1(x, z)                            # WRONG
        x = self.b2(x, z)                            # WRONG
        x = self.b3(x, z)                            # WRONG
        x = self.to_rgb(x)                           # BUG
        return x

# ---------------------------
# Discriminator (WRONG head)
# ---------------------------
class D_Simple(nn.Module):
    def __init__(self, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, ch,   4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch, ch*2, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*2, ch*4, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*4, 1, 4, 1, 0),
            nn.Sigmoid()                     # BUG
        )
    def forward(self, x): return self.net(x).view(x.size(0))

G = G_Lite(z_dim, w_dim).to(device)
D = D_Simple().to(device)

# ---------------------------
# Loss & Optimizers (partly wrong)
# ---------------------------
crit = nn.BCEWithLogitsLoss()
# BUG
optG = torch.optim.Adam(G.parameters(), lr=g_lr, betas=(0.9, 0.999))   # WRONG
optD = torch.optim.Adam(D.parameters(), lr=d_lr, betas=(0.9, 0.999))   # WRONG

# ---------------------------
# Training loop (intentionally wrong)
# ---------------------------
for real, _ in loader:
    real = F.interpolate(real.to(device), size=16, mode='bilinear', align_corners=False)
    b = real.size(0)
    z = torch.randn(b, z_dim, device=device)

    # ---- D step ----
    fake = G(z)                                    # BUG
    # BUG
    lossD = crit(D(real), torch.zeros(b, device=device)) + crit(D(fake), torch.ones (b, device=device))      
    lossD.backward()
    optG.step()                                    # BUG

    # ---- G step ----
    z = torch.randn(b, z_dim, device=device)
    fake = G(z)
    # BUG
    lossG = crit(D(fake), torch.zeros(b, device=device))
    lossG.backward()
    # (missing)
    # (missing)

    break  # keep the broken demo short

print("Your task: fix all bugs until StyleGAN-lite training runs and images look reasonable.")
