Skip to content

Commit

Permalink
Feature/none generators (#260)
Browse files Browse the repository at this point in the history
* Allow generators to be none

* Doc string

* Method name

* Update tests

* Update tests
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jul 28, 2018
1 parent 4654ee1 commit baacf95
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 27 deletions.
103 changes: 84 additions & 19 deletions tests/test_torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_fit_valid_sets_args(self, gtvs):
optimizer = MagicMock()
metric = Metric('test')

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

gtvs.return_value = (1, 2)
Expand All @@ -48,7 +48,7 @@ def test_fit_no_valid(self):
optimizer = MagicMock()
metric = Metric('test')

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_main_loop_metrics(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -102,7 +102,7 @@ def test_main_loop_verbose(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

import sys
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_main_loop_train_steps_positive(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -157,7 +157,7 @@ def test_main_loop_train_steps_fractional(self, _):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -181,7 +181,7 @@ def test_main_loop_validation_setup(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -208,7 +208,7 @@ def test_main_loop_epochs_positive(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -231,7 +231,7 @@ def test_main_loop_epochs_zero(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -254,7 +254,7 @@ def test_main_loop_epochs_negative(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -278,7 +278,7 @@ def test_main_loop_epochs_fractional(self, _):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -302,14 +302,36 @@ def test_main_loop_epochs_none(self, warning):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
torchbearerstate = torchbearermodel.fit_generator(generator, train_steps, epochs, 0, [callback], initial_epoch=0, pass_state=False)

self.assertTrue(warning.call_count == 1)

def test_main_loop_none_gen(self):
metric = Metric('test')

generator = None
train_steps = 8

epochs = 1

callback = MagicMock()

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
torchbearerstate = torchbearermodel.fit_generator(generator, train_steps, epochs, 0, [callback], initial_epoch=0, pass_state=False)

self.assertTrue(torchbearerstate[torchbearer.MODEL].call_count == train_steps)

def test_main_loop_train_steps_too_big(self):
metric = Metric('test')

Expand All @@ -325,7 +347,7 @@ def test_main_loop_train_steps_too_big(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -348,7 +370,7 @@ def test_main_loop_train_steps_negative(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -371,7 +393,7 @@ def test_main_loop_pass_state(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -394,7 +416,7 @@ def test_main_loop_optimizer(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand All @@ -417,7 +439,7 @@ def test_main_loop_criterion(self):
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand Down Expand Up @@ -796,6 +818,32 @@ def test_test_loop_stop_training(self):

self.assertTrue(torchbearerstate[torchbearer.MODEL].call_count == 1)

def test_test_loop_none_gen(self):
metric = Metric('test')
metric_list = MetricList([metric])

validation_generator = None
validation_steps = 8

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = Mock(return_value=1)
optimizer = MagicMock()

criterion = Mock(return_value=2)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])

state = torchbearermodel.main_state.copy()
state.update({torchbearer.METRIC_LIST: metric_list, torchbearer.VALIDATION_GENERATOR: validation_generator,
torchbearer.CallbackList: callback_List, torchbearer.VALIDATION_STEPS: validation_steps,
torchbearer.CRITERION: criterion, torchbearer.STOP_TRAINING: False, torchbearer.METRICS: {}})

torchbearerstate = torchbearermodel._test_loop(state, callback_List, False, Model._load_batch_none, num_steps=validation_steps)

self.assertTrue(torchbearerstate[torchbearer.MODEL].call_count == validation_steps)

def test_evaluate(self):
x = torch.rand(1,5)
y = torch.rand(1,5)
Expand All @@ -807,7 +855,7 @@ def test_evaluate(self):
optimizer = MagicMock()
metric = Metric('test')

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand Down Expand Up @@ -836,6 +884,23 @@ def test_evaluate_generator_args(self):
self.assertTrue(torchbearermodel._test_loop.call_args[0][2] == pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][4] == steps)

def test_evaluate_generator_none(self):
torchmodel = MagicMock()
optimizer = MagicMock()
generator = None

pass_state = False
steps = 10

torchbearermodel = Model(torchmodel, optimizer, torch.nn.L1Loss(), [])
torchbearermodel.main_state[torchbearer.METRICS] = 1
torchbearermodel._test_loop = Mock()

torchbearermodel.evaluate_generator(generator, 0, steps, pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][1].callback_list == [])
self.assertTrue(torchbearermodel._test_loop.call_args[0][2] == pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][4] == steps)

def test_evaluate_generator_verbose(self):
from torchbearer.callbacks import Tqdm

Expand Down Expand Up @@ -893,7 +958,7 @@ def test_predict(self):
optimizer = MagicMock()
metric = Metric('test')

loss = torch.tensor([2], requires_grad=True)
loss = torch.tensor([2.0], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
Expand Down
40 changes: 32 additions & 8 deletions torchbearer/torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def fit_generator(self, generator, train_steps=None, epochs=1, verbose=1, callba
# Get train and validation steps
if validation_steps is None and validation_generator is not None:
validation_steps = len(validation_generator)
if train_steps is None or train_steps > len(generator):
if train_steps is None:
train_steps = len(generator)
if generator is not None and train_steps > len(generator):
train_steps = len(generator)
if not isinstance(train_steps, int):
train_steps = int(train_steps)
Expand Down Expand Up @@ -146,7 +148,8 @@ def fit_generator(self, generator, train_steps=None, epochs=1, verbose=1, callba
for state[torchbearer.EPOCH] in range(initial_epoch, epochs):
_callbacks.on_start_epoch(state)

state[torchbearer.TRAIN_ITERATOR] = iter(state[torchbearer.GENERATOR])
if state[torchbearer.GENERATOR] is not None:
state[torchbearer.TRAIN_ITERATOR] = iter(state[torchbearer.GENERATOR])
self.train()

_callbacks.on_start_training(state)
Expand All @@ -155,7 +158,11 @@ def fit_generator(self, generator, train_steps=None, epochs=1, verbose=1, callba

for state[torchbearer.BATCH] in range(0, state[torchbearer.TRAIN_STEPS]):
# Extract batch
self._load_batch_standard('train', state)
if state[torchbearer.GENERATOR] is None: # TODO: Replace with flag check
self._load_batch_none('train', state)
else:
self._load_batch_standard('train', state)

_callbacks.on_sample(state)

# Zero grads
Expand Down Expand Up @@ -191,7 +198,7 @@ def fit_generator(self, generator, train_steps=None, epochs=1, verbose=1, callba
_callbacks.on_end_training(state)

# Validate
if validation_generator is not None:
if validation_generator is not None or validation_steps is not None:
state[torchbearer.VALIDATION_GENERATOR] = validation_generator
state[torchbearer.VALIDATION_STEPS] = validation_steps
self.eval()
Expand Down Expand Up @@ -226,15 +233,17 @@ def _test_loop(self, state, callbacks, pass_state, batch_loader, num_steps=None)
state[torchbearer.METRIC_LIST].reset(state)
state[torchbearer.METRICS] = {}


if num_steps is None or num_steps > len(state[torchbearer.VALIDATION_GENERATOR]):
if num_steps is None:
num_steps = len(state[torchbearer.VALIDATION_GENERATOR])
if state[torchbearer.VALIDATION_GENERATOR] is not None and num_steps > len(state[torchbearer.VALIDATION_GENERATOR]):
num_steps = len(state[torchbearer.VALIDATION_GENERATOR])
if not isinstance(num_steps, int):
num_steps = int(num_steps)
warnings.warn('Num test steps is not an int, converting to int.', Warning)

state[torchbearer.VALIDATION_STEPS] = num_steps
state[torchbearer.VALIDATION_ITERATOR] = iter(state[torchbearer.VALIDATION_GENERATOR])
if state[torchbearer.VALIDATION_GENERATOR] is not None:
state[torchbearer.VALIDATION_ITERATOR] = iter(state[torchbearer.VALIDATION_GENERATOR])

callbacks.on_start_validation(state)

Expand Down Expand Up @@ -321,7 +330,13 @@ def evaluate_generator(self, generator, verbose=1, steps=None, pass_state=False)
_callbacks = []
if verbose == 1:
_callbacks.append(Tqdm('e'))
self._test_loop(state, CallbackList(_callbacks), pass_state, self._load_batch_standard, steps)

if state[torchbearer.VALIDATION_GENERATOR] is None:
batch_loader = self._load_batch_none
else:
batch_loader = self._load_batch_standard

self._test_loop(state, CallbackList(_callbacks), pass_state, batch_loader, steps)

return state[torchbearer.METRICS]

Expand Down Expand Up @@ -481,6 +496,15 @@ def _load_batch_standard(iterator, state):
"""
state[torchbearer.X], state[torchbearer.Y_TRUE] = Model._deep_to(next(state[iterator + '_iterator']), state[torchbearer.DEVICE], state[torchbearer.DATA_TYPE])

@staticmethod
def _load_batch_none(_, state):
"""Static method to load a none (none, none) tuple mini-batch into state
:param state: The current state dict of the :class:`Model`.
:type state: dict[str,any]
"""
state[torchbearer.X], state[torchbearer.Y_TRUE] = None, None

@staticmethod
def _load_batch_predict(iterator, state):
""" Static method to load a prediction (input data, target) or (input data) mini-batch from iterator into state
Expand Down

0 comments on commit baacf95

Please sign in to comment.