# Creating your own GAN I: LSGAN

In this notebook you will learn how to create your own Generative Adversarial Network with `vegans`. This is a more advanced topic which gives you deeper insights into the design of the library. If you do not want to implement your own networks it might still be worth reading through the next couple of notebooks to get a deeper understanding of what's happening in the background. If you're not interested in that it's fine as well and you can work with the already implemented models :) As the time of writing this notebook (2021-04-08 08:32) there are only 3 (6) GAN architectures implemented in `vegans`: `VanillaGAN`, `WasssersteinGAN`, `WassersteinGANGP` and all there conditional variants. In this notebook I will explain to you how to implement the `LSGAN` and `ConditionalLSGAN` (which will then probably be part of the library after finishing this notebook), which stands for [Least-Squares GAN](https://arxiv.org/abs/1611.04076v3). In the next few tutorials we will create successively implement more difficult architectures (Pix2Pix, LR-GAN).

First import the usual libraries:

In [1]:
import os
import torch

from vegans.GAN import ConditionalWassersteinGAN, ConditionalWassersteinGANGP
from vegans.utils.utils import plot_losses, plot_images, get_input_dim

## AbstractGenerativeModel

We will first implement the unconditional variant of the `LSGAN`, investigate the base classes to be used and the move on to the conditional version.

The most important class in `vegans` is the `AbstractGenerativeModel`. It takes care of a lot of boilerplate code for logging and saving stuff, checking for the correct input and defining the correct variables. Every unconditional model (and also conditional model for that matter) should inherit from this class. It is semi-abstract in the sense that the `AbstractGenerativeModel` itself can not be used for training anything as it's missing some important functions which MUST be implemented by its children. 

These **abstract methods** are
- \_\_init\_\_(self, x_dim, z_dim, optim, optim_kwargs, fixed_noise_size, device, folder, ngpu):
    This takes care of the initializaton and the method 
    
    super().\_\_init\_\_(
        x_dim=x_dim, z_dim=z_dim, optim=optim, optim_kwargs=optim_kwargs,
        fixed_noise_size=fixed_noise_size, device=device, folder=folder, ngpu=ngpu
     )
     
     must be called at the end of the `__init__` method.

- _define_loss(self): Not strictly necessary but it is still kept as an abstract method so that the user has to think about what he wants to implement here. You can also implement it with a single `pass` statement. However, we will show you it's intended use here.

- calculate_losses(self, X_batch, Z_batch, who): The core function that needs to be implemented. For every batch it must populate an already existing (but empty) dictionary `self._losses`. The keys of the dictionary must include at least the keys used in `self.neural_nets` (explained below), but can also contain other losses. We will discuss this further in later implementations.

The `AbstractGenerativeModel` will also check for the presence of one very **important** attribute:
- self.neural_nets: This is a dictionary containing all the different networks to be trained. It might look like
    {
        "Generator": generator_nn_Module,
        "Adversary": adversary_nn_Module,
        "Encoder": encoder_nn_Module
    }

The values of the dictionary must inherit in one way or another from `nn.Module`. The user of the implemented GAN must make sure of that by using `nn.Sequential` or building their own architectures which inherit from `nn.Module` (like shown in all previous tutorials).

The keys of the dictionary are equally as important because these will link together different parts used during training, for example:
- self.optimizers: dict containing the same keys as `self.neural_nets`. Containing one optimizer per network.
- self.steps: dict containing the same keys as `self.neural_nets`. Containing the number of training steps per network.
- self._losses: dict containing the same keys as `self.neural_nets`. Containing the loss functions per network.

Which key will be used per training step is determined by the `who`argument of `calculate_losses`. In this example case `who` will be one of "Generator", "Adversary" or "Encoder". In the `fit` function of the AbstractGenerativeModel we have the following code snippet, which should now make sense:

```python
for name, _ in self.neural_nets.items():     # iterates over ["Generator", "Adversary", "Encoder"]
    for _ in range(self.steps[name]):        # Train every network for its specified steps
        self._losses = {}                    # Empty _losses dictionary
        self.calculate_losses(X_batch=X, Z_batch=Z, who=name) # populate _losses dictionary / USER-DEFINED
        self._zero_grad(who=name)            # Set appropriate gradients to zero
        self._backward(who=name)             # Calculate the gradients for a certain loss 
        self._step(who=name)                 # Propagate the update through the network
```

Here we iterate over all existing keys in `self.neural_nets` (["Generator", "Adversary", "Encoder"]). We use this key to fetch the number of steps this network should be trained for. We then create an empty `self._losses` dictionary which **MUST** be populated when calling `sef.calculate_losses`. After that we call the usual torch functions to set the previous gradients to zero, calculate the gradients and propagate those through the network.

Now we covered most of the important things. They will be explained again at the appropriate position over the course of the next few notebooks whenever they become relevant.

## AbstractGAN1v1

We can almost start with the implementation of the LSGAN. There exists one more utility class which is not as abstract as `AbstractGenerativeModel` (it in fact inherits from `AbstractGenerativeModel`) but it is not yet a true `GAN` implementation. This is the `AbstractGAN1v1` which should be used whenever you want to implement a `GAN` of the structure 

```python
self.neural_nets = {
    "Generator": generator_nn_Module,
    "Adversary": adversary_nn_Module
}
```

So one generator vs one adversary. This includes the VanillaGAN, WassersteinGAN, WassersteinGANGP as well as the LSGAN. It already implements the `calculate_losses` abstract method (which can be overriden of course) and takes care of initialization. So implementing LSGAN becomes very easy. More advanced cases are in the next notebooks.

In [11]:
from vegans.models.unconditional.AbstractGAN1v1 import AbstractGAN1v1

## LSGAN

Now let's with the class definition and `__init__` method.

In [12]:
class LSGAN(AbstractGAN1v1):
    def __init__(
            self,
            generator,
            adversary,
            x_dim,
            z_dim,
            optim=None,
            optim_kwargs=None,
            feature_layer=None,
            fixed_noise_size=32,
            device=None,
            folder="./LSGAN",
            ngpu=None):

        super().__init__(
            generator=generator, adversary=adversary,
            z_dim=z_dim, x_dim=x_dim, adv_type="Discriminator",
            optim=optim, optim_kwargs=optim_kwargs, feature_layer=feature_layer,
            fixed_noise_size=fixed_noise_size,
            device=device, folder=folder, ngpu=ngpu
        )

This is basically a copy of the code for the [VanillaGAN](https://github.com/tneuer/GAN-pytorch/blob/main/vegans/models/unconditional/VanillaGAN.py). We do not need to inherit from `AbstractGenerativeModel` explicitly because this is already done by `AbstractGAN1v1`.

As for all networks we expect an optim(izer) dictionary, optim_kwargs (optimizer keyword arguments), fixed_noise_size (for logging purposes), the device ("cpu" or "cuda"), folder and ngpu (number gpus). We simply pass this to the parent class [AbstractGAN1v1](https://github.com/tneuer/GAN-pytorch/blob/main/vegans/models/unconditional/AbstractGAN1v1.py) which will immediatly create the very important

`self.neural_nets = {"Generator": self.generator, "Adversary": self.adversary}`

You are bound by these names ("Generator", "Adversary") if you are using `AbstractGAN1v1`. If you don't like them you need to implement a little bit more (next notebooks).

Notice that we used `adv_type="Discriminator"` which indicates that the `adversary` must output a value between [0, 1]. This will be checked when the user passes an adversary architecture. If you want the output to be between [-Inf, Inf] use `adv_type="Critic"`.

The argument `feature_layer` enables the usage of a feature loss for the generator overwriting the default loss function. In this case this would of course be the Least Squares loss. 

Because it is so simple we will implement the missing method `_define_loss` in one go. The loss function must take two arguments:

- output from discriminator (or critic) -> prediction
- real and false labels. They will be generated by the `AbstractGAN1v1` and are either arrays full of ones or zeros -> labels

For the LSGAN (Least-Squares GAN) we use the Mean-Squared-Error loss given by:

Discriminator: 0.5 \* ( (D(x) - b)\*\*2 + (D(G(z)) - a)\*\*2 )

Generator: 0.5 \*  (D(G(z)) - c)\*\*2 

where D(x) is the discriminator output for real images (predictions), G(z) is the generator output and a, b, c are parameters. Very often we set a=0, b=c=1. This is what we will do in our implementation. This results in:

Discriminator: 0.5 \* ( (D(x) - 1)\*\*2 + D(G(z)\*\*2 )

Generator: 0.5 \*  (D(G(z)) - 1)\*\*2 

This means we assign the real images a label of 1 and the fake images a label of 0. The generator tries to force the discriminator to output 1 for its images. This is already implemented in pytorch with `torch.nn.MSELoss()`.

In [15]:
class LSGAN(AbstractGAN1v1):
    def __init__(
            self,
            generator,
            adversary,
            x_dim,
            z_dim,
            optim=None,
            optim_kwargs=None,
            fixed_noise_size=32,
            device=None,
            folder="./VanillaGAN",
            ngpu=None):

        super().__init__(
            generator=generator, adversary=adversary,
            z_dim=z_dim, x_dim=x_dim, adv_type="Discriminator",
            optim=optim, optim_kwargs=optim_kwargs,
            fixed_noise_size=fixed_noise_size,
            device=device, folder=folder, ngpu=ngpu
        )

    def _define_loss(self):
        self.loss_functions = {"Generator": torch.nn.MSELoss(), "Adversary": torch.nn.MSELoss()}

We chose the `torch.optim.Adam` optimizer as a default optimizer and implemented the appropriate loss. The parent classes `AbstractGAN1v1` and `AbstractGenerativeModel` will take care of all the rest. For example if `who="Generator"`, the following code snippet is called:

```python
def _calculate_generator_loss(self, X_batch, Z_batch):
    fake_images = self.generate(z=Z_batch)
    fake_predictions = self.predict(x=fake_images)
    gen_loss = self.loss_functions["Generator"](
        fake_predictions, torch.ones_like(fake_predictions, requires_grad=False)
    )
    self._losses.update({"Generator": gen_loss})
``` 

So a normal least squares loss is calculated between the output of the discriminator for generated images: `D(G(z)) -> self.predict(self.generate(z=Z_batch))` and labels 1 (because c=1). This is saved in the `self._losses` dict which is used by the `AbstractGenerativeModel` to perform optimization. A similar function is called for the adversary, for completeness stated here:

```python
def _calculate_adversary_loss(self, X_batch, Z_batch):
    fake_images = self.generate(z=Z_batch).detach()
    fake_predictions = self.predict(x=fake_images)
    real_predictions = self.predict(x=X_batch.float())

    adv_loss_fake = self.loss_functions["Adversary"](
        fake_predictions, torch.zeros_like(fake_predictions, requires_grad=False)
    )
    adv_loss_real = self.loss_functions["Adversary"](
        real_predictions, torch.ones_like(real_predictions, requires_grad=False)
    )
    adv_loss = 0.5*(adv_loss_fake + adv_loss_real)
    self._losses.update({
        "Adversary": adv_loss,
        "Adversary_fake": adv_loss_fake,
        "Adversary_real": adv_loss_real,
    })
``` 

Please be aware that this is the code snippet at the time of writing. It may change in the future but will probably not change drastically from this implementation.

The network would now be ready to be used :)

But we won't stop here and go quickly over the implementation of the `ConditionalLSGAN` so we can take labels and images as conditional input.

## ConditionalLSGAN

We can basically do the same thing as before and copy from [CondtionalVanillaGAN](https://github.com/tneuer/GAN-pytorch/blob/main/vegans/models/conditional/ConditionalVanillaGAN.py). This time we will inherit from `AbstractConditionalGAN1v1` (which inherits from `AbstractConditionalGenerativeModel` which in turn inherits from `AbstractGenerativeModel`). Everything is a `AbstractGenerativeModel` in the end. 

The main difference is that we now must also pass the `y_dim` (Dimension of the labels).

In [16]:
from vegans.models.unconditional.LSGAN import LSGAN
from vegans.models.conditional.AbstractConditionalGAN1v1 import AbstractConditionalGAN1v1

class ConditionalLSGAN(AbstractConditionalGAN1v1, LSGAN):
    def __init__(
            self,
            generator,
            adversary,
            x_dim,
            z_dim,
            y_dim,
            optim=None,
            optim_kwargs=None,
            feature_layer=None,
            fixed_noise_size=32,
            device=None,
            folder="./ConditionalVanillaGAN",
            ngpu=None):

        super().__init__(
            generator=generator, adversary=adversary,
            x_dim=x_dim, z_dim=z_dim, y_dim=y_dim, adv_type="Discriminator",
            optim=optim, optim_kwargs=optim_kwargs, feature_layer=feature_layer,
            fixed_noise_size=fixed_noise_size,
            device=device, folder=folder, ngpu=ngpu
        )


This is all we had to do. The rest is handled by the two parent classes. This algorithm should now be possible to generate specific instances of handwritten digits or even translate an image of a summer scenery into winter scenery (note that there are other special architectures for especially this last problem, like BiCycleGAN). It was pretty easy to implement this network due to the abstractions that were in place. In the next session we will implement a Pix2PixGAN which requires some additonal modification of the loss function but also shouldn't be too difficult.

Please again note that this is a prelimanary tutorial implementation which might or might not change in future releases of `vegans`. So this implementation might not be completely up-to-date, but still is a viable implementation nonetheless.