Skip to content

Commit

Permalink
Example/vae (#209)
Browse files Browse the repository at this point in the history
* Merge with master

* Add vae example

* Formatting

* Moved example code files

* Formatting vae example

* Formatting
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jul 18, 2018
1 parent f944722 commit 958cf70
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,5 @@ ENV/
.idea

# Ignore examples/data
/docs/_static/examples/data
/docs/_static/examples/results
/docs/_static/examples/data
123 changes: 123 additions & 0 deletions docs/_static/examples/vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.utils import save_image

import torchbearer
from torchbearer.callbacks import Callback


class AutoEncoderMNIST(Dataset):
def __init__(self, mnist_dataset):
super().__init__()
self.mnist_dataset = mnist_dataset

def __getitem__(self, index):
character, label = self.mnist_dataset.__getitem__(index)
return character, character

def __len__(self):
return len(self.mnist_dataset)


BATCH_SIZE = 128

transform = transforms.Compose([transforms.ToTensor()])

# Define standard classification mnist dataset

basetrainset = torchvision.datasets.MNIST('./data/mnist', train=True, download=True, transform=transform)

basetestset = torchvision.datasets.MNIST('./data/mnist', train=False, download=True, transform=transform)

# Wrap base classification mnist dataset to return the image as the target

trainset = AutoEncoderMNIST(basetrainset)

testset = AutoEncoderMNIST(basetestset)

traingen = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

testgen = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)


class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()

self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)

def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)

def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu

def decode(self, z):
h3 = F.relu(self.fc3(z))
return F.sigmoid(self.fc4(h3))

def forward(self, x, state):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
state['mu'] = mu
state['logvar'] = logvar
return self.decode(z)


def bce_loss(y_pred, y_true):
BCE = F.binary_cross_entropy(y_pred, y_true.view(-1, 784), size_average=False)
return BCE


class AddKLDLoss(Callback):
def on_criterion(self, state):
super().on_criterion(state)
KLD = self.KLD_Loss(state['mu'], state['logvar'])
state[torchbearer.LOSS] = state[torchbearer.LOSS] + KLD

def KLD_Loss(self, mu, logvar):
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return KLD


class SaveReconstruction(Callback):
def __init__(self, num_images=8, folder='results/'):
super().__init__()
self.num_images = num_images
self.folder = folder

def on_step_validation(self, state):
super().on_step_validation(state)
if state[torchbearer.BATCH] == 0:
data = state[torchbearer.X]
recon_batch = state[torchbearer.Y_PRED]
comparison = torch.cat([data[:self.num_images],
recon_batch.view(128, 1, 28, 28)[:self.num_images]])
save_image(comparison.cpu(),
str(self.folder) + 'reconstruction_' + str(state[torchbearer.EPOCH]) + '.png', nrow=self.num_images)


model = VAE()

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
loss = bce_loss

from torchbearer import Model

torchbearer_model = Model(model, optimizer, loss, metrics=['loss']).to('cuda')
torchbearer_model.fit_generator(traingen, epochs=10, validation_generator=testgen, callbacks=[AddKLDLoss(), SaveReconstruction()], pass_state=True)
98 changes: 98 additions & 0 deletions docs/_static/examples/vae_standard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data.dataset import Dataset
from torchvision import transforms


class AutoEncoderMNIST(Dataset):
def __init__(self, mnist_dataset):
super().__init__()
self.mnist_dataset = mnist_dataset

def __getitem__(self, index):
character, label = self.mnist_dataset.__getitem__(index)
return character, character

def __len__(self):
return len(self.mnist_dataset)


BATCH_SIZE = 128

normalize = transforms.Compose([transforms.ToTensor()])

# Define standard classification mnist dataset

basetrainset = torchvision.datasets.MNIST('./data/mnist', train=True, download=True, transform=normalize)

basetestset = torchvision.datasets.MNIST('./data/mnist', train=False, download=True, transform=normalize)

# Wrap base classification mnist dataset to return the image as the target

trainset = AutoEncoderMNIST(basetrainset)

testset = AutoEncoderMNIST(basetestset)

traingen = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

testgen = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)


class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()

self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)

def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)

def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu

def decode(self, z):
h3 = F.relu(self.fc3(z))
return F.sigmoid(self.fc4(h3))

def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar


def bce_plus_kld_loss(y_pred, y_true):
recon_x, mu, logvar = y_pred
x = y_true
return loss_function(recon_x, x, mu, logvar)

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False)

KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

return BCE + KLD


model = VAE()

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
loss = loss_function

from torchbearer import Model

torchbearer_model = Model(model, optimizer, loss, metrics=['loss']).to('cuda')
torchbearer_model.fit_generator(traingen, epochs=10, validation_generator=testgen, pass_state=False)
Binary file added docs/_static/img/reconstruction_9.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
135 changes: 135 additions & 0 deletions docs/examples/vae.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
Training a Variational Auto-Encoder
====================================

This guide will give a quick guide on training a variational auto-encoder (VAE) in torchbearer. We will use the VAE example from the pytorch examples here_:

.. _here: https://github.com/pytorch/examples/tree/master/vae

Defining the Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We shall first copy the VAE example model.

.. literalinclude:: /_static/examples/vae_standard.py
:language: python
:lines: 44-73

Defining the Data
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We get the MNIST dataset from torchvision and transform them to torch tensors.

.. literalinclude:: /_static/examples/vae_standard.py
:language: python
:lines: 23-31

The output label from this dataset is the classification label, since we are doing a auto-encoding problem, we wish the label to be the original image. To fix this we create a wrapper class which replaces the classification label with the image.

.. literalinclude:: /_static/examples/vae_standard.py
:language: python
:lines: 10-20

We then wrap the original datasets and create training and testing data generators in the standard pytorch way.

.. literalinclude:: /_static/examples/vae_standard.py
:language: python
:lines: 33-41

Defining the Loss
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Now we have the model and data, we will need a loss function to optimize.
VAEs typically take the sum of a reconstruction loss and a KL-divergence loss to form the final loss value.
There are two ways this can be done in torchbearer - one is very similar to the PyTorch example method and the other utilises the torchbearer state.

PyTorch method
------------------------------------

The loss function from the PyTorch example is:

.. literalinclude:: /_static/examples/vae_standard.py
:language: python
:lines: 82-87

This requires the packing of the reconstruction, mean and log-variance into the model output and unpacking it for the loss function to use.

.. literalinclude:: /_static/examples/vae_standard.py
:language: python
:lines: 70-73

.. literalinclude:: /_static/examples/vae_standard.py
:language: python
:lines: 76-79


Using Torchbearer State
------------------------------------

Instead of having to pack and unpack the mean and variance in the forward pass, in torchbearer there is a persistent state dictionary which can be used to conveniently hold such intermediate tensors.

By default the model forward pass does not have access to the state dictionary, but setting the ``pass_state`` flag to true in the fit_generator_ call gives the model access to state on forward.

.. _fit_generator: https://torchbearer.readthedocs.io/en/latest/code/main.html#torchbearer.torchbearer.Model.fit_generator

.. literalinclude:: /_static/examples/vae.py
:language: python
:lines: 123

We can then modify the model forward pass to store the mean and log-variance under suitable keys.

.. literalinclude:: /_static/examples/vae.py
:language: python
:lines: 74-79

The loss can then be separated into a standard reconstruction loss and a separate KL-divergence loss using intermediate tensor values.

.. literalinclude:: /_static/examples/vae.py
:language: python
:lines: 82-84

Since loss functions cannot access state, we utilise a simple callback to complete the loss calculation.

.. literalinclude:: /_static/examples/vae.py
:language: python
:lines: 87-95


Visualising Results
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

For auto-encoding problems it is often useful to visualise the reconstructions. We can do this in torchbearer by using another simple callback. We stack the first 8 images from the first validation batch and pass them to torchvisions_ save_image_ function which saves out visualisations.

.. _torchvisions: https://github.com/pytorch/vision
.. _save_image: https://pytorch.org/docs/stable/torchvision/utils.html?highlight=save#torchvision.utils.save_image

.. literalinclude:: /_static/examples/vae.py
:lines: 98-112

Training the Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We train the model by creating a torchmodel and a torchbearermodel and calling fit_generator_.

.. _fit_generator: https://torchbearer.readthedocs.io/en/latest/code/main.html#torchbearer.torchbearer.Model.fit_generator


.. literalinclude:: /_static/examples/vae.py
:lines: 115-120

The visualised results after ten epochs then look like this:

.. figure:: /_static/img/reconstruction_9.png
:scale: 200 %
:alt: VAE reconstructions after 10 epochs of mnist

Source Code
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The source code for the example are given below:

Standard:

:download:`Download Python source code: vae_standard.py </_static/examples/vae_standard.py>`

Using state:

:download:`Download Python source code: vae.py </_static/examples/vae.py>`

0 comments on commit 958cf70

Please sign in to comment.