# Creating your own GAN II: Pix2Pix

In the last notebook we implemented the `LSGAN` together, this time we're tackling the [Pix2PixGAN](https://arxiv.org/pdf/1611.07004.pdf). Check also [this resource](https://blog.eduonix.com/artificial-intelligence/pix2pix-gan/) for a good explanation. As the time of writing this notebook (2021-04-08 19:14) there are only 4 (8) GAN architectures implemented in `vegans`: `VanillaGAN`, `WasssersteinGAN`, `WassersteinGANGP`, `LSGAN` and all there conditional variants. The Pix2Pix GAN is a purely conditional network, which (as the name suggests) is popular for image to image translation tasks (rotating mnist digits, summer to winter scenery, horses to zebras, person to person with beard, ...). Note that we could use conditional WassersteinGANs for this task but they are not really optimized for it, so we will implement this algorithm here.

We suppose you have read the previous notebook on the creation of the `LSGAN`. If not, please go over it so you have a understanding of the abstract base classes. We will not present it here and jump basically right into the implementation.

First import the usual libraries:

In [2]:
import os
import torch

os.chdir("/home/thomas/Backup/Algorithmen/GAN-pytorch")
from vegans.GAN import ConditionalWassersteinGAN, ConditionalWassersteinGANGP
from vegans.utils.utils import plot_losses, plot_images, get_input_dim

## ConditionalPix2Pix

As we mentioned in the introduction, there is not really an unconditional version of the Pix2Pix GANs because you always need data and the corresponding transformed image (the label in a way). But we are still choosing this name to be consistent with the naming schema of the `vegan` library (if it takes `y_dim` as input name it `conditional`).

We already know that we somehow need to inherit from the `ConditionalGenerativeModel` base class. However, if you are familiar with the Pix2Pix algorithm you'll know that it consists of two networks: a generator and a discriminator. So from the last notebook tutorial we know that we should be able to inherit from the `ConditioanlGAN1v1`. Let's try that (if this sentence is still in the notebook while you're reading it probably means it worked).

In [4]:
from vegans.models.conditional.ConditionalGAN1v1 import ConditionalGAN1v1
from torch.nn import BCELoss

In [5]:
class ConditionalPix2Pix(ConditionalGAN1v1):
    def __init__(
            self,
            generator,
            adversariat,
            x_dim,
            z_dim,
            y_dim,
            optim=None,
            optim_kwargs=None,
            fixed_noise_size=32,
            device=None,
            folder="./ConditionalPix2Pix",
            ngpu=None):

        super().__init__(
            generator=generator, adversariat=adversariat,
            x_dim=x_dim, z_dim=z_dim, y_dim=y_dim, adv_type="Discriminator",
            optim=optim, optim_kwargs=optim_kwargs,
            fixed_noise_size=fixed_noise_size,
            device=device, folder=folder, ngpu=ngpu
        )
    
    def _default_optimizer(self):
        return torch.optim.Adam

    def _define_loss(self):
        self.loss_functions = {"Generator": BCELoss(), "Adversariat": BCELoss()}

Note that this is not the whole story. If we leave it like that it would be a normal `VanillaGAN` (original GAN implementation). The Pix2Pix GAN introduces a pixel-wise loss for the generator. So the goal for the generator is twofold:

   1. Try to fool the discriminator and force it to output 1 for fake images (classify them as real).
   2. Minimize the pixel-wise mean squared error between the generated and target image.
   
So far we have only covered the first part so let's implement the pixel-wise loss. We can do this by creating a hook into the `ConditionalGAN1v1` implementation. During training this parent class executes the following code snippet during every training step:

```python
def calculate_losses(self, X_batch, Z_batch, y_batch, who=None):
    if who == "Generator":
        self._calculate_generator_loss(X_batch=X_batch, Z_batch=Z_batch, y_batch=y_batch)
    elif who == "Adversariat":
        self._calculate_adversariat_loss(X_batch=X_batch, Z_batch=Z_batch, y_batch=y_batch)
    else:
        self._calculate_generator_loss(X_batch=X_batch, Z_batch=Z_batch, y_batch=y_batch)
        self._calculate_adversariat_loss(X_batch=X_batch, Z_batch=Z_batch, y_batch=y_batch)
        self._losses["Loss/LossRatio"] = self._losses["Adversariat_real"]/self._losses["Adversariat_fake"]
```

where `who` is either "Generator" or "Adversariat" (Sometimes **OUTSIDE** of training it might be `None`. This is used for calculating losses which are saved, logged and printed to the console but never for training!). So we can create our own computation for the generator loss by defining our own method for 
`self._calculate_generator_loss(Z_batch=Z_batch, y_batch=y_batch)`. The "original" implementation for this method within `ConditionalGAN1v1` looks like this:

```python
def _calculate_generator_loss(self, X_batch, Z_batch, y_batch):
    fake_images = self.generate(y=y_batch, z=Z_batch)
    fake_predictions = self.predict(x=fake_images, y=y_batch)
    gen_loss = self.loss_functions["Generator"](
        fake_predictions, torch.ones_like(fake_predictions, requires_grad=False)
    )
    self._losses.update({"Generator": gen_loss})
```

So we reuse what we can but also include a pixel-wise image loss. Note that we include an additonal parameter to weight the pixel-wise loss called `lambda_l1`. The appropriate place to introduce this parameter is in the constructor. You do not have to include this parameter in self.hyperparameters (which is created by `GenerativeModel` of course), but it's nice to have all these values in a centralized dictionary.

In [7]:
class ConditionalPix2Pix(ConditionalGAN1v1):
    def __init__(
            self,
            generator,
            adversariat,
            x_dim,
            z_dim,
            y_dim,
            optim=None,
            optim_kwargs=None,
            lambda_l1 = 10,
            fixed_noise_size=32,
            device=None,
            folder="./ConditionalPix2Pix",
            ngpu=None):

        super().__init__(
            generator=generator, adversariat=adversariat,
            x_dim=x_dim, z_dim=z_dim, y_dim=y_dim, adv_type="Discriminator",
            optim=optim, optim_kwargs=optim_kwargs,
            fixed_noise_size=fixed_noise_size,
            device=device, folder=folder, ngpu=ngpu
        )
        self.lambda_l1 = 10
        self.hyperparameters["lambda_l1"] = self.lambda_l1
    
    def _default_optimizer(self):
        return torch.optim.Adam

    def _define_loss(self):
        self.loss_functions = {"Generator": BCELoss(), "Adversariat": BCELoss(), "L1": torch.nn.L1Loss()}
        
    def _calculate_generator_loss(self, X_batch, Z_batch, y_batch):
        fake_images = self.generate(y=y_batch, z=Z_batch)
        fake_predictions = self.predict(x=fake_images, y=y_batch)
        gen_loss_original = self.loss_functions["Generator"](
            fake_predictions, torch.ones_like(fake_predictions, requires_grad=False)
        )
        gen_loss_pixel_wise = self.loss_functions["L1"](
            X_batch, fake_images
        )
        gen_loss = gen_loss_original + self.lambda_l1*gen_loss_pixel_wise
        self._losses.update({
            "Generator": gen_loss,
            "Generator_Original": gen_loss_original,
            "Generator_L1": gen_loss_pixel_wise
        })

Note that we have included multiple losses in the `self._losses` dictionary. The only necessary one is with the key "Generator" because we have a network with this name in the `self.neural_nets` dictionary (the other one is called "Adversariat" for GANs inheriting from `ConditionalGAN1v1`). No backward step is performed on the other two but they are logged within tensorboard and the internal `self.logged_losses` dictionary.

**IMPORTANT NOTE**: If you are familiar with the [Pix2Pix Paper](https://arxiv.org/pdf/1611.07004.pdf) or other implementations of this algorithm you might be confused why we still have a Z_batch in there. The paper itself claims that the noise isn't really useful in the modeling here. Unfortuantely this is maybe one of the drawbacks of `vegans` (at least currently). With decreased ease of implementation and greater generalization we can not take care of every special use case. 
We still recommend using a very small dimension for `z_dim` so it doesn't greatly improve computation time.

Please again note that this is a prelimanary tutorial implementation which might or might not change in future releases of `vegan`. So this implementation might not be completely up-to-date, but still is a viable implementation nonetheless.