Skip to content

Commit

Permalink
Fix/no train steps (#566)
Browse files Browse the repository at this point in the history
* Fix no train steps bug and add Mock model

* Update changelog

* Set test steps to 0 if None
  • Loading branch information
MattPainter01 committed Jun 12, 2019
1 parent 924336d commit 1d88b8d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [Unreleased]
### Added
- Added ``with_loader`` trial method that allows running of custom batch loaders
- Added a Mock Model which is set when None is passed as the model to a Trial. Mock Model always returns None.
### Changed
### Deprecated
### Removed
- Removed the variational sub-package, this will now be packaged separately
### Fixed
- Fixed a bug where list or dictionary metrics would cause the tensorboard callback to error
- Fixed a bug where running a trial without training steps would error

## [0.3.2] - 2019-05-28
### Added
Expand Down
7 changes: 7 additions & 0 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,10 @@ def test_no_model(self):
with warnings.catch_warnings(record=True) as w:
tbmodel.run()
self.assertTrue(len(w) == 1)

self.assertTrue(torchbearer.trial.MockModel()(torch.rand(1)) is None)

def test_no_train_steps(self):
tbmodel = torchbearer.Trial(None)
tbmodel.for_val_steps(10)
tbmodel.run()
16 changes: 12 additions & 4 deletions torchbearer/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def get_default(fcn, arg):
import itertools

import torch
import torch.nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Optimizer

Expand Down Expand Up @@ -61,6 +62,11 @@ def zero_grad(self):
pass # Do Nothing


class MockModel(torch.nn.Module):
def forward(self, x, state=None):
return None


class CallbackListInjection(CallbackList):
"""This class allows for an callback to be injected into a callback list, without masking the methods available for
mutating the list. In this way, callbacks (such as printers) can be injected seamlessly into the methods of the
Expand Down Expand Up @@ -719,11 +725,11 @@ def run(self, epochs=1, verbose=-1):
torchbearer.STOP_TRAINING: False,
})

state.update(self.state) # TODO: Swap this for something which makes `self.state` still mutable

if state[torchbearer.MODEL] is None or not callable(state[torchbearer.MODEL]):
if self.state[torchbearer.MODEL] is None or not callable(self.state[torchbearer.MODEL]):
warnings.warn('The Model is None or not callable which may cause issues if not deliberate')
state[torchbearer.MODEL] = lambda *args, **kwargs: None
self.state[torchbearer.MODEL] = MockModel()

state.update(self.state) # TODO: Swap this for something which makes `self.state` still mutable

if state[torchbearer.TRAIN_GENERATOR] is not None \
or state[torchbearer.TRAIN_STEPS] is not None \
Expand Down Expand Up @@ -773,6 +779,7 @@ def _fit_pass(self, state):
state[torchbearer.METRIC_LIST].reset(state)
state[torchbearer.METRICS] = {}

state[torchbearer.STEPS] = 0 if state[torchbearer.STEPS] is None else state[torchbearer.STEPS]
state[torchbearer.CALLBACK_LIST].on_start_training(state)
for state[torchbearer.BATCH] in (range(state[torchbearer.STEPS]) if state[torchbearer.STEPS] != -1 else itertools.count()):
state[torchbearer.SAMPLER](state)
Expand Down Expand Up @@ -801,6 +808,7 @@ def _test_pass(self, state):

state[torchbearer.CALLBACK_LIST].on_start_validation(state)

state[torchbearer.STEPS] = 0 if state[torchbearer.STEPS] is None else state[torchbearer.STEPS]
for state[torchbearer.BATCH] in range(state[torchbearer.STEPS]):
state[torchbearer.SAMPLER](state)
state[torchbearer.CALLBACK_LIST].on_sample_validation(state)
Expand Down

0 comments on commit 1d88b8d

Please sign in to comment.