# CycleGAN — Bug-Fix Labs - 10 Bugs to Fix

In [None]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.transforms.functional as TF
import itertools

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data  (Unpaired domains X, Y built from CIFAR-10 and CIFAR-10 with color jitter)
# ---------------------------
tf_X = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor()                             # BUG
])
tf_Y = transforms.Compose([
    transforms.Resize(64),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.ToTensor()                             # BUG
])

class UnpairedCIFAR(Dataset):
    def __init__(self, train=True):
        self.X = torchvision.datasets.CIFAR10('./data', train=train, download=True, transform=tf_X)
        self.Y = torchvision.datasets.CIFAR10('./data', train=train, download=True, transform=tf_Y)
    def __len__(self): return min(len(self.X), len(self.Y))
    def __getitem__(self, i):
        x,_ = self.X[i]
        y,_ = self.Y[(i*7) % len(self.Y)]   # different index ⇒ unpaired
        return x, y

ds = UnpairedCIFAR(train=True)
loader = DataLoader(ds, batch_size=4, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# ---------------------------
# Models
# ---------------------------
def conv(in_c, out_c, k, s, p, norm=True):
    layers = [nn.Conv2d(in_c, out_c, k, s, p)]
    # BUG
    if norm: layers += [nn.BatchNorm2d(out_c)]  # WRONG: use InstanceNorm2d
    layers += [nn.ReLU(True)]
    return nn.Sequential(*layers)

class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.c1 = conv(ch, ch, 3, 1, 1)   # uses BN (bug above)
        self.c2 = conv(ch, ch, 3, 1, 1)   # uses BN (bug above)
    def forward(self, x): return x + self.c2(self.c1(x))

class ResGen(nn.Module):
    def __init__(self, in_ch=3, ch=64, n_blocks=3):
        super().__init__()
        self.down = nn.Sequential(
            conv(in_ch, ch,   7, 1, 3),
            conv(ch,   ch*2,  3, 2, 1),
            conv(ch*2, ch*4,  3, 2, 1),
        )
        self.res = nn.Sequential(*[ResBlock(ch*4) for _ in range(n_blocks)])
        self.up  = nn.Sequential(
            nn.ConvTranspose2d(ch*4, ch*2, 3, 2, 1, output_padding=1), nn.ReLU(True),
            nn.ConvTranspose2d(ch*2, ch,   3, 2, 1, output_padding=1), nn.ReLU(True),
            nn.Conv2d(ch, 3, 7, 1, 3)   # BUG
                                        # BUG
        )
    def forward(self, x): return self.up(self.res(self.down(x)))

class PatchD(nn.Module):
    def __init__(self, in_ch=3, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, ch, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch, ch*2, 4, 2, 1), nn.BatchNorm2d(ch*2), nn.LeakyReLU(0.2, True),  # BUG
            nn.Conv2d(ch*2, ch*4, 4, 2, 1), nn.BatchNorm2d(ch*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*4, 1, 1, 1, 0),
            nn.Sigmoid()   # BUG
        )
    def forward(self, x): return self.net(x)  # (B,1,H,W)

# BUG
D_shared = PatchD().to(device)

G_XY = ResGen().to(device)   # X → Y
G_YX = ResGen().to(device)   # Y → X

# ---------------------------
# Losses & Optimizers
# ---------------------------
bce = nn.BCEWithLogitsLoss()
l1  = nn.L1Loss()

# BUG
optD = torch.optim.Adam(D_shared.parameters(), lr=2e-4, betas=(0.5, 0.999))
optG = torch.optim.Adam(itertools.chain(G_XY.parameters(), G_YX.parameters()), lr=2e-4, betas=(0.5, 0.999))

lambda_cyc = 10.0

def cycle_consistency_loss(G_XY, G_YX, x, y, lam=10.0):
    # BUG
    cyc_x = G_YX(x)                    # WRONG
    loss = lam * F.l1_loss(cyc_x, x)   # WRONG
    return loss

# ---------------------------
# Training loop (intentionally wrong)
# ---------------------------
for x, y in loader:
    x, y = x.to(device), y.to(device)

    # ---- Discriminator steps ----
    # BUG
    fake_y = G_XY(x)                                 # BUG
    fake_x = G_YX(y)                                 # BUG
    optD.zero_grad()

    logits_real_X = D_shared(x)
    logits_fake_X = D_shared(fake_x)                 # same D judging X-domain fakes
    # BUG
    t1 = torch.ones(x.size(0), 1, device=device)     # WRONG
    t0 = torch.zeros(x.size(0), 1, device=device)    # WRONG
    lossDX = bce(logits_real_X, t1) + bce(logits_fake_X, t0)

    logits_real_Y = D_shared(y)
    logits_fake_Y = D_shared(fake_y)
    lossDY = bce(logits_real_Y, t1) + bce(logits_fake_Y, t0)

    lossD = (lossDX + lossDY) * 0.5
    lossD.backward()
    optG.step()                                      # BUG

    # ---- Generator step ----
    # BUG
    optG.zero_grad()
    adv_X = bce(D_shared(G_YX(y)), torch.ones_like(logits_real_X))   # judge Y→X
    adv_Y = bce(D_shared(G_XY(x)), torch.zeros_like(logits_real_Y))  # BUG
    # BUG
    loss_cyc = cycle_consistency_loss(G_XY, G_YX, x, y, lambda_cyc)  # BUG
    lossG = adv_X + adv_Y + loss_cyc
    lossG.backward()
    # (missing)                            # BUG
    break  # keep the broken demo short

print("Your task: fix all bugs until CycleGAN training runs and X↔Y translations look reasonable.")
