Skip to content

Commit

Permalink
Update vae example (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jul 23, 2018
1 parent 2dd4628 commit a1bea97
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,4 @@ ENV/
# Ignore examples/data
/docs/_static/examples/results
/docs/_static/examples/data
/docs/_static_examples/images
/docs/_static/examples/images
39 changes: 19 additions & 20 deletions docs/_static/examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,40 +84,39 @@ def bce_loss(y_pred, y_true):
return BCE


class AddKLDLoss(Callback):
def on_criterion(self, state):
super().on_criterion(state)
KLD = self.KLD_Loss(state['mu'], state['logvar'])
state[torchbearer.LOSS] = state[torchbearer.LOSS] + KLD
def kld_Loss(mu, logvar):
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return KLD

def KLD_Loss(self, mu, logvar):
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return KLD

@torchbearer.callbacks.add_to_loss
def add_kld_loss_callback(state):
KLD = kld_Loss(state['mu'], state['logvar'])
return KLD

class SaveReconstruction(Callback):
def __init__(self, num_images=8, folder='results/'):
super().__init__()
self.num_images = num_images
self.folder = folder

def on_step_validation(self, state):
super().on_step_validation(state)
def save_reconstruction_callback(num_images=8, folder='results/'):
import os
os.makedirs(os.path.dirname(folder), exist_ok=True)

@torchbearer.callbacks.on_step_validation
def saver(state):
if state[torchbearer.BATCH] == 0:
data = state[torchbearer.X]
recon_batch = state[torchbearer.Y_PRED]
comparison = torch.cat([data[:self.num_images],
recon_batch.view(128, 1, 28, 28)[:self.num_images]])
comparison = torch.cat([data[:num_images],
recon_batch.view(128, 1, 28, 28)[:num_images]])
save_image(comparison.cpu(),
str(self.folder) + 'reconstruction_' + str(state[torchbearer.EPOCH]) + '.png', nrow=self.num_images)
str(folder) + 'reconstruction_' + str(state[torchbearer.EPOCH]) + '.png', nrow=num_images)
return saver


model = VAE()

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
loss = bce_loss

from torchbearer import Model

torchbearer_model = Model(model, optimizer, loss, metrics=['loss']).to('cuda')
torchbearer_model.fit_generator(traingen, epochs=10, validation_generator=testgen, callbacks=[AddKLDLoss(), SaveReconstruction()], pass_state=True)
torchbearer_model.fit_generator(traingen, epochs=10, validation_generator=testgen,
callbacks=[add_kld_loss_callback, save_reconstruction_callback()], pass_state=True)
20 changes: 14 additions & 6 deletions docs/_static/examples/vae_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,24 @@ def forward(self, x):
return self.decode(z), mu, logvar


def bce_plus_kld_loss(y_pred, y_true):
def bce_loss(recon_x, x):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False)
return BCE


def kld_Loss(mu, logvar):
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return KLD


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(y_pred, y_true):
recon_x, mu, logvar = y_pred
x = y_true
return loss_function(recon_x, x, mu, logvar)

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False)
BCE = bce_loss(recon_x, x)

KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
KLD = kld_Loss(mu, logvar)

return BCE + KLD

Expand Down
29 changes: 17 additions & 12 deletions docs/examples/vae.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,32 @@ Defining the Loss
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Now we have the model and data, we will need a loss function to optimize.
VAEs typically take the sum of a reconstruction loss and a KL-divergence loss to form the final loss value.

.. literalinclude:: /_static/examples/vae.py
:language: python
:lines: 82-84

.. literalinclude:: /_static/examples/vae.py
:language: python
:lines: 87-89

There are two ways this can be done in torchbearer - one is very similar to the PyTorch example method and the other utilises the torchbearer state.

PyTorch method
------------------------------------

The loss function from the PyTorch example is:
The loss function slightly modified from the PyTorch example is:

.. literalinclude:: /_static/examples/vae_standard.py
:language: python
:lines: 82-87
:lines: 87-95

This requires the packing of the reconstruction, mean and log-variance into the model output and unpacking it for the loss function to use.

.. literalinclude:: /_static/examples/vae_standard.py
:language: python
:lines: 70-73

.. literalinclude:: /_static/examples/vae_standard.py
:language: python
:lines: 76-79


Using Torchbearer State
------------------------------------
Expand All @@ -80,17 +85,17 @@ We can then modify the model forward pass to store the mean and log-variance und
:language: python
:lines: 74-79

The loss can then be separated into a standard reconstruction loss and a separate KL-divergence loss using intermediate tensor values.
The reconstruction loss is a standard loss taking network output and the true label

.. literalinclude:: /_static/examples/vae.py
:language: python
:lines: 82-84
:lines: 116

Since loss functions cannot access state, we utilise a simple callback to complete the loss calculation.
Since loss functions cannot access state, we utilise a simple callback to combine the kld loss which does not act on network output or true label.

.. literalinclude:: /_static/examples/vae.py
:language: python
:lines: 87-95
:lines: 92-95


Visualising Results
Expand All @@ -102,7 +107,7 @@ For auto-encoding problems it is often useful to visualise the reconstructions.
.. _save_image: https://pytorch.org/docs/stable/torchvision/utils.html?highlight=save#torchvision.utils.save_image

.. literalinclude:: /_static/examples/vae.py
:lines: 98-112
:lines: 98-111

Training the Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -113,7 +118,7 @@ We train the model by creating a torchmodel and a torchbearermodel and calling f


.. literalinclude:: /_static/examples/vae.py
:lines: 115-120
:lines: 114-122

The visualised results after ten epochs then look like this:

Expand Down

0 comments on commit a1bea97

Please sign in to comment.