# Pix2Pix — Bug-Fix Labs - 10 Bugs to Fix

In [None]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.transforms.functional as TF

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data  (paired A->B built from CIFAR-10; B is a flipped+blurred A)
# ---------------------------
base_tf = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),                     # BUG
])

class PairedCIFAR(Dataset):
    def __init__(self, train=True):
        self.ds = torchvision.datasets.CIFAR10('./data', train=train, download=True, transform=base_tf)
        self.blur = transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0))
    def __len__(self): return len(self.ds)
    def __getitem__(self, i):
        a, _ = self.ds[i]                      # A = original image
        b = self.blur(TF.hflip(a))             # B = "target" image
        return a, b

train_set = PairedCIFAR(train=True)
loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# ---------------------------
# Hyperparams
# ---------------------------
lambda_l1 = 100.0
g_lr = 2e-4
d_lr = 2e-4

# ---------------------------
# Models
# ---------------------------
class UNetG(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, ch=64):
        super().__init__()
        # Encoder
        self.e1 = nn.Conv2d(in_ch, ch, 4, 2, 1)                                 # 64 -> 32
        self.e2 = nn.Sequential(nn.LeakyReLU(0.2, True), nn.Conv2d(ch, ch*2, 4, 2, 1), nn.BatchNorm2d(ch*2))  # 32 -> 16
        self.e3 = nn.Sequential(nn.LeakyReLU(0.2, True), nn.Conv2d(ch*2, ch*4, 4, 2, 1), nn.BatchNorm2d(ch*4))# 16 -> 8
        self.e4 = nn.Sequential(nn.LeakyReLU(0.2, True), nn.Conv2d(ch*4, ch*8, 4, 2, 1), nn.BatchNorm2d(ch*8))# 8 -> 4
        # Decoder
        self.d1 = nn.Sequential(nn.ReLU(True), nn.ConvTranspose2d(ch*8, ch*4, 4, 2, 1), nn.BatchNorm2d(ch*4)) # 4 -> 8
        self.d2 = nn.Sequential(nn.ReLU(True), nn.ConvTranspose2d(ch*8, ch*2, 4, 2, 1), nn.BatchNorm2d(ch*2)) # 8 -> 16
        self.d3 = nn.Sequential(nn.ReLU(True), nn.ConvTranspose2d(ch*4, ch,   4, 2, 1), nn.BatchNorm2d(ch))   # 16 -> 32
        self.out = nn.Sequential(nn.ReLU(True), nn.ConvTranspose2d(ch*2, out_ch, 4, 2, 1))                     # 32 -> 64
        # BUG

    def forward(self, a):
        e1 = self.e1(a)                 # (B,ch,32,32)
        e2 = self.e2(e1)                # (B,2ch,16,16)
        e3 = self.e3(e2)                # (B,4ch,8,8)
        e4 = self.e4(e3)                # (B,8ch,4,4)

        d1 = self.d1(e4)                # (B,4ch,8,8)
        # BUG
        d2 = self.d2(torch.cat([d1, e2], dim=1))   # WRONG
        # BUG
        d3 = self.d3(torch.cat([d2, e1], dim=1))   # WRONG
        out = self.out(torch.cat([d3, e4], dim=1)) # WRONG
        return out                                 # and missing ??

class PatchD(nn.Module):
    def __init__(self, in_ch=6, ch=64):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(in_ch, ch, 4, 2, 1), nn.LeakyReLU(0.2, True),     # 64 -> 32
            nn.Conv2d(ch, ch*2, 4, 2, 1), nn.BatchNorm2d(ch*2), nn.LeakyReLU(0.2, True),   # 32 -> 16
            nn.Conv2d(ch*2, ch*4, 4, 2, 1), nn.BatchNorm2d(ch*4), nn.LeakyReLU(0.2, True), # 16 -> 8
        )
        # BUG
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(ch*4, 1), nn.Sigmoid()       # BUG
        )
    def forward(self, a, b):
        # BUG
        x = torch.cat([b, a], dim=1)                 # WRONG
        h = self.body(x)
        return self.head(h)                          # returns (B,1) scalar, not (B,1,H,W)

G = UNetG().to(device)
D = PatchD().to(device)

# ---------------------------
# Loss & Optimizers
# ---------------------------
bce = nn.BCEWithLogitsLoss()
l1  = nn.L1Loss()
optG = torch.optim.Adam(G.parameters(), lr=g_lr, betas=(0.5, 0.999))
optD = torch.optim.Adam(D.parameters(), lr=d_lr, betas=(0.5, 0.999))

# ---------------------------
# Training loop (intentionally wrong)
# ---------------------------
for a, b in loader:
    a, b = a.to(device), b.to(device)
    # ---- D step ----
    fake_b = G(a)                                   # BUG
    # BUG
    # (Here D returns (B,1); but PatchGAN should be (B,1,H,W).
    logits_real = D(a, b)
    logits_fake = D(a, fake_b)
    targ_real = torch.ones(a.size(0), 1, device=device)   # WRONG
    targ_fake = torch.zeros(a.size(0), 1, device=device)  # WRONG
    lossD = bce(logits_real, targ_real) + bce(logits_fake, targ_fake)
    lossD.backward()
    optG.step()                                    # BUG

    # ---- G step ----
    fake_b = G(a)
    # BUG
    # BUG
    # BUG
    lossG = bce(D(a, fake_b), torch.zeros_like(logits_real))   # WRONG
    lossG.backward()
    # (missing)                               # BUG
    # (missing)        # BUG
    break  # keep the broken demo short

print("Your task: fix all bugs until Pix2Pix training runs and outputs look reasonable.")
