Skip to content

Commit

Permalink
Example/gan none loss (#274)
Browse files Browse the repository at this point in the history
* Add none loss option for torchbearer models

* Fix error in docs

* Change loss_criterion -> criterion

* Update changelog
  • Loading branch information
MattPainter01 authored and ethanwharris committed Aug 2, 2018
1 parent 172593b commit eecff42
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 21 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a verbose level (options are now 0,1,2) which will print progress for the entire fit call, updating every epoch. Useful when doing dynamic programming with little data.
### Changed
- Timer callback can now also be used as a metric which allows display of specified timings to printers and has been moved to metrics.
- The loss_criterion is renamed to criterion in `torchbearer.Model` arguments.
- The criterion in `torchbearer.Model` is now optional and will provide a zero loss tensor if it is not given.
### Deprecated
### Removed
### Fixed
Expand Down
8 changes: 2 additions & 6 deletions docs/_static/examples/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def forward(self, real_imgs, state):
z = Variable(torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim)))).to(state[tb.DEVICE])
state[GEN_IMGS] = self.generator(z)
state[DISC_GEN] = self.discriminator(state[GEN_IMGS])
# We don't want to keep discriminator gradients on the generator forward pass
# This clears the function graph built up for the discriminator
self.discriminator.zero_grad()

# Discriminator Forward
Expand Down Expand Up @@ -156,10 +156,6 @@ def process(self, state):
return state[D_LOSS]


def zero_loss(y_pred, y_true):
return torch.zeros(y_true.shape[0], 1)


torchbearermodel = tb.Model(model, optim, zero_loss, ['loss', g_loss(), d_loss()])
torchbearermodel = tb.Model(model, optim, loss_criterion=None, metrics=['loss', g_loss(), d_loss()])
torchbearermodel.to(device)
torchbearermodel.fit_generator(dataloader, epochs=200, pass_state=True, callbacks=[loss_callback, saver_callback])
9 changes: 2 additions & 7 deletions docs/examples/gan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ Loss
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Since our loss is complicated in this example, we shall forgo the basic loss criterion used in normal torchbearer models.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 159-160

Instead use a callback to provide the loss. Since this callback is very simple we can use callback decorators on a function (which takes state) to tell torchbearer when it should be called.
Instead we use a callback to provide the loss. Since this callback is very simple we can use callback decorators on a function (which takes state) to tell torchbearer when it should be called.

.. literalinclude:: /_static/examples/gan.py
:language: python
Expand All @@ -73,7 +68,7 @@ We can then train the torchbearer model on the GPU in the standard way.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 163-165
:lines: 159-161

Visualising
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 0 additions & 4 deletions torchbearer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
:members:
:undoc-members:
.. automodule:: torchbearer.callbacks.timer
:members:
:undoc-members:
Tensorboard
------------------------------------
Expand Down
7 changes: 7 additions & 0 deletions torchbearer/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
.. automodule:: torchbearer.metrics.roc_auc_score
:members:
Timer
------------------------------------
.. automodule:: torchbearer.metrics.timer
:members:
:undoc-members:
"""

from .metrics import *
Expand Down
11 changes: 7 additions & 4 deletions torchbearer/torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,25 @@
class Model:
""" Torchbearermodel to wrap base torch model and provide training environment around it
"""
def __init__(self, model, optimizer, loss_criterion, metrics=[]):
def __init__(self, model, optimizer, criterion=None, metrics=[]):
""" Create torchbearermodel which wraps a base torchmodel and provides a training environment surrounding it
:param model: The base pytorch model
:type model: torch.nn.Module
:param optimizer: The optimizer used for pytorch model weight updates
:type optimizer: torch.optim.Optimizer
:param loss_criterion: The final loss criterion that provides a loss value to the optimizer
:type loss_criterion: function
:param criterion: The final loss criterion that provides a loss value to the optimizer
:type criterion: function or None
:param metrics: Additional metrics for display and use within callbacks
:type metrics: list
"""
super().__init__()
if criterion is None:
criterion = lambda y_pred, y_true: torch.zeros(y_true.shape[0], device=y_true.device)

self.main_state = {
torchbearer.MODEL: model,
torchbearer.CRITERION: loss_criterion,
torchbearer.CRITERION: criterion,
torchbearer.OPTIMIZER: optimizer,
torchbearer.DEVICE: 'cpu',
torchbearer.DATA_TYPE: torch.float32,
Expand Down

0 comments on commit eecff42

Please sign in to comment.