# CycleGAN — Bug-Fixes

In [None]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.transforms.functional as TF
import itertools

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data  (normalize BOTH domains to [-1,1] for Tanh)   [SE-1]
# ---------------------------
norm = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])

tf_X = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    norm                                                    # FIX [SE-1]
])
tf_Y = transforms.Compose([
    transforms.Resize(64),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.ToTensor(),
    norm                                                    # FIX [SE-1] (same normalization)
])

class UnpairedCIFAR(Dataset):
    def __init__(self, train=True):
        self.X = torchvision.datasets.CIFAR10('./data', train=train, download=True, transform=tf_X)
        self.Y = torchvision.datasets.CIFAR10('./data', train=train, download=True, transform=tf_Y)
    def __len__(self): return min(len(self.X), len(self.Y))
    def __getitem__(self, i):
        x,_ = self.X[i]
        y,_ = self.Y[(i*7) % len(self.Y)]
        return x, y

ds = UnpairedCIFAR(train=True)
loader = DataLoader(ds, batch_size=4, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# ---------------------------
# Models  (InstanceNorm in G; logits/no Sigmoid in D)   [E-5][SE-2]
# ---------------------------
def c7s1(in_c, out_c):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, 7, 1, 3),
        nn.InstanceNorm2d(out_c),                            # FIX [E-5]
        nn.ReLU(True)
    )

def d_block(in_c, out_c):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, 3, 2, 1),
        nn.InstanceNorm2d(out_c),                            # FIX [E-5]
        nn.ReLU(True)
    )

class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(ch, ch, 3, 1, 1), nn.InstanceNorm2d(ch), nn.ReLU(True),
            nn.Conv2d(ch, ch, 3, 1, 1), nn.InstanceNorm2d(ch)
        )
    def forward(self, x): return x + self.block(x)

class ResGen(nn.Module):
    def __init__(self, in_ch=3, ch=64, n_blocks=6):
        super().__init__()
        self.down = nn.Sequential(
            c7s1(in_ch, ch),
            d_block(ch, ch*2),
            d_block(ch*2, ch*4),
        )
        self.res = nn.Sequential(*[ResBlock(ch*4) for _ in range(n_blocks)])
        self.up  = nn.Sequential(
            nn.ConvTranspose2d(ch*4, ch*2, 3, 2, 1, output_padding=1), nn.InstanceNorm2d(ch*2), nn.ReLU(True),
            nn.ConvTranspose2d(ch*2, ch,   3, 2, 1, output_padding=1), nn.InstanceNorm2d(ch),   nn.ReLU(True),
            nn.Conv2d(ch, 3, 7, 1, 3),
            nn.Tanh()                                            # use Tanh to match [-1,1] inputs     [SE-1]
        )
    def forward(self, x): return self.up(self.res(self.down(x)))

class PatchD(nn.Module):
    def __init__(self, in_ch=3, ch=64):
        super().__init__()
        # no BN/IN in D is common; we keep it simple and stable
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, ch,   4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch,   ch*2,  4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*2, ch*4,  4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*4, 1,     1, 1, 0)  # logits, NO Sigmoid                            [SE-2]
        )
    def forward(self, x): return self.net(x)  # (B,1,H',W')

# Two generators and TWO discriminators                    [H-10]
G_XY = ResGen().to(device)   # X → Y
G_YX = ResGen().to(device)   # Y → X
D_X  = PatchD().to(device)   # judges in domain X
D_Y  = PatchD().to(device)   # judges in domain Y

# ---------------------------
# Losses & Optimizers
# ---------------------------
bce = nn.BCEWithLogitsLoss()
l1  = nn.L1Loss()
lambda_cyc = 10.0

optG  = torch.optim.Adam(itertools.chain(G_XY.parameters(), G_YX.parameters()), lr=2e-4, betas=(0.5, 0.999))
optDX = torch.optim.Adam(D_X.parameters(), lr=2e-4, betas=(0.5, 0.999))
optDY = torch.optim.Adam(D_Y.parameters(), lr=2e-4, betas=(0.5, 0.999))      # separate optims         [E-7]

# Proper cycle & identity losses                           [M-8][E-6]
def cycle_consistency_loss(G_XY, G_YX, x, y, lam=10.0):
    cyc_x = G_YX(G_XY(x))
    cyc_y = G_XY(G_YX(y))
    return lam * (F.l1_loss(cyc_x, x) + F.l1_loss(cyc_y, y))        # FIX [M-8]

def identity_loss(G_XY, G_YX, x, y, lam=10.0):
    id_x = G_YX(x)
    id_y = G_XY(y)
    return 0.5 * lam * (F.l1_loss(id_x, x) + F.l1_loss(id_y, y))    # FIX [E-6]

# ---------------------------
# Training (short sanity run)
# ---------------------------
G_XY.train(); G_YX.train(); D_X.train(); D_Y.train()
for step, (x, y) in enumerate(loader):
    x, y = x.to(device), y.to(device)

    # ---- Discriminator X ----
    with torch.no_grad():                                             # or .detach()
        fake_x = G_YX(y)                                             # Y→X
    optDX.zero_grad()
    logits_real_X = D_X(x)
    logits_fake_X = D_X(fake_x)
    lossDX = bce(logits_real_X, torch.ones_like(logits_real_X)) + bce(logits_fake_X, torch.zeros_like(logits_fake_X))  # targets match patch shape   [SE-3]
            
    lossDX.backward()
    optDX.step()

    # ---- Discriminator Y ----
    with torch.no_grad():
        fake_y = G_XY(x)                                             # X→Y
    optDY.zero_grad()
    logits_real_Y = D_Y(y)
    logits_fake_Y = D_Y(fake_y)
    lossDY = bce(logits_real_Y, torch.ones_like(logits_real_Y)) + \
            bce(logits_fake_Y, torch.zeros_like(logits_fake_Y))
    lossDY.backward()
    optDY.step()

    # ---- Generators (adv + cycle + identity) ----
    optG.zero_grad()
    fake_x = G_YX(y)                                                 # recompute (no detach)    [M-9]
    fake_y = G_XY(x)
    adv_X = bce(D_X(fake_x), torch.ones_like(D_X(fake_x)))           # want fakes judged real
    adv_Y = bce(D_Y(fake_y), torch.ones_like(D_Y(fake_y)))
    loss_cyc = cycle_consistency_loss(G_XY, G_YX, x, y, lambda_cyc)   # FIX [M-8]
    loss_id  = identity_loss(G_XY, G_YX, x, y, lambda_cyc)            # FIX [E-6]
    lossG = adv_X + adv_Y + loss_cyc + loss_id
    lossG.backward()
    optG.step()

    if step % 100 == 0:
        print(f"step {step:04d} | DX={lossDX.item():.3f} DY={lossDY.item():.3f} G={lossG.item():.3f} (advX={adv_X.item():.3f}, advY={adv_Y.item():.3f}, cyc={loss_cyc.item():.3f}, id={loss_id.item():.3f})")
    if step == 400:  # small sanity run
        break


## what was fixed (mapping to bug tags)

1. **\[SE-1]** identical normalization for both domains to `(-1,1)` and `Tanh` generator output range.
2. **\[SE-2]** discriminators output **logits** (no Sigmoid) when using `BCEWithLogitsLoss`.
3. **\[SE-3]** patch targets now use `ones_like/zeros_like(logits)` to match spatial shape.
4. **\[SE-4]** fakes **detached** (or computed under `no_grad`) during D updates.
5. **\[E-5]** **InstanceNorm2d** in generators (and no BN in discriminators).
6. **\[E-6]** added **identity loss**: `0.5*λ*(|G_YX(x)-x|₁ + |G_XY(y)-y|₁)`.
7. **\[E-7]** separate discriminators **D\_X** and **D\_Y** with separate optimizers **optDX/optDY**.
8. **\[M-8]** correct **cycle consistency**: `λ*(|G_YX(G_XY(x))-x|₁ + |G_XY(G_YX(y))-y|₁)`.
9. **\[M-9]** recompute fakes for generator step (no detach) and manage `zero_grad()/step()` in the proper order.
10. **\[H-10]** replaced the single shared D with **two discriminators** and stepped the **correct** optimizer for each update.
