Skip to content

Commit

Permalink
Example/gan update (#225)
Browse files Browse the repository at this point in the history
* Add callback decorators

* Update gan example with callback decorators

* Update ignore file

* Add tests for callback decorators
  • Loading branch information
MattPainter01 committed Jul 20, 2018
1 parent 67af6c0 commit da8a481
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 43 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,4 @@ ENV/
# Ignore examples/data
/docs/_static/examples/results
/docs/_static/examples/data
/docs/_static_examples/images
69 changes: 37 additions & 32 deletions docs/_static/examples/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from torchvision.utils import save_image

import torchbearer as tb
from torchbearer.callbacks import Callback
import torchbearer.callbacks as callbacks
from torchbearer import state_key

os.makedirs('images', exist_ok=True)

Expand All @@ -22,6 +23,18 @@
latent_dim = 100
sample_interval = 400
img_shape = (1, 28, 28)
adversarial_loss = torch.nn.BCELoss()
device = 'cuda'
valid = torch.ones(batch_size, 1, device=device)
fake = torch.zeros(batch_size, 1, device=device)

# Register state keys (optional)
GEN_IMGS = state_key('gen_imgs')
DISC_GEN = state_key('disc_gen')
DISC_GEN_DET = state_key('disc_gen_det')
DISC_REAL = state_key('disc_real')
G_LOSS = state_key('g_loss')
D_LOSS = state_key('d_loss')


class Generator(nn.Module):
Expand Down Expand Up @@ -79,39 +92,31 @@ def __init__(self):
def forward(self, real_imgs, state):
# Generator Forward
z = Variable(torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim)))).to(state[tb.DEVICE])
state['gen_imgs'] = self.generator(z)
state['disc_gen'] = self.discriminator(state['gen_imgs'])
state[GEN_IMGS] = self.generator(z)
state[DISC_GEN] = self.discriminator(state[GEN_IMGS])
# We don't want to discriminator gradients on the generator forward pass
self.discriminator.zero_grad()

# Discriminator Forward
state['disc_gen_det'] = self.discriminator(state['gen_imgs'].detach())
state['disc_real'] = self.discriminator(real_imgs)

state[DISC_GEN_DET] = self.discriminator(state[GEN_IMGS].detach())
state[DISC_REAL] = self.discriminator(real_imgs)

class LossCallback(Callback):
def on_start(self, state):
super().on_start(state)
self.adversarial_loss = torch.nn.BCELoss()
state['valid'] = torch.ones(batch_size, 1, device=state[tb.DEVICE])
state['fake'] = torch.zeros(batch_size, 1, device=state[tb.DEVICE])

def on_criterion(self, state):
super().on_criterion(state)
fake_loss = self.adversarial_loss(state['disc_gen_det'], state['fake'])
real_loss = self.adversarial_loss(state['disc_real'], state['valid'])
state['g_loss'] = self.adversarial_loss(state['disc_gen'], state['valid'])
state['d_loss'] = (real_loss + fake_loss) / 2
# This is the loss that backward is called on.
state[tb.LOSS] = state['g_loss'] + state['d_loss']
@callbacks.on_criterion
def loss_callback(state):
fake_loss = adversarial_loss(state[DISC_GEN_DET], fake)
real_loss = adversarial_loss(state[DISC_REAL], valid)
state[G_LOSS] = adversarial_loss(state[DISC_GEN], valid)
state[D_LOSS] = (real_loss + fake_loss) / 2
# This is the loss that backward is called on.
state[tb.LOSS] = state[G_LOSS] + state[D_LOSS]


class SaverCallback(Callback):
def on_step_training(self, state):
super().on_step_training(state)
batches_done = state[tb.EPOCH] * len(state[tb.GENERATOR]) + state[tb.BATCH]
if batches_done % sample_interval == 0:
save_image(state['gen_imgs'].data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)
@callbacks.on_step_training
def saver_callback(state):
batches_done = state[tb.EPOCH] * len(state[tb.GENERATOR]) + state[tb.BATCH]
if batches_done % sample_interval == 0:
save_image(state[GEN_IMGS].data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)


# Configure data loader
Expand All @@ -135,26 +140,26 @@ def on_step_training(self, state):
@tb.metrics.mean
class g_loss(tb.metrics.Metric):
def __init__(self):
super().__init__('g_loss')
super().__init__(G_LOSS)

def process(self, state):
return state['g_loss']
return state[G_LOSS]


@tb.metrics.running_mean
@tb.metrics.mean
class d_loss(tb.metrics.Metric):
def __init__(self):
super().__init__('d_loss')
super().__init__(D_LOSS)

def process(self, state):
return state['d_loss']
return state[D_LOSS]


def zero_loss(y_pred, y_true):
return torch.zeros(y_true.shape[0], 1)


torchbearermodel = tb.Model(model, optim, zero_loss, ['loss', g_loss(), d_loss()])
torchbearermodel.to('cuda')
torchbearermodel.fit_generator(dataloader, epochs=200, pass_state=True, callbacks=[LossCallback(), SaverCallback()])
torchbearermodel.to(device)
torchbearermodel.fit_generator(dataloader, epochs=200, pass_state=True, callbacks=[loss_callback, saver_callback])
27 changes: 16 additions & 11 deletions docs/examples/gan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@ We first define all constants for the example.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 18-24
:lines: 19-29

We then define a number of state keys for convenience. This is optional, however, it automatically avoids key conflicts.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 32-37

We then define the dataset and dataloader - for this example, MNIST.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 119-126
:lines: 124-131

Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -29,7 +35,7 @@ We use the generator and discriminator from PyTorch_GAN_ and combine them into a

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 73-89
:lines: 86-102

Note that we have to be careful to remove the gradient information from the discriminator after doing the generator forward pass.

Expand All @@ -40,14 +46,13 @@ Since our loss is complicated in this example, we shall forgo the basic loss cri

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 154-155
:lines: 159-160

Instead use a callback to provide the loss.
We also utilise this callback to add constants to state.
Instead use a callback to provide the loss. Since this callback is very simple we can use callback decorators on a function (which takes state) to tell torchbearer when it should be called.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 92-106
:lines: 105-112

Note that we have summed the separate discriminator and generator losses since their graphs are separated, this is allowable.

Expand All @@ -59,7 +64,7 @@ We can then create metrics from these by decorating simple state fetcher metrics

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 134-141
:lines: 139-146

Training
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -68,16 +73,16 @@ We can then train the torchbearer model on the GPU in the standard way.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 158-160
:lines: 163-165

Visualising
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We borrow the image saving method from PyTorch_GAN_ and put it in a call back to save on training step.
We borrow the image saving method from PyTorch_GAN_ and put it in a call back to save on training step - again using decorators.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 109-114
:lines: 115-119

After 172400 iterations we see the following.

Expand Down
105 changes: 105 additions & 0 deletions tests/callbacks/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import unittest

import torchbearer.callbacks as callbacks


class TestDecorators(unittest.TestCase):

def test_on_start(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_start(example).on_start(state) == state)

def test_on_start_epoch(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_start_epoch(example).on_start_epoch(state) == state)

def test_on_start_training(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_start_training(example).on_start_training(state) == state)

def test_on_sample(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_sample(example).on_sample(state) == state)

def test_on_forward(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_forward(example).on_forward(state) == state)

def test_on_criterion(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_criterion(example).on_criterion(state) == state)

def test_on_backward(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_backward(example).on_backward(state) == state)

def test_on_step_training(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_step_training(example).on_step_training(state) == state)

def test_on_end_training(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_end_training(example).on_end_training(state) == state)

def test_on_end_epoch(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_end_epoch(example).on_end_epoch(state) == state)

def test_on_end(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_end(example).on_end(state) == state)

def test_on_start_validation(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_start_validation(example).on_start_validation(state) == state)

def test_on_sample_validation(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_sample_validation(example).on_sample_validation(state) == state)

def test_on_forward_validation(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_forward_validation(example).on_forward_validation(state) == state)

def test_on_end_validation(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_end_validation(example).on_end_validation(state) == state)

def test_on_step_validation(self):
def example(state):
return state
state = 'test'
self.assertTrue(callbacks.on_step_validation(example).on_step_validation(state) == state)



1 change: 1 addition & 0 deletions torchbearer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@
from .torch_scheduler import *
from .weight_decay import *
from .aggregate_predictions import *
from .decorators import *

0 comments on commit da8a481

Please sign in to comment.