Skip to content

Latest commit

 

History

History
149 lines (99 loc) · 4.97 KB

gan.rst

File metadata and controls

149 lines (99 loc) · 4.97 KB

Training a GAN

We shall try to implement something more complicated using torchbearer - a Generative Adverserial Network (GAN). This tutorial is a modified version of the GAN from the brilliant collection of GAN implementations PyTorch_GAN by eriklindernoren on github.

Data and Constants

We first define all constants for the example.

/_static/examples/gan.py

We then define a number of state keys for convenience using .state_key. This is optional, however, it automatically avoids key conflicts.

/_static/examples/gan.py

We then define the dataset and dataloader - for this example, MNIST.

/_static/examples/gan.py

Model

We use the generator and discriminator from PyTorch_GAN.

/_static/examples/gan.py

We then create the models and optimisers.

/_static/examples/gan.py

Loss

GANs usually require two different losses, one for the generator and one for the discriminator. We define these as functions of state so that we can access the discriminator model for the additional forward passes required.

/_static/examples/gan.py

We will see later how we get a torchbearer trial to use these losses.

Metrics

We would like to follow the discriminator and generator losses during training - note that we added these to state during the model definition. In torchbearer, state keys are also metrics, so we can take means and running means of them and tell torchbearer to output them as metrics.

/_static/examples/gan.py

We will add this metric list to the trial when we create it.

Closures

The training loop of a GAN is a bit different to a standard model training loop. GANs require separate forward and backward passes for the generator and discriminator. To achieve this in torchbearer we can write a new closure. Since the individual training loops for the generator and discriminator are the same as a standard training loop we can use a ~torchbearer.bases.base_closure. The base closure takes state keys for required objects (data, model, optimiser, etc.) and returns a standard closure consisting of:

  1. Zero gradients
  2. Forward pass
  3. Loss calculation
  4. Backward pass

We create a separate closure for the generator and discriminator. We use separate state keys for some objects so we can use them separately, although the loss is easier to deal with in a single key.

/_static/examples/gan.py

We then create a main closure (a simple function of state) that runs both of these and steps the optimisers.

/_static/examples/gan.py

We will add this closure to the trial next.

Training

We now create the torchbearer trial on the GPU in the standard way. Note that when torchbearer is passed a None optimiser it creates a mock optimser that will just run the closure. Since we are using the standard torchbearer state keys for the generator model and criterion, we can pass them in here.

/_static/examples/gan.py

We now update state with the keys required for the discriminators closure and add the new closure to the trial. Note that torchbearer doesn't know the discriminator model is a model here, so we have to sent it to the GPU ourselves.

/_static/examples/gan.py

Finally we run the trial.

/_static/examples/gan.py

Visualising

We borrow the image saving method from PyTorch_GAN and put it in a call back to save ~torchbearer.callbacks.decorators.on_step_training. We generate from the same inputs each time to get a better visualisation.

/_static/examples/gan.py

Here is a Gif created from the saved images.

Source Code

The source code for the example is given below:

Download Python source code: gan.py </_static/examples/gan.py>