# Pix2Pix — Bug-Fixes

In [None]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.transforms.functional as TF

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data  (normalize both domains to [-1,1] for Tanh)   [SE-1]
# ---------------------------
base_tf = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])   # FIX [SE-1]
])

class PairedCIFAR(Dataset):
    def __init__(self, train=True):
        self.ds = torchvision.datasets.CIFAR10('./data', train=train, download=True, transform=base_tf)
        self.blur = transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0))
    def __len__(self): return len(self.ds)
    def __getitem__(self, i):
        a, _ = self.ds[i]           # A = source
        b = self.blur(TF.hflip(a))  # B = target (paired)
        return a, b

train_set = PairedCIFAR(train=True)
loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# ---------------------------
# Hyperparams
# ---------------------------
lambda_l1 = 100.0
g_lr = 2e-4
d_lr = 2e-4
betas = (0.5, 0.999)

# ---------------------------
# Models
# ---------------------------
class UNetG(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, ch=64):
        super().__init__()
        # Encoder
        self.e1 = nn.Conv2d(in_ch, ch, 4, 2, 1)  # 64 -> 32
        self.e2 = nn.Sequential(nn.LeakyReLU(0.2, True), nn.Conv2d(ch, ch*2, 4, 2, 1), nn.BatchNorm2d(ch*2))  # 32 -> 16
        self.e3 = nn.Sequential(nn.LeakyReLU(0.2, True), nn.Conv2d(ch*2, ch*4, 4, 2, 1), nn.BatchNorm2d(ch*4))# 16 -> 8
        self.e4 = nn.Sequential(nn.LeakyReLU(0.2, True), nn.Conv2d(ch*4, ch*8, 4, 2, 1), nn.BatchNorm2d(ch*8))# 8 -> 4
        # Decoder (with correct skip wiring)   [M-8]
        self.d1 = nn.Sequential(nn.ReLU(True), nn.ConvTranspose2d(ch*8, ch*4, 4, 2, 1), nn.BatchNorm2d(ch*4)) # 4 -> 8
        self.d2 = nn.Sequential(nn.ReLU(True), nn.ConvTranspose2d(ch*8, ch*2, 4, 2, 1), nn.BatchNorm2d(ch*2)) # 8 -> 16
        self.d3 = nn.Sequential(nn.ReLU(True), nn.ConvTranspose2d(ch*4, ch,   4, 2, 1), nn.BatchNorm2d(ch))   # 16 -> 32
        self.out = nn.Sequential(nn.ReLU(True), nn.ConvTranspose2d(ch*2, out_ch, 4, 2, 1), nn.Tanh())          # 32 -> 64  [SE-2]
    def forward(self, a):
        e1 = self.e1(a)                 # (B,ch,32,32)
        e2 = self.e2(e1)                # (B,2ch,16,16)
        e3 = self.e3(e2)                # (B,4ch,8,8)
        e4 = self.e4(e3)                # (B,8ch,4,4)

        d1 = self.d1(e4)                                # (B,4ch,8,8)
        d2 = self.d2(torch.cat([d1, e3], dim=1))        # FIX [M-8]
        d3 = self.d3(torch.cat([d2, e2], dim=1))        # FIX [M-8]
        out = self.out(torch.cat([d3, e1], dim=1))      # FIX [M-8] + Tanh at end [SE-2]
        return out

class PatchD(nn.Module):
    def __init__(self, in_ch=6, ch=64):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(in_ch, ch, 4, 2, 1), nn.LeakyReLU(0.2, True),     # 64 -> 32
            nn.Conv2d(ch, ch*2, 4, 2, 1), nn.BatchNorm2d(ch*2), nn.LeakyReLU(0.2, True),   # 32 -> 16
            nn.Conv2d(ch*2, ch*4, 4, 2, 1), nn.BatchNorm2d(ch*4), nn.LeakyReLU(0.2, True), # 16 -> 8
        )
        # Produce a patch grid (B,1,H,W), NOT a scalar; no Sigmoid with BCEWithLogits   [M-9] [SE-3]
        self.head = nn.Conv2d(ch*4, 1, 1, 1, 0)   # FIX [M-9], [SE-3]
    def forward(self, a, b):
        x = torch.cat([a, b], dim=1)              # FIX concat order [E-5]
        h = self.body(x)
        return self.head(h)                       # (B,1,H,W) patch logits

G = UNetG().to(device)
D = PatchD().to(device)

# ---------------------------
# Loss & Optimizers
# ---------------------------
bce = nn.BCEWithLogitsLoss()
l1  = nn.L1Loss()
optG = torch.optim.Adam(G.parameters(), lr=g_lr, betas=betas)
optD = torch.optim.Adam(D.parameters(), lr=d_lr, betas=betas)

# ---------------------------
# Training (short sanity run)
# ---------------------------
G.train(); D.train()
for step, (a, b) in enumerate(loader):
    a, b = a.to(device), b.to(device)

    # ---- D step ----
    with torch.no_grad():                 # or fake_b = G(a).detach()     [E-6]
        fake_b = G(a)
    optD.zero_grad()
    logits_real = D(a, b)
    logits_fake = D(a, fake_b)
    targ_real = torch.ones_like(logits_real)      # match spatial shape   [SE-4]
    targ_fake = torch.zeros_like(logits_fake)
    lossD = bce(logits_real, targ_real) + bce(logits_fake, targ_fake)
    lossD.backward()
    optD.step()                                   # correct optimizer      [H-10]

    # ---- G step ----
    optG.zero_grad()
    fake_b = G(a)
    logits_fake = D(a, fake_b)
    adv = bce(logits_fake, torch.ones_like(logits_fake))   # target ones   [E-7]
    rec = lambda_l1 * l1(fake_b, b)                        # add L1 term   [E-7]
    lossG = adv + rec
    lossG.backward()
    optG.step()                                   # correct optimizer      [H-10]

    if step % 100 == 0:
        print(f"step {step:05d}  lossD={lossD.item():.3f}  lossG={lossG.item():.3f}  (adv={adv.item():.3f}, L1={rec.item():.3f})")
    if step == 400:  # small sanity run
        break


## what was fixed (mapping to bug tags)

1. **\[SE-1]** normalized both domains to `(-1,1)`
2. **\[SE-2]** added `Tanh` at generator output
3. **\[SE-3]** discriminator now outputs **logits** (no Sigmoid) with `BCEWithLogitsLoss`
4. **\[SE-4]** patch targets now `ones_like/zeros_like(logits)` to match spatial shape
5. **\[E-5]** fixed concat order in D to `[a|b]`
6. **\[E-6]** detached `G(a)` during D step
7. **\[E-7]** added L1 reconstruction term with `λ=100` and correct adversarial target (ones)
8. **\[M-8]** corrected UNet skip connections: `d2(cat[d1,e3])`, `d3(cat[d2,e2])`, `out(cat[d3,e1])`
9. **\[M-9]** PatchGAN now outputs a **grid** via `Conv2d(...,1,1,1,0)`
10. **\[H-10]** proper optimizer usage and gradient zeroing per step