# Implementing the CycleGAN (vanilla architecture)

## Objetivo

Reproduzir uma CycleGAN vanilla, baseada no artigo [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/pdf/1703.10593v7).

Mais especificamente, construir, treinar e documentar esta arquitetura de GAN utilizando Pytorch, baseado na implementação em [https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py).

In [1]:
import torch
import matplotlib.pyplot as plt
torch.manual_seed(0)
from torchvision.utils import make_grid
from torchvision import datasets
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## CycleGAN generator

Each CycleGAN generator has three sections

- Encoder
- Transformer
- Decoder

The input image is passed into the encoder. The encoder extracts features from the input image by using Convolutions and compressed the representation of image but increase the number of channels.

The encoder consists of 3 convolution that reduces the representation by 1/4 th of actual image size. Consider an image of size (256, 256, 3) which we input into the encoder, the output of encoder will be (64, 64, 256).

Then the output of encoder after activation function is applied is passed into the transformer. The transformer contains 6 or 9 residual blocks based on the size of input.

The output of transformer is then passed into the decoder which uses 2 -deconvolution block of fraction strides to increase the size of representation to original size.

### Architecture

The architecture of generator is:

`c7s1-64, d128, d256, R256, R256, R256,
R256, R256, R256, u128, u64, c7s1-3

where c7s1-k denote a 7×7 Convolution-InstanceNorm-ReLU layer with k filters and stride 1. dk denotes a 3 × 3 Convolution-InstanceNorm-ReLU layer with k filters and stride 2. Rk denotes a residual block that contains two 3 × 3 convolution layers with the same number of filters on both layer. uk denotes a 3 × 3 fractional-strides-Convolution-InstanceNorm-ReLU layer with k filters and stride 1/2 (i.e deconvolution operation).

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [
                nn.Conv2d(64 * mult, 64 * mult * 2, 3, stride=2, padding=1),
                nn.InstanceNorm2d(64 * mult * 2),
            ]

        mult = 2 ** n_downsampling
        for i in range(n_residual_blocks):
            model += [ResidualBlock(64 * mult)]

        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [
                nn.ConvTranspose2d(64 * mult, 64 * mult // 2, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(64 * mult // 2),
                nn.ReLU(inplace=True),
            ]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(64, output_nc, 7)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)

In [3]:
# Instantiate generators
gen_AtoB = Generator(3, 3)
gen_BtoA = Generator(3, 3)

# Basic tests with random input
input_tensor = torch.randn(1, 3, 256, 256)

output_tensor = gen_AtoB(input_tensor)
assert output_tensor.shape == (1, 3, 256, 256), "Generator output has incorrect shape"

output_tensor = gen_BtoA(input_tensor)
assert output_tensor.shape == (1, 3, 256, 256), "Generator output has incorrect shape"


print("Generator instantiation and basic tests passed successfully.")


Generator instantiation and basic tests passed successfully.


## CycleGAN Discriminator

In discriminator the authors use PatchGAN discriminator. The difference between a PatchGAN and regular GAN discriminator is that rather the regular GAN maps from a 256×256 image to a single scalar output, which signifies “real” or “fake”, whereas the PatchGAN maps from 256×256 to an NxN (here 70×70) array of outputs X, where each Xij signifies whether the patch ij in the image is real or fake.

### Architecture

The architecture of discriminator is :

`C64-C128-C256-C512`

where Ck is 4×4 convolution-InstanceNorm-LeakyReLU layer with k filters and stride 2. We don’t apply InstanceNorm on the first layer (C64). After the last layer, we apply convolution operation to produce a 1×1 output.

In [4]:
class Discriminator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            self.discriminator_block(64, 128),
            self.discriminator_block(128, 256),
            self.discriminator_block(256, 512),

            nn.Conv2d(512, 1, 4, padding=1)
        )

    def discriminator_block(self, input_dim, output_dim, is_first=False, is_last=False):
        """Returns downsampling layers of each discriminator block"""
        return nn.Sequential(
                      nn.Conv2d(input_dim, output_dim, kernel_size=4, stride=2, padding=1),
                      nn.InstanceNorm2d(output_dim),
                      nn.LeakyReLU(0.2, inplace=True))

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

In [5]:
# Instantiate discriminators
dis_A = Discriminator(3, 1)
dis_B = Discriminator(3, 1)

# Basic tests with random input
input_tensor = torch.randn(1, 3, 256, 256)

output_tensor = dis_A(input_tensor)
assert output_tensor.shape == (1, 1), "Discriminator output has incorrect shape"

output_tensor = dis_B(input_tensor)
assert output_tensor.shape == (1, 1), "Discriminator output has incorrect shape"

print("Discriminator instantiation and basic tests passed successfully.")


Discriminator instantiation and basic tests passed successfully.


# Cost Functions

- **Adversarial Loss:**  We apply adversarial loss to both our mappings of generators and discriminators. This adversary loss is written as :

$$ Loss_{advers} \left ( G, D_y, X, Y \right ) =\frac{1}{m}\sum \left ( 1 - D_y\left ( G\left ( x \right ) \right ) \right )^{2} $$  

$$ Loss_{advers}\left ( F, D_x, Y, X \right ) =\frac{1}{m}\sum \left ( 1 - D_x\left ( F\left ( y \right ) \right ) \right )^{2} $$   

- **Cycle Consistency Loss:** Given a random set of images adversarial network can map the set of input image to random permutation of images in the output domain which may induce the output distribution similar to target distribution. Thus adversarial mapping cannot guarantee the input xi  to yi . For this to happen the author proposed that process should be cycle-consistent.

  This loss function used in Cycle GAN to measure the error rate of  inverse mapping G(x) -> F(G(x)). The behavior induced by this loss function cause closely matching the real input (x) and F(G(x))

$$ Loss_{cyc}\left ( G, F, X, Y \right ) =\frac{1}{m}\left [ \left ( F\left ( G\left ( x_i \right ) \right )-x_i \right ) +\left ( G\left ( F\left ( y_i \right ) \right )-y_i \right ) \right ] $$   


The Cost function we used is the sum of adversarial loss and cyclic consistent loss:


$$ L\left ( G, F, D_x, D_y \right ) = L_{advers}\left (G, D_y, X, Y \right ) + L_{advers}\left (F, D_x, Y, X \right ) + \lambda L_{cycl}\left ( G, F, X, Y \right ) $$

and our aim is :


$$ arg \underset{G, F}{min}\underset{D_x, D_y}{max}L\left ( G, F, D_x, D_y \right ) $$   

In [8]:
class CycleGANLoss(nn.Module):
    """Define different GAN objectives.

    The CycleGANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, target_real_label=1.0, target_fake_label=0.0):
        super(CycleGANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.loss = nn.BCEWithLogitsLoss()

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.

        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """

        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        target_tensor = self.get_target_tensor(prediction, target_is_real)
        loss = self.loss(prediction, target_tensor)
        return loss

In [10]:
# Instantiate optimizers
optimizer_G = optim.Adam(list(gen_AtoB.parameters()) + list(gen_BtoA.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(dis_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(dis_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Instantiate loss functions
criterionGAN = CycleGANLoss()
criterionCycle = torch.nn.L1Loss()

# Assert tests for dimensions, forward and backward pass, and test the losses calculations
# Test generator forward pass
input_tensor = torch.randn(1, 3, 256, 256)
output_tensor = gen_AtoB(input_tensor)
assert output_tensor.shape == (1, 3, 256, 256)

# Test discriminator forward pass
output_tensor = dis_A(input_tensor)
assert output_tensor.shape == (1, 1)

# Test backward pass for generator
output_tensor = gen_AtoB(input_tensor)
loss = criterionGAN(output_tensor, True)
loss.backward()

# Test backward pass for discriminator
output_tensor = dis_A(input_tensor)
loss = criterionGAN(output_tensor, True)
loss.backward()

# Test cycle consistency loss
fake_B = gen_AtoB(input_tensor)
rec_A = gen_BtoA(fake_B)
cycle_loss = criterionCycle(rec_A, input_tensor)
assert cycle_loss.shape == ()

# Test GAN loss
fake_B = gen_AtoB(input_tensor)
pred_fake = dis_B(fake_B)
gan_loss = criterionGAN(pred_fake, True)
assert gan_loss.shape == ()

print("Optimizers, loss functions, and assert tests instantiated successfully.")


Optimizers, loss functions, and assert tests instantiated successfully.


## Dataset

In [6]:
def show_img(img):
    """
    Show image with it's size

    img: tensor
    """

    img = img.permute(1, 2, 0)
    if img.shape[2]==1:
        img = img.view(img.shape[0], img.shape[1])
    plt.title(f'Image has size {img.cpu().numpy().shape}')
    plt.imshow(img,cmap='gray')
    plt.axis('off')
    plt.show()

In [None]:
transforms_all = v2.Compose([
    v2.Resize(size=50), # Resize the input to the size (50,50).
    v2.ToTensor()
])

## Training

### Hyperparameters

In [11]:
latent_dim = 100
num_epochs = 100
batch_size = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 8
img_size = 256
channels = 3
sample_interval = 100
checkpoint_interval = 10

### Training loop

In [None]:
# Training loop
#for epoch in range(num_epochs):
#    for i, batch in enumerate(dataloader):
#       # Convert list to tensor
#        real_images = batch[0].to(device)
#        # Adversarial ground truths
#        valid = torch.ones(real_images.size(0), 1, device=device)
#        fake = torch.zeros(real_images.size(0), 1, device=device)
#        # Configure input
#        real_images = real_images.to(device)
#
#        # ---------------------
#        #  Train Discriminator
#        # ---------------------
#        optimizer_D.zero_grad()
#        # Sample noise as generator input
#        z = torch.randn(real_images.size(0), latent_dim, device=device)
#        # Generate a batch of images
#        fake_images = generator(z)
#
#        # Measure discriminator's ability
#        # to classify real and fake images
#        real_loss = adversarial_loss(discriminator\
#                                     (real_images), valid)
#        fake_loss = adversarial_loss(discriminator\
#                                     (fake_images.detach()), fake)
#        d_loss = (real_loss + fake_loss) / 2
#        # Backward pass and optimize
#        d_loss.backward()
#        optimizer_D.step()
#
#        # -----------------
#        #  Train Generator
#        # -----------------
#
#        optimizer_G.zero_grad()
#        # Generate a batch of images
#        gen_images = generator(z)
#        # Adversarial loss
#        g_loss = adversarial_loss(discriminator(gen_images), valid)
#        # Backward pass and optimize
#        g_loss.backward()
#        optimizer_G.step()
#        # ---------------------
#        #  Progress Monitoring
#        # ---------------------
#        if (i + 1) % 100 == 0:
#            print(
#                f"Epoch [{epoch+1}/{num_epochs}]\
#                        Batch {i+1}/{len(dataloader)} "
#                f"Discriminator Loss: {d_loss.item():.4f} "
#                f"Generator Loss: {g_loss.item():.4f}"
#            )
#    # Save generated images for every epoch
#    if (epoch + 1) % 10 == 0:
#        with torch.no_grad():
#            z = torch.randn(16, latent_dim, device=device)
#            generated = generator(z).detach().cpu()
#            grid = torchvision.utils.make_grid(generated,\
#                                        nrow=4, normalize=True)
#            plt.imshow(np.transpose(grid, (1, 2, 0)))
#            plt.axis("off")
#            plt.show()