# Implementing the CycleGAN (vanilla architecture)

## Objetivo

Reproduzir uma CycleGAN vanilla, baseada no artigo [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/pdf/1703.10593v7).

Mais especificamente, construir, treinar e documentar esta arquitetura de GAN utilizando Pytorch, baseado na implementação em [https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py).

In [2]:
import torch
import matplotlib.pyplot as plt
torch.manual_seed(0)
from torchvision.utils import make_grid
from torchvision import datasets
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## CycleGAN generator

Each CycleGAN generator has three sections

- Encoder
- Transformer
- Decoder

The input image is passed into the encoder. The encoder extracts features from the input image by using Convolutions and compressed the representation of image but increase the number of channels.

The encoder consists of 3 convolution that reduces the representation by 1/4 th of actual image size. Consider an image of size (256, 256, 3) which we input into the encoder, the output of encoder will be (64, 64, 256).

Then the output of encoder after activation function is applied is passed into the transformer. The transformer contains 6 or 9 residual blocks based on the size of input.

The output of transformer is then passed into the decoder which uses 2 -deconvolution block of fraction strides to increase the size of representation to original size.

### Architecture

The architecture of generator is:

`c7s1-64, d128, d256, R256, R256, R256,
R256, R256, R256, u128, u64, c7s1-3

where c7s1-k denote a 7×7 Convolution-InstanceNorm-ReLU layer with k filters and stride 1. dk denotes a 3 × 3 Convolution-InstanceNorm-ReLU layer with k filters and stride 2. Rk denotes a residual block that contains two 3 × 3 convolution layers with the same number of filters on both layer. uk denotes a 3 × 3 fractional-strides-Convolution-InstanceNorm-ReLU layer with k filters and stride 1/2 (i.e deconvolution operation).

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [
                nn.Conv2d(64 * mult, 64 * mult * 2, 3, stride=2, padding=1),
                nn.InstanceNorm2d(64 * mult * 2),
            ]

        mult = 2 ** n_downsampling
        for i in range(n_residual_blocks):
            model += [ResidualBlock(64 * mult)]

        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [
                nn.ConvTranspose2d(64 * mult, 64 * mult // 2, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(64 * mult // 2),
                nn.ReLU(inplace=True),
            ]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(64, output_nc, 7)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [3]:
# Instantiate generators
gen_AtoB = Generator(3, 3)
gen_BtoA = Generator(3, 3)

# Basic tests with random input
input_tensor = torch.randn(1, 3, 256, 256)

output_tensor = gen_AtoB(input_tensor)
assert output_tensor.shape == (1, 3, 256, 256), "Generator output has incorrect shape"

output_tensor = gen_BtoA(input_tensor)
assert output_tensor.shape == (1, 3, 256, 256), "Generator output has incorrect shape"


print("Generator instantiation and basic tests passed successfully.")


Generator instantiation and basic tests passed successfully.


## CycleGAN Discriminator

In discriminator the authors use PatchGAN discriminator. The difference between a PatchGAN and regular GAN discriminator is that rather the regular GAN maps from a 256×256 image to a single scalar output, which signifies “real” or “fake”, whereas the PatchGAN maps from 256×256 to an NxN (here 70×70) array of outputs X, where each Xij signifies whether the patch ij in the image is real or fake.

### Architecture

The architecture of discriminator is :

`C64-C128-C256-C512`

where Ck is 4×4 convolution-InstanceNorm-LeakyReLU layer with k filters and stride 2. We don’t apply InstanceNorm on the first layer (C64). After the last layer, we apply convolution operation to produce a 1×1 output.

In [4]:
class Discriminator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            self.discriminator_block(64, 128),
            self.discriminator_block(128, 256),
            self.discriminator_block(256, 512),

            nn.Conv2d(512, 1, 4, padding=1)
        )

    def discriminator_block(self, input_dim, output_dim, is_first=False, is_last=False):
        """Returns downsampling layers of each discriminator block"""
        return nn.Sequential(
                      nn.Conv2d(input_dim, output_dim, kernel_size=4, stride=2, padding=1),
                      nn.InstanceNorm2d(output_dim),
                      nn.LeakyReLU(0.2, inplace=True))

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

In [5]:
# Instantiate discriminators
dis_A = Discriminator(3, 1)
dis_B = Discriminator(3, 1)

# Basic tests with random input
input_tensor = torch.randn(1, 3, 256, 256)

output_tensor = dis_A(input_tensor)
assert output_tensor.shape == (1, 1), "Discriminator output has incorrect shape"

output_tensor = dis_B(input_tensor)
assert output_tensor.shape == (1, 1), "Discriminator output has incorrect shape"

print("Discriminator instantiation and basic tests passed successfully.")


Discriminator instantiation and basic tests passed successfully.


# Cost Functions

- **Adversarial Loss:**  We apply adversarial loss to both our mappings of generators and discriminators. This adversary loss is written as :

$$ Loss_{advers} \left ( G, D_y, X, Y \right ) =\frac{1}{m}\sum \left ( 1 - D_y\left ( G\left ( x \right ) \right ) \right )^{2} $$  

$$ Loss_{advers}\left ( F, D_x, Y, X \right ) =\frac{1}{m}\sum \left ( 1 - D_x\left ( F\left ( y \right ) \right ) \right )^{2} $$   

- **Cycle Consistency Loss:** Given a random set of images adversarial network can map the set of input image to random permutation of images in the output domain which may induce the output distribution similar to target distribution. Thus adversarial mapping cannot guarantee the input xi  to yi . For this to happen the author proposed that process should be cycle-consistent.

  This loss function used in Cycle GAN to measure the error rate of  inverse mapping G(x) -> F(G(x)). The behavior induced by this loss function cause closely matching the real input (x) and F(G(x))

$$ Loss_{cyc}\left ( G, F, X, Y \right ) =\frac{1}{m}\left [ \left ( F\left ( G\left ( x_i \right ) \right )-x_i \right ) +\left ( G\left ( F\left ( y_i \right ) \right )-y_i \right ) \right ] $$   


The Cost function we used is the sum of adversarial loss and cyclic consistent loss:


$$ L\left ( G, F, D_x, D_y \right ) = L_{advers}\left (G, D_y, X, Y \right ) + L_{advers}\left (F, D_x, Y, X \right ) + \lambda L_{cycl}\left ( G, F, X, Y \right ) $$

and our aim is :


$$ arg \underset{G, F}{min}\underset{D_x, D_y}{max}L\left ( G, F, D_x, D_y \right ) $$   

In [8]:
class CycleGANLoss(nn.Module):
    """Define different GAN objectives.

    The CycleGANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, target_real_label=1.0, target_fake_label=0.0):
        super(CycleGANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.loss = nn.BCEWithLogitsLoss()

    def _get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.

        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """

        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        target_tensor = self._get_target_tensor(prediction, target_is_real)
        loss = self.loss(prediction, target_tensor)
        return loss

In [10]:
# Instantiate optimizers
optimizer_G = optim.Adam(list(gen_AtoB.parameters()) + list(gen_BtoA.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(dis_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(dis_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Instantiate loss functions
criterionGAN = CycleGANLoss()
criterionCycle = torch.nn.L1Loss()

# Assert tests for dimensions, forward and backward pass, and test the losses calculations
# Test generator forward pass
input_tensor = torch.randn(1, 3, 256, 256)
output_tensor = gen_AtoB(input_tensor)
assert output_tensor.shape == (1, 3, 256, 256)

# Test discriminator forward pass
output_tensor = dis_A(input_tensor)
assert output_tensor.shape == (1, 1)

# Test backward pass for generator
output_tensor = gen_AtoB(input_tensor)
loss = criterionGAN(output_tensor, True)
loss.backward()

# Test backward pass for discriminator
output_tensor = dis_A(input_tensor)
loss = criterionGAN(output_tensor, True)
loss.backward()

# Test cycle consistency loss
fake_B = gen_AtoB(input_tensor)
rec_A = gen_BtoA(fake_B)
cycle_loss = criterionCycle(rec_A, input_tensor)
assert cycle_loss.shape == ()

# Test GAN loss
fake_B = gen_AtoB(input_tensor)
pred_fake = dis_B(fake_B)
gan_loss = criterionGAN(pred_fake, True)
assert gan_loss.shape == ()

print("Optimizers, loss functions, and assert tests instantiated successfully.")


Optimizers, loss functions, and assert tests instantiated successfully.


# CycleGAN Class

In [None]:
class CycleGAN: 
    def __init__(self, input_nc, output_nc, lr=0.0002, beta1=0.5, beta2=0.999, device='cpu'):
        """
        Initializes the CycleGAN model, optimizers, and losses.
        
        Args:
        - input_nc: Number of input channels (e.g. 3 for RGB).
        - output_nc: Number of output channels.
        - lr: Learning rate for optimizers.
        - beta1, beta2: Beta parameters for Adam optimizer.
        - device: 'cuda' or 'cpu'.
        """
        self.device = device
        
        # Initialize generators
        self.gen_AtoB = Generator(input_nc, output_nc).to(self.device)
        self.gen_BtoA = Generator(input_nc, output_nc).to(self.device)
        
        # Initialize discriminators
        self.dis_A = Discriminator(input_nc).to(self.device)
        self.dis_B = Discriminator(input_nc).to(self.device)
        
        # Define loss functions
        self.adversarial_loss = CycleGANLoss().to(self.device)
        self.cycle_loss    = nn.L1Loss().to(self.device)
        self.identity_loss = nn.L1Loss().to(self.device)
        
        # Optimizers
        self.optimizer_G = optim.Adam(list(self.gen_AtoB.parameters()) + list(self.gen_BtoA.parameters()), 
                                      lr=lr, betas=(beta1, beta2))
        self.optimizer_D_A = optim.Adam(self.dis_A.parameters(), lr=lr, betas=(beta1, beta2))
        self.optimizer_D_B = optim.Adam(self.dis_B.parameters(), lr=lr, betas=(beta1, beta2))
    
    def forward(self, real_A, real_B):
        """
        Forward pass for both generators.
        
        Args:
        - real_A: Real image from domain A.
        - real_B: Real image from domain B.
        
        Returns:
        - fake_B: Generated image for domain B.
        - fake_A: Generated image for domain A.
        - recovered_A: Reconstructed image from fake_B -> A.
        - recovered_B: Reconstructed image from fake_A -> B.
        """
        fake_B = self.gen_AtoB(real_A)
        fake_A = self.gen_BtoA(real_B)
        recovered_A = self.gen_BtoA(fake_B)
        recovered_B = self.gen_AtoB(fake_A)
        
        return fake_B, fake_A, recovered_A, recovered_B
    
    def compute_loss(self, real_A, real_B):
        """
        Computes the total loss for generators and discriminators.
        
        Args:
        - real_A: Real image from domain A.
        - real_B: Real image from domain B.
        
        Returns:
        - loss_G: Generator total loss.
        - loss_D_A: Discriminator A loss.
        - loss_D_B: Discriminator B loss.
        """
        # Forward pass for generators
        fake_B, fake_A, recovered_A, recovered_B = self.forward(real_A, real_B)
        
        # Identity loss (G_A2B(B) should equal B, G_B2A(A) should equal A)
        loss_identity_A = self.identity_loss(self.gen_BtoA(real_A), real_A)
        loss_identity_B = self.identity_loss(self.gen_AtoB(real_B), real_B)

        # GAN loss
        loss_G_AtoB = self.adversarial_loss(self.dis_B(fake_B), torch.ones_like(self.dis_B(fake_B)))
        loss_G_BtoA = self.adversarial_loss(self.dis_A(fake_A), torch.ones_like(self.dis_A(fake_A)))

        # Cycle-consistency loss
        loss_cycle_A = self.cycle_loss(recovered_A, real_A)
        loss_cycle_B = self.cycle_loss(recovered_B, real_B)

        # Total generator loss
        loss_G = (loss_G_AtoB + loss_G_BtoA) + 10 * (loss_cycle_A + loss_cycle_B) + 5 * (loss_identity_A + loss_identity_B)

        # Discriminator A loss
        loss_real_A = self.adversarial_loss(self.dis_A(real_A), torch.ones_like(self.dis_A(real_A)))
        loss_fake_A = self.adversarial_loss(self.dis_A(fake_A.detach()), torch.zeros_like(self.dis_A(fake_A)))
        loss_D_A = (loss_real_A + loss_fake_A) * 0.5

        # Discriminator B loss
        loss_real_B = self.adversarial_loss(self.dis_B(real_B), torch.ones_like(self.dis_B(real_B)))
        loss_fake_B = self.adversarial_loss(self.dis_B(fake_B.detach()), torch.zeros_like(self.dis_B(fake_B)))
        loss_D_B = (loss_real_B + loss_fake_B) * 0.5
        
        return loss_G, loss_D_A, loss_D_B
    
    def optimize(self, real_A, real_B):
        """
        Perform one optimization step for the generators and discriminators.
        
        Args:
        - real_A: Real image from domain A.
        - real_B: Real image from domain B.
        
        Returns:
        - loss_G: Generator total loss.
        - loss_D_A: Discriminator A loss.
        - loss_D_B: Discriminator B loss.
        """
        # Compute losses
        loss_G, loss_D_A, loss_D_B = self.compute_loss(real_A, real_B)

        # Optimize Generators
        self.optimizer_G.zero_grad()
        loss_G.backward()
        self.optimizer_G.step()

        # Optimize Discriminator A
        self.optimizer_D_A.zero_grad()
        loss_D_A.backward()
        self.optimizer_D_A.step()

        # Optimize Discriminator B
        self.optimizer_D_B.zero_grad()
        loss_D_B.backward()
        self.optimizer_D_B.step()
        
        return loss_G.item(), loss_D_A.item(), loss_D_B.item()
    
    def save_model(self, epoch, path='cycle_gan_model.pth'):
        """
        Save the current model state.
        
        Args:
        - epoch: Current epoch number.
        - path: Path to save the model.
        """
        torch.save({
            'epoch': epoch,
            'gen_AtoB_state_dict': self.gen_AtoB.state_dict(),
            'gen_BtoA_state_dict': self.gen_BtoA.state_dict(),
            'dis_A_state_dict': self.dis_A.state_dict(),
            'dis_B_state_dict': self.dis_B.state_dict(),
            'optimizer_G_state_dict': self.optimizer_G.state_dict(),
            'optimizer_D_A_state_dict': self.optimizer_D_A.state_dict(),
            'optimizer_D_B_state_dict': self.optimizer_D_B.state_dict(),
        }, path)

    def load_model(self, path):
        """
        Load a saved model state.
        
        Args:
        - path: Path to the saved model.
        """
        checkpoint = torch.load(path)
        self.gen_AtoB.load_state_dict(checkpoint['gen_AtoB_state_dict'])
        self.gen_BtoA.load_state_dict(checkpoint['gen_BtoA_state_dict'])
        self.dis_A.load_state_dict(checkpoint['dis_A_state_dict'])
        self.dis_B.load_state_dict(checkpoint['dis_B_state_dict'])
        self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        self.optimizer_D_A.load_state_dict(checkpoint['optimizer_D_A_state_dict'])
        self.optimizer_D_B.load_state_dict(checkpoint['optimizer_D_B_state_dict'])


# Dataset

In [6]:
def show_img(img):
    """
    Show image with it's size

    img: tensor
    """

    img = img.permute(1, 2, 0)
    if img.shape[2]==1:
        img = img.view(img.shape[0], img.shape[1])
    plt.title(f'Image has size {img.cpu().numpy().shape}')
    plt.imshow(img,cmap='gray')
    plt.axis('off')
    plt.show()

In [None]:
transform = v2.Compose([
    v2.Resize((256, 256)),
    v2.ToTensor(),
    v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load datasets
dataset_A = datasets.ImageFolder('path_to_dataset_A', transform=transform)
dataset_B = datasets.ImageFolder('path_to_dataset_B', transform=transform)

# DataLoaders
dataloader_A = DataLoader(dataset_A, batch_size=16, shuffle=True)
dataloader_B = DataLoader(dataset_B, batch_size=16, shuffle=True)


# Training

## Hyperparameters

In [3]:
hyperparameters = {
	"latent_dim" : 100,
	"num_epochs" : 100,
	"batch_size" : 16,
	"device" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
	"lr" : 0.0002,
	"b1" : 0.5,
	"b2" : 0.999,
	"n_cpu" : 8,
	"img_size" : 256,
	"channels" : 3,
	"sample_interval" : 100,
	"checkpoint_interval" : 10,
}

## Training loop

The training loop includes the following steps:

1. **Generate fake images** from the generators.
2. **Compute adversarial loss** for both discriminators and generators.
3. **Cycle loss** to ensure image reconstruction.
4. **Identity loss** to preserve image identity during translation.
5. Update the **discriminator and generator** weights using backpropagation.

In [None]:
# Training hyperparameters
num_epochs = 100
cycle_gan = CycleGAN(input_nc=3, output_nc=3, device=hyperparameters["device"])

for epoch in range(num_epochs):
    for i, (real_A, real_B) in enumerate(zip(dataloader_A, dataloader_B)):
        real_A = real_A[0].to(hyperparameters["device"])
        real_B = real_B[0].to(hyperparameters["device"])
        
        # Perform one optimization step
        loss_G, loss_D_A, loss_D_B = cycle_gan.optimize(real_A, real_B)
    
    # Print progress
    print(f"[Epoch {epoch+1}/{num_epochs}] Loss_G: {loss_G:.4f}, Loss_D_A: {loss_D_A:.4f}, Loss_D_B: {loss_D_B:.4f}")
    
    # Save model every 10 epochs
    if (epoch + 1) % 10 == 0:
        cycle_gan.save_model(epoch)