Skip to content

Commit

Permalink
Example/gan (#223)
Browse files Browse the repository at this point in the history
* Add gan example

* Finish GAN example
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jul 20, 2018
1 parent 22c9399 commit b89f285
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 0 deletions.
160 changes: 160 additions & 0 deletions docs/_static/examples/gan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import os

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image

import torchbearer as tb
from torchbearer.callbacks import Callback

os.makedirs('images', exist_ok=True)

# Define constants
epochs = 200
batch_size = 64
lr = 0.0002
nworkers = 8
latent_dim = 100
sample_interval = 400
img_shape = (1, 28, 28)


class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)

def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *img_shape)
return img


class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()

self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)

def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)

return validity


class GAN(nn.Module):
def __init__(self):
super().__init__()
self.discriminator = Discriminator()
self.generator = Generator()

def forward(self, real_imgs, state):
# Generator Forward
z = Variable(torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim)))).to(state[tb.DEVICE])
state['gen_imgs'] = self.generator(z)
state['disc_gen'] = self.discriminator(state['gen_imgs'])
# We don't want to discriminator gradients on the generator forward pass
self.discriminator.zero_grad()

# Discriminator Forward
state['disc_gen_det'] = self.discriminator(state['gen_imgs'].detach())
state['disc_real'] = self.discriminator(real_imgs)


class LossCallback(Callback):
def on_start(self, state):
super().on_start(state)
self.adversarial_loss = torch.nn.BCELoss()
state['valid'] = torch.ones(batch_size, 1, device=state[tb.DEVICE])
state['fake'] = torch.zeros(batch_size, 1, device=state[tb.DEVICE])

def on_criterion(self, state):
super().on_criterion(state)
fake_loss = self.adversarial_loss(state['disc_gen_det'], state['fake'])
real_loss = self.adversarial_loss(state['disc_real'], state['valid'])
state['g_loss'] = self.adversarial_loss(state['disc_gen'], state['valid'])
state['d_loss'] = (real_loss + fake_loss) / 2
# This is the loss that backward is called on.
state[tb.LOSS] = state['g_loss'] + state['d_loss']


class SaverCallback(Callback):
def on_step_training(self, state):
super().on_step_training(state)
batches_done = state[tb.EPOCH] * len(state[tb.GENERATOR]) + state[tb.BATCH]
if batches_done % sample_interval == 0:
save_image(state['gen_imgs'].data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)


# Configure data loader
os.makedirs('./data/mnist', exist_ok=True)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.MNIST('./data/mnist', train=True, download=True, transform=transform)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)


# Model and optimizer
model = GAN()
optim = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))


@tb.metrics.running_mean
@tb.metrics.mean
class g_loss(tb.metrics.Metric):
def __init__(self):
super().__init__('g_loss')

def process(self, state):
return state['g_loss']


@tb.metrics.running_mean
@tb.metrics.mean
class d_loss(tb.metrics.Metric):
def __init__(self):
super().__init__('d_loss')

def process(self, state):
return state['d_loss']


def zero_loss(y_pred, y_true):
return torch.zeros(y_true.shape[0], 1)


torchbearermodel = tb.Model(model, optim, zero_loss, ['loss', g_loss(), d_loss()])
torchbearermodel.to('cuda')
torchbearermodel.fit_generator(dataloader, epochs=200, pass_state=True, callbacks=[LossCallback(), SaverCallback()])
Binary file added docs/_static/img/172400.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
97 changes: 97 additions & 0 deletions docs/examples/gan.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
Training a GAN
====================================

We shall try to implement something more complicated using torchbearer - a Generative Adverserial Network (GAN).
This tutorial is a modified version of the GAN_ from the brilliant collection of GAN implementations PyTorch_GAN_ by eriklindernoren on github.

.. _PyTorch_GAN: https://github.com/eriklindernoren/PyTorch-GAN
.. _GAN: https://github.com/eriklindernoren/PyTorch-GAN#gan

Data and Constants
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We first define all constants for the example.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 18-24

We then define the dataset and dataloader - for this example, MNIST.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 119-126

Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We use the generator and discriminator from PyTorch_GAN_ and combine them into a model that performs a single forward pass.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 73-89

Note that we have to be careful to remove the gradient information from the discriminator after doing the generator forward pass.

Loss
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Since our loss is complicated in this example, we shall forgo the basic loss criterion used in normal torchbearer models.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 154-155

Instead use a callback to provide the loss.
We also utilise this callback to add constants to state.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 92-106

Note that we have summed the separate discriminator and generator losses since their graphs are separated, this is allowable.

Metrics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We would like to follow the discriminator and generator losses during training - note that we added these to state during the model definition.
We can then create metrics from these by decorating simple state fetcher metrics.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 134-141

Training
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We can then train the torchbearer model on the GPU in the standard way.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 158-160

Visualising
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We borrow the image saving method from PyTorch_GAN_ and put it in a call back to save on training step.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 109-114

After 172400 iterations we see the following.

.. figure:: /_static/img/172400.png
:scale: 200 %
:alt: GAN generated samples after 172400 iterations


Source Code
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The source code for the example is given below:

:download:`Download Python source code: gan.py </_static/examples/gan.py>`



0 comments on commit b89f285

Please sign in to comment.