Skip to content

Commit

Permalink
Example/gan add to loss (#283)
Browse files Browse the repository at this point in the history
* Update gan example with add to loss

* Update example

* Update docs

* Formatting
  • Loading branch information
MattPainter01 authored and ethanwharris committed Aug 3, 2018
1 parent 6243524 commit 77ecc58
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 13 deletions.
5 changes: 2 additions & 3 deletions docs/_static/examples/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,13 @@ def forward(self, real_imgs, state):
state[DISC_REAL] = self.discriminator(real_imgs)


@callbacks.on_criterion
@callbacks.add_to_loss
def loss_callback(state):
fake_loss = adversarial_loss(state[DISC_GEN_DET], fake)
real_loss = adversarial_loss(state[DISC_REAL], valid)
state[G_LOSS] = adversarial_loss(state[DISC_GEN], valid)
state[D_LOSS] = (real_loss + fake_loss) / 2
# This is the loss that backward is called on.
state[tb.LOSS] = state[G_LOSS] + state[D_LOSS]
return state[G_LOSS] + state[D_LOSS]


@callbacks.on_step_training
Expand Down
20 changes: 11 additions & 9 deletions docs/examples/gan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ We first define all constants for the example.
:language: python
:lines: 19-29

We then define a number of state keys for convenience. This is optional, however, it automatically avoids key conflicts.
We then define a number of state keys for convenience using :func:`.state_key`. This is optional, however, it automatically avoids key conflicts.

.. literalinclude:: /_static/examples/gan.py
:language: python
Expand All @@ -26,7 +26,7 @@ We then define the dataset and dataloader - for this example, MNIST.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 124-131
:lines: 123-130

Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -42,14 +42,15 @@ Note that we have to be careful to remove the gradient information from the disc
Loss
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Since our loss is complicated in this example, we shall forgo the basic loss criterion used in normal torchbearer models.
Instead we use a callback to provide the loss. Since this callback is very simple we can use callback decorators on a function (which takes state) to tell torchbearer when it should be called.
Since our loss computation in this example is complicated, we shall forgo the basic loss criterion used in normal torchbearer models.
Instead we use a callback to provide the loss, in this case we use the :func:`.add_to_loss` callback decorator.
This decorates a function that returns a loss and automatically adds this to the main loss in training and validation.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 105-112
:lines: 105-111

Note that we have summed the separate discriminator and generator losses since their graphs are separated, this is allowable.
Note that we have summed the separate discriminator and generator losses, since their graphs are separated, this is allowable.

Metrics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -65,19 +66,20 @@ Training
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We can then train the torchbearer model on the GPU in the standard way.
Note that when torchbearer is passed a ``None`` criterion it automatically sets the base loss to 0.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 159-161
:lines: 158-160

Visualising
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We borrow the image saving method from PyTorch_GAN_ and put it in a call back to save on training step - again using decorators.
We borrow the image saving method from PyTorch_GAN_ and put it in a call back to save :func:`~torchbearer.callbacks.decorators.on_step_training` - again using decorators.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 115-119
:lines: 114-118

Here is a Gif created from the saved images.

Expand Down
6 changes: 6 additions & 0 deletions torchbearer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
"""
Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: torchbearer.torchbearer
:members:
:undoc-members:
Utilities
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: torchbearer.state
:members:
:undoc-members:
Expand Down
7 changes: 7 additions & 0 deletions torchbearer/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@


def state_key(key):
"""Computes and returns a non-conflicting key for the state dictionary when given a seed key
:param key: The seed key - basis for new state key
:type key: String
:return: New state key
:rtype: String
"""
if key in STATE_KEYS:
count = 1
my_key = key + '_' + str(count)
Expand Down
2 changes: 1 addition & 1 deletion torchbearer/torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Model:
def __init__(self, model, optimizer, criterion=None, metrics=[]):
super().__init__()
if criterion is None:
criterion = lambda y_pred, y_true: torch.zeros(y_true.shape[0], device=y_true.device)
criterion = lambda y_pred, y_true: torch.zeros(1, device=y_true.device)

self.main_state = {
torchbearer.MODEL: model,
Expand Down

0 comments on commit 77ecc58

Please sign in to comment.