# Creating your own GAN III: LR-GAN

In the last notebook we implemented the `Pix2Pix` GAN together, this time we're tackling the Latent-Regressor GAN. It is described in one of our favorite papers describing a really amazing algorithm: The [BicycleGAN](https://arxiv.org/pdf/1711.11586.pdf). This will soon be implemented as well in `vegan`. As the time of writing this notebook (2021-04-08 20:22, only one hour passed since we started implementing the Pix2Pix :) ) there are only 4 (9) GAN architectures implemented in `vegans`: `VanillaGAN`, `WasssersteinGAN`, `WassersteinGANGP`, `LSGAN` and all there conditional variants plus the Pix2Pix which only is a conditional algorithm. The LR-GAN finally tackles the problem of mode collapse which means that all random inputs into the generator are mapped to a single output which might fool the discriminator perfectly but does not look like a real image to humans. Even if i does it is not desirable as we want a high variety of output images. This is one of the most pressing problems so it is good that we finally deal with it (somewhat).

We suppose you have read the previous notebooks on the creation of the `LSGAN` and `Pix2Pix`. If not, please go over it so you have a understanding of the abstract base classes. We will not present it here and jump basically right into the implementation.

First import the usual libraries:

In [1]:
import os
import torch

os.chdir("/home/thomas/Backup/Algorithmen/GAN-pytorch")
from vegans.GAN import ConditionalWassersteinGAN, ConditionalWassersteinGANGP
from vegans.utils.utils import plot_losses, plot_images, get_input_dim

## Latent-Regressor GAN

Remember that so far we only looked at two networks learning against each other. A discriminator which is trained to differentiate between real and fake images and a generator which is trained to fool the discriminator. As mentioned above this quite often leads to mode collapse where the generator produces a single "perfect" image regardless of input. In the well known example of generating handwritten digits, the generator might learn to produce the perfect image of a zero but produces nothing else. The discriminator can never tell if the generator images are real or fake so the generator is content and will stop learning. The latent regressor GAN approaches this problem elegantly by introducing another helper network (Note that another technique to deal with this is minibatch discrimination which is not yet supported in `vegan`, but hopefully is in the future (it might even be implemented by the time you read this)).

The job of this helper network is to take the output of the generator (often a generated image) and compress it back into the latent space. So if we started from a random image (say with shape [1, 4, 4]) the generator produces an image from it and the helper network maps at back to a space of the initial dimension (again [1, 4, 4]). For this reason this network is called an **Encoder**. 
The output of the encoder is then compared to the initial latent input to the generator and a L1 (L2) norm is computed between the two. The goal of the generator (and encoder) is to minimise this L1 (L2) norm. 

This helps against mode collapse because if every latent input to the generator is mapped to basically the same output image the encoder will transform all those images back to one single latent vector. This vector will most of the times be quite different from the original input so the L1 norm increases which the generator has to minimize. Therefore it needs to produce a reproducible output for every input.

Let's now start with the implementation of the LRGAN. First note that unlike before we can't use the parent class `GAN1v1` because we now have three networks working against each other. There is no base class for this case in `vegan` (yet), so we can only inherit from `GenerativeModel` and do a lot of the footwork on our own. This means we need to implement all abstract methods:

- __init__(self, x_dim, z_dim, optim, optim_kwargs, fixed_noise_size, device, folder, ngpu):
    This takes care of the initializaton and the method 
    
    super().__init__(
        x_dim=x_dim, z_dim=z_dim, optim=optim, optim_kwargs=optim_kwargs,
        fixed_noise_size=fixed_noise_size, device=device, folder=folder, ngpu=ngpu
     )
     
     must be called at the end of the `__init__` method.
     
- _default_optimizer(self): returns an optimizer from torch.optim that is used if the user doesn't specify any optimizers in the `optim` keyword when constructing a class.

- _define_loss(self): Not strictly necessary but it is still kept as an abstract method so that the user has to think about what he wants to implement here. You can also implement it with a single `pass` statement. However, we will show you it's intended use here.

- calculate_losses(self, X_batch, Z_batch, who): The core function that needs to be implemented. For every batch it must populate an already existing (but empty) dictionary `self._losses`. The keys of the dictionary must include at least the keys used in `self.neural_nets` (explained below), but can also contain other losses. We will discuss this further in later implementations.

We start with the `__init__` method which **MUST** create the self.neural_nets dictionary. No worries, if you forget to specify it the `GenerativeModel` class will kindly remind you to populate the dictionary. I will copy some of the code from [GAN1v1](https://github.com/tneuer/GAN-pytorch/blob/main/vegans/models/unconditional/GAN1v1.py) just to get a feeling for how to start. I will reuse some of the code of course because that's what past-me would have wanted.

In [2]:
from vegans.models.unconditional.GenerativeModel import GenerativeModel
from torch.nn import BCELoss, L1Loss
from torch.nn import MSELoss as L2Loss
from vegans.utils.networks import Generator, Adversariat, Encoder

In [4]:
class LRGAN(GenerativeModel):
    #########################################################################
    # Actions before training
    #########################################################################
    def __init__(
            self,
            generator,
            adversariat,
            encoder,
            x_dim,
            z_dim,
            optim=None,
            optim_kwargs=None,
            fixed_noise_size=32,
            device=None,
            folder="./GAN1v1",
            ngpu=0):

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.generator = Generator(generator, input_size=z_dim, device=device, ngpu=ngpu)
        self.adversariat = Adversariat(adversariat, input_size=x_dim, adv_type="Discriminator", device=device, ngpu=ngpu)
        self.encoder = Encoder(encoder, input_size=x_dim, device=device, ngpu=ngpu)
        self.neural_nets = {
            "Generator": self.generator, "Adversariat": self.adversariat, "Encoder": self.encoder
        }

        GenerativeModel.__init__(
            self, x_dim=x_dim, z_dim=z_dim, optim=optim, optim_kwargs=optim_kwargs,
            fixed_noise_size=fixed_noise_size, device=device, folder=folder, ngpu=ngpu
        )
    
    def _default_optimizer(self):
        return torch.optim.Adam

    def _define_loss(self):
        self.loss_functions = {"Generator": BCELoss(), "Adversariat": BCELoss(), "L1": L1Loss()}
        

I included the construction of three `vegan` classes in the vegans.utils.networks module: Generator, Adversariat, Encoder. If you are sneaky and look at the [code](https://github.com/tneuer/GAN-pytorch/blob/main/vegans/utils/networks.py) you will notice that there is absolutely no difference between the implementations of Generator and Encoder (apart from the `name` attribute which only shows up when printing the network). On an abstract level both do the same thing: take an input (image or vector), applying weights and biases and finally producing an output (image or vector). Everything else is implementation detail. We also already implemented the `_default_optimizer` method (which is always the easiest part, mostly `torch.optim.Adam`) and the `_define_loss` method which includes the L1 loss for the encoder and generator.

Next we will implement the last abstract method: calculate_losses(...)! I again copy and modify existing code from the [GAN1v1](https://github.com/tneuer/GAN-pytorch/blob/main/vegans/models/unconditional/GAN1v1.py) class.

In [5]:
class LRGAN(GenerativeModel):
    #########################################################################
    # Actions before training
    #########################################################################
    def __init__(
            self,
            generator,
            adversariat,
            encoder,
            x_dim,
            z_dim,
            optim=None,
            optim_kwargs=None,
            fixed_noise_size=32,
            device=None,
            folder="./GAN1v1",
            ngpu=0):

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.generator = Generator(generator, input_size=z_dim, device=device, ngpu=ngpu)
        self.adversariat = Adversariat(adversariat, input_size=x_dim, adv_type="Discriminator", device=device, ngpu=ngpu)
        self.encoder = Encoder(encoder, input_size=x_dim, device=device, ngpu=ngpu)
        self.neural_nets = {
            "Generator": self.generator, "Adversariat": self.adversariat, "Encoder": self.encoder
        }

        GenerativeModel.__init__(
            self, x_dim=x_dim, z_dim=z_dim, optim=optim, optim_kwargs=optim_kwargs,
            fixed_noise_size=fixed_noise_size, device=device, folder=folder, ngpu=ngpu
        )
    
    def _default_optimizer(self):
        return torch.optim.Adam

    def _define_loss(self):
        self.loss_functions = {"Generator": BCELoss(), "Adversariat": BCELoss(), "L1": L1Loss()}
        
        
    #########################################################################
    # Actions during training
    #########################################################################
    def calculate_losses(self, X_batch, Z_batch, who=None):
        if who == "Generator":
            self._calculate_generator_loss(X_batch=X_batch, Z_batch=Z_batch)
        elif who == "Adversariat":
            self._calculate_adversariat_loss(X_batch=X_batch, Z_batch=Z_batch)
        elif who == "Encoder":
            self._calculate_encoder_loss(X_batch=X_batch, Z_batch=Z_batch)
        else:
            self._calculate_generator_loss(Z_batch=Z_batch)
            self._calculate_adversariat_loss(X_batch=X_batch, Z_batch=Z_batch)
            self._losses["RealFakeRatio"] = self._losses["Adversariat_real"]/self._losses["Adversariat_fake"]

    def _calculate_generator_loss(self, Z_batch):
        fake_images = self.generate(z=Z_batch)
        fake_predictions = self.predict(x=fake_images)
        gen_loss = self.loss_functions["Generator"](
            fake_predictions, torch.ones_like(fake_predictions, requires_grad=False)
        )
        self._losses.update({"Generator": gen_loss})

    def _calculate_adversariat_loss(self, X_batch, Z_batch):
        fake_images = self.generate(z=Z_batch).detach()
        fake_predictions = self.predict(x=fake_images)
        real_predictions = self.predict(x=X_batch.float())

        adv_loss_fake = self.loss_functions["Adversariat"](
            fake_predictions, torch.zeros_like(fake_predictions, requires_grad=False)
        )
        adv_loss_real = self.loss_functions["Adversariat"](
            real_predictions, torch.ones_like(real_predictions, requires_grad=False)
        )
        adv_loss = 0.5*(adv_loss_fake + adv_loss_real)
        self._losses.update({
            "Adversariat": adv_loss,
            "Adversariat_fake": adv_loss_fake,
            "Adversariat_real": adv_loss_real,
        })
        
    def _calculate_encoder_loss(self, X_batch, Z_batch):
        pass

Please again note that this is a prelimanary tutorial implementation which might or might not change in future releases of `vegan`. So this implementation might not be completely up-to-date, but still is a viable implementation nonetheless.