# CycleGAN

from photos to Monet-style images and vice-versa.

a ResNet-based generator for the architecture. building a ResNet architecture. This involves creating ResidualBlock and assembling the generator model.
a PatchGAN discriminator, cycle consistency loss, identity loss, and adversarial losses. Also choosing optimizers and outlining the training loop.

The Generator is a ResNet-based network with reflection padding, downsampling, several residual blocks, and upsampling.
The Discriminator is a PatchGAN that outputs a feature map (not a single scalar) used to classify patches as real or fake.

Loss Functions:
The implementation uses MSELoss for the adversarial loss (common in CycleGAN implementations), L1Loss for both cycle consistency and identity losses.





In [None]:
import os
import glob
import torch
import itertools
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as K

from PIL import Image
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
import numpy as np

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


########################################
# 1. DATALOADER
########################################

In [None]:
# function to plot samples and metadata of a dataset
def plot_dataset(
    dataset,
    n_rows=4,
    figsize=(6, 6),
    denormalize=False,
):
    print(f'Number of samples: {len(dataset)}')
    print(f'Number of samples A: {len(dataset.files_A)}')
    print(f'Number of samples B: {len(dataset.files_B)}')
    print(f'Sample shape A: {dataset[0]["A"].shape}')
    print(f'Sample shape B: {dataset[0]["B"].shape}')

    samples = np.random.randint(0, len(dataset), size=n_rows)
    fig, axes = plt.subplots(n_rows, 2, figsize=figsize)

    for i in range(n_rows):
        imageA, imageB = dataset[samples[i]].values()
        if denormalize:
            imageA = K.normalize(imageA, -1, 2)
            imageA = imageA.clamp(0, 1)

            imageB = K.normalize(imageB, -1, 2)
            imageB = imageB.clamp(0, 1)

        imageA = imageA.permute(1, 2, 0)
        imageB = imageB.permute(1, 2, 0)
        im = axes[i, 0].imshow(imageA)
        if imageA.shape[2] == 1:
            im.set_cmap('gray')

        im = axes[i, 1].imshow(imageB)
        if imageB.shape[2] == 1:
            im.set_cmap('gray')
        axes[i, 0].set_title('A')
        axes[i, 1].set_title('B')
        axes[i, 0].axis('off')
        axes[i, 1].axis('off')

    plt.show()

In [2]:
# dataloader

class ImageDataset(Dataset):
    """
    A dataset class for unpaired image-to-image translation.
    Expects the following folder structure inside the dataset root:

        <root>/
            trainA/
            trainB/
            testA/
            testB/

    For Monet2Photo, you can treat one domain as 'A' (e.g., photos)
    and the other as 'B' (e.g., Monet-style paintings).
    """
    def __init__(self, root, transforms_=None, mode='train'):
        self.transform = transforms_
        self.files_A = sorted(glob.glob(os.path.join(root, f'{mode}A') + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, f'{mode}B') + '/*.*'))

    def __getitem__(self, index):
        img_A = Image.open(self.files_A[index % len(self.files_A)]).convert('RGB')
        img_B = Image.open(self.files_B[index % len(self.files_B)]).convert('RGB')

        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return {'A': img_A, 'B': img_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

# Define image transformations – these are similar to those used in the CycleGAN paper
transforms_ = T.Compose([
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True),
    T.Resize(int(256*1.12), Image.BICUBIC),
    T.RandomCrop(256),
    T.RandomHorizontalFlip(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Set the dataset root
dataset_root = "/pgeoprj2/ciag2024/dados/cycleGAN/monet2photo/"
train_dataset = ImageDataset(root=dataset_root, transforms_=transforms_, mode='train')
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)

In [3]:
print("Number of images in domain A (trainA):", len(train_dataset.files_A))
print("Number of images in domain B (trainB):", len(train_dataset.files_B))

Number of images in domain A (trainA): 1072
Number of images in domain B (trainB): 6287


In [None]:
plot_dataset(train_dataset, denormalize=True)

########################################
# 2. MODEL ARCHITECTURES
########################################

# Generator

In [4]:
# Generator

# --- Residual Block for the Generator ---
class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3),
            nn.InstanceNorm2d(dim),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3),
            nn.InstanceNorm2d(dim)
        )
    def forward(self, x):
        return x + self.block(x)

# --- Generator: ResNet-based ---
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, kernel_size=7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_nc, kernel_size=7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Discriminator

In [5]:
# --- Discriminator: PatchGAN ---
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()
        model = [
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        model += [
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        model += [
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        model += [
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        model += [
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


########################################
# 3. INSTANTIATE MODELS & LOSS FUNCTIONS
########################################

In [6]:
# 3. INSTANTIATE MODELS & LOSS FUNCTIONS

# For Monet2Photo, let’s assume domain A is 'photo' and domain B is 'Monet'.
# (Depending on your interpretation, you may swap the labels.)
G_A2B = Generator(input_nc=3, output_nc=3).to(device)  # Translates Photo -> Monet
G_B2A = Generator(input_nc=3, output_nc=3).to(device)  # Translates Monet -> Photo
D_A = Discriminator(input_nc=3).to(device)  # Discriminator for domain A (Photo)
D_B = Discriminator(input_nc=3).to(device)  # Discriminator for domain B (Monet)

# Losses
criterion_GAN = nn.MSELoss().to(device)
criterion_cycle = nn.L1Loss().to(device)
criterion_identity = nn.L1Loss().to(device)

########################################
# 4. OPTIMIZERS & LEARNING RATE SCHEDULERS
########################################

In [7]:
########################################
# 4. OPTIMIZERS & LEARNING RATE SCHEDULERS
########################################

lr = 0.0002
beta1 = 0.5
optimizer_G = optim.Adam(itertools.chain(G_A2B.parameters(), G_B2A.parameters()),
                         lr=lr, betas=(beta1, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=lr, betas=(beta1, 0.999))

########################################
# 5. TRAINING LOOP (simplified for demonstration)
########################################

In [None]:
########################################
# 5. TRAINING LOOP (simplified for demonstration)
########################################

num_epochs = 20  # For demonstration; increase for better results
lambda_cycle = 10.0
lambda_identity = 5.0

# For target labels in GAN loss
# Real label = 1.0, Fake label = 0.0
real_label = 1.0
fake_label = 0.0

for epoch in range(num_epochs):
    for i, batch in (
        pbar := tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch')
    ):
        # Set model inputs
        real_A = batch['A'].to(device)  # Photo
        real_B = batch['B'].to(device)  # Monet

        #### Train Generators G_A2B and G_B2A ####
        optimizer_G.zero_grad()

        # Identity loss
        # G_B2A(real_A) should be close to real_A if real_A is already from domain A
        loss_id_A = criterion_identity(G_B2A(real_A), real_A)
        loss_id_B = criterion_identity(G_A2B(real_B), real_B)
        loss_identity = (loss_id_A + loss_id_B) * lambda_identity

        # GAN loss
        fake_B = G_A2B(real_A)
        pred_fake_B = D_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake_B, torch.full(pred_fake_B.shape, real_label, device=device))

        fake_A = G_B2A(real_B)
        pred_fake_A = D_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake_A, torch.full(pred_fake_A.shape, real_label, device=device))
        loss_GAN = loss_GAN_A2B + loss_GAN_B2A

        # Cycle consistency loss
        rec_A = G_B2A(fake_B)
        loss_cycle_A = criterion_cycle(rec_A, real_A)
        rec_B = G_A2B(fake_A)
        loss_cycle_B = criterion_cycle(rec_B, real_B)
        loss_cycle = (loss_cycle_A + loss_cycle_B) * lambda_cycle

        # Total generator loss
        loss_G = loss_identity + loss_GAN + loss_cycle
        loss_G.backward()
        optimizer_G.step()

        #### Train Discriminator D_A ####
        optimizer_D_A.zero_grad()
        # Real loss for D_A
        pred_real_A = D_A(real_A)
        loss_D_A_real = criterion_GAN(pred_real_A, torch.full(pred_real_A.shape, real_label, device=device))
        # Fake loss for D_A
        pred_fake_A = D_A(fake_A.detach())
        loss_D_A_fake = criterion_GAN(pred_fake_A, torch.full(pred_fake_A.shape, fake_label, device=device))
        loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        #### Train Discriminator D_B ####
        optimizer_D_B.zero_grad()
        # Real loss for D_B
        pred_real_B = D_B(real_B)
        loss_D_B_real = criterion_GAN(pred_real_B, torch.full(pred_real_B.shape, real_label, device=device))
        # Fake loss for D_B
        pred_fake_B = D_B(fake_B.detach())
        loss_D_B_fake = criterion_GAN(pred_fake_B, torch.full(pred_fake_B.shape, fake_label, device=device))
        loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()

        pbar.set_description(f'[epoch {epoch+1}/{num_epochs}] G_loss: {loss_G.item():.4f} DA_loss: {loss_D_A.item():.4f} DB_loss: {loss_D_B.item():.4f}')

  0%|          | 0/393 [00:00<?, ?batch/s]

########################################
# 6. VISUALIZATION FUNCTION
########################################

In [None]:
########################################
# 6. VISUALIZATION FUNCTION
########################################

def visualize_cycle(generator_AB, generator_BA, dataset, num_samples=3):
    """
    Visualizes a few examples of cycle translation.
    For each sample:
      - Translate A -> B, then back to A.
      - Translate B -> A, then back to B.
    """
    generator_AB.eval()
    generator_BA.eval()

    plt.figure(figsize=(num_samples*6, 8))

    # Unnormalize: reverse the normalization from [-1,1] to [0,1]
    def unnormalize(img):
        return img * 0.5 + 0.5

    for i in range(num_samples):
        sample = dataset[i]
        real_A = sample['A'].unsqueeze(0).to(device)
        real_B = sample['B'].unsqueeze(0).to(device)

        with torch.no_grad():
            fake_B = generator_AB(real_A)
            rec_A = generator_BA(fake_B)
            fake_A = generator_BA(real_B)
            rec_B = generator_AB(fake_A)

        # Unnormalize images for display and move channels to last dimension
        real_A_disp = unnormalize(real_A.squeeze(0)).permute(1,2,0).cpu().numpy()
        fake_B_disp = unnormalize(fake_B.squeeze(0)).permute(1,2,0).cpu().numpy()
        rec_A_disp = unnormalize(rec_A.squeeze(0)).permute(1,2,0).cpu().numpy()
        real_B_disp = unnormalize(real_B.squeeze(0)).permute(1,2,0).cpu().numpy()
        fake_A_disp = unnormalize(fake_A.squeeze(0)).permute(1,2,0).cpu().numpy()
        rec_B_disp = unnormalize(rec_B.squeeze(0)).permute(1,2,0).cpu().numpy()

        # Plot row for domain A: real A, fake B, reconstructed A
        plt.subplot(num_samples, 6, i*6 + 1)
        plt.imshow(real_A_disp)
        plt.title("Real A")
        plt.axis("off")

        plt.subplot(num_samples, 6, i*6 + 2)
        plt.imshow(fake_B_disp)
        plt.title("Fake B")
        plt.axis("off")

        plt.subplot(num_samples, 6, i*6 + 3)
        plt.imshow(rec_A_disp)
        plt.title("Rec A")
        plt.axis("off")

        # Plot row for domain B: real B, fake A, reconstructed B
        plt.subplot(num_samples, 6, i*6 + 4)
        plt.imshow(real_B_disp)
        plt.title("Real B")
        plt.axis("off")

        plt.subplot(num_samples, 6, i*6 + 5)
        plt.imshow(fake_A_disp)
        plt.title("Fake A")
        plt.axis("off")

        plt.subplot(num_samples, 6, i*6 + 6)
        plt.imshow(rec_B_disp)
        plt.title("Rec B")
        plt.axis("off")

    plt.tight_layout()
    plt.show()



In [None]:
# Example: Visualize 3 samples after training
visualize_cycle(G_A2B, G_B2A, train_dataset, num_samples=3)

# A:Monet  B:photos

# Randomly selecting examples

We redefine the previous function by adding added the parameter random_sample (defaulting to False). When random_sample is True, we choose a random index for each sample using Python's random.randint.

Index Selection:
In each iteration, if random_sample is True, the index idx is chosen randomly from [0, len(dataset)-1]; otherwise, it is set to the loop counter i.

Remaining Code:
The rest of the function remains the same, generating cycle translation outputs and plotting them.

In [None]:
import random

def visualize_cycle(generator_AB, generator_BA, dataset, num_samples=3, random_sample=False):
    """
    Visualizes a few examples of cycle translation.
    For each sample:
      - Translate A -> B, then back to A.
      - Translate B -> A, then back to B.

    Args:
        generator_AB (nn.Module): Generator for translating from A to B.
        generator_BA (nn.Module): Generator for translating from B to A.
        dataset (Dataset): Dataset containing paired images with keys 'A' and 'B'.
        num_samples (int): Number of samples to visualize.
        random_sample (bool): If True, select samples randomly from the dataset.
                              If False, select the first num_samples images.
    """
    generator_AB.eval()
    generator_BA.eval()

    plt.figure(figsize=(num_samples*6, 8))

    # Unnormalize: reverse the normalization from [-1,1] to [0,1]
    def unnormalize(img):
        return img * 0.5 + 0.5

    dataset_length = len(dataset)
    for i in range(num_samples):
        # Choose index either sequentially or at random
        if random_sample:
            idx = random.randint(0, dataset_length - 1)
        else:
            idx = i

        sample = dataset[idx]
        real_A = sample['A'].unsqueeze(0).to(device)
        real_B = sample['B'].unsqueeze(0).to(device)

        with torch.no_grad():
            fake_B = generator_AB(real_A)
            rec_A = generator_BA(fake_B)
            fake_A = generator_BA(real_B)
            rec_B = generator_AB(fake_A)

        # Unnormalize images for display and move channels to last dimension
        real_A_disp = unnormalize(real_A.squeeze(0)).permute(1,2,0).cpu().numpy()
        fake_B_disp = unnormalize(fake_B.squeeze(0)).permute(1,2,0).cpu().numpy()
        rec_A_disp = unnormalize(rec_A.squeeze(0)).permute(1,2,0).cpu().numpy()
        real_B_disp = unnormalize(real_B.squeeze(0)).permute(1,2,0).cpu().numpy()
        fake_A_disp = unnormalize(fake_A.squeeze(0)).permute(1,2,0).cpu().numpy()
        rec_B_disp = unnormalize(rec_B.squeeze(0)).permute(1,2,0).cpu().numpy()

        # Plot row for domain A: real A, fake B, reconstructed A
        plt.subplot(num_samples, 6, i*6 + 1)
        plt.imshow(real_A_disp)
        plt.title("Real A")
        plt.axis("off")

        plt.subplot(num_samples, 6, i*6 + 2)
        plt.imshow(fake_B_disp)
        plt.title("Fake B")
        plt.axis("off")

        plt.subplot(num_samples, 6, i*6 + 3)
        plt.imshow(rec_A_disp)
        plt.title("Rec A")
        plt.axis("off")

        # Plot row for domain B: real B, fake A, reconstructed B
        plt.subplot(num_samples, 6, i*6 + 4)
        plt.imshow(real_B_disp)
        plt.title("Real B")
        plt.axis("off")

        plt.subplot(num_samples, 6, i*6 + 5)
        plt.imshow(fake_A_disp)
        plt.title("Fake A")
        plt.axis("off")

        plt.subplot(num_samples, 6, i*6 + 6)
        plt.imshow(rec_B_disp)
        plt.title("Rec B")
        plt.axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
# For sequential sampling:
# visualize_cycle(generator_AB, generator_BA, dataset, num_samples=3, random_sample=False)

# For random sampling:
visualize_cycle(G_A2B, G_B2A, train_dataset, num_samples=3, random_sample=True)
