Skip to content

Commit

Permalink
Feature/progress bar (#273)
Browse files Browse the repository at this point in the history
* Add verbose level for epoch progress

* Update CHANGELOG.md
  • Loading branch information
ethanwharris committed Aug 2, 2018
1 parent 6fdd87b commit 172593b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- Added a verbose level (options are now 0,1,2) which will print progress for the entire fit call, updating every epoch. Useful when doing dynamic programming with little data.
### Changed
- Timer callback can now also be used as a metric which allows display of specified timings to printers and has been moved to metrics.
### Deprecated
Expand Down
2 changes: 1 addition & 1 deletion tests/test_torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def test_predict_generator_verbose(self):
torchbearermodel._test_loop = Mock()

torchbearermodel.predict_generator(generator, 1, steps, pass_state)
self.assertIsInstance(torchbearermodel._test_loop.call_args[0][1].callback_list[1], Tqdm)
self.assertIsInstance(torchbearermodel._test_loop.call_args[0][1].callback_list[0], Tqdm)
self.assertTrue(torchbearermodel._test_loop.call_args[0][2] == pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][4] == steps)

Expand Down
35 changes: 28 additions & 7 deletions torchbearer/callbacks/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,17 @@ class Tqdm(Callback):
"""The Tqdm callback outputs the progress and metrics for training and validation loops to the console using TQDM.
"""

def __init__(self, validation_label_letter='v'):
def __init__(self, validation_label_letter='v', on_epoch=False):
"""Create Tqdm callback which uses the given key to label validation output.
:param validation_label_letter: The letter to use for validation outputs.
:type validation_label_letter: str
:param on_epoch: If True, output a single progress bar which tracks epochs
:type on_epoch: bool
"""
self._loader = None
self.validation_label = validation_label_letter
self._on_epoch = on_epoch

def _on_start(self, state, letter, steps):
bar_desc = '{:d}/{:d}({:s})'.format(state[torchbearer.EPOCH], state[torchbearer.MAX_EPOCHS], letter)
Expand All @@ -62,50 +65,68 @@ def _close(self, state):
self._loader.set_postfix(state[torchbearer.METRICS])
self._loader.close()

def on_start(self, state):
if self._on_epoch:
self._loader = tqdm(total=state[torchbearer.MAX_EPOCHS])

def on_end_epoch(self, state):
if self._on_epoch:
self._update(state)

def on_end(self, state):
if self._on_epoch:
self._close(state)

def on_start_training(self, state):
"""Initialise the TQDM bar for this training phase.
:param state: The Model state
:type state: dict
"""
self._on_start(state, 't', state[torchbearer.TRAIN_STEPS])
if not self._on_epoch:
self._on_start(state, 't', state[torchbearer.TRAIN_STEPS])

def on_step_training(self, state):
"""Update the bar with the metrics from this step.
:param state: The Model state
:type state: dict
"""
self._update(state)
if not self._on_epoch:
self._update(state)

def on_end_training(self, state):
"""Update the bar with the terminal training metrics and then close.
:param state: The Model state
:type state: dict
"""
self._close(state)
if not self._on_epoch:
self._close(state)

def on_start_validation(self, state):
"""Initialise the TQDM bar for this validation phase.
:param state: The Model state
:type state: dict
"""
self._on_start(state, self.validation_label, state[torchbearer.VALIDATION_STEPS])
if not self._on_epoch:
self._on_start(state, self.validation_label, state[torchbearer.VALIDATION_STEPS])

def on_step_validation(self, state):
"""Update the bar with the metrics from this step.
:param state: The Model state
:type state: dict
"""
self._update(state)
if not self._on_epoch:
self._update(state)

def on_end_validation(self, state):
"""Update the bar with the terminal validation metrics and then close.
:param state: The Model state
:type state: dict
"""
self._close(state)
if not self._on_epoch:
self._close(state)
56 changes: 36 additions & 20 deletions torchbearer/torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, model, optimizer, loss_criterion, metrics=[]):
torchbearer.CALLBACK_LIST: torchbearer.callbacks.CallbackList([])
}

def fit(self, x, y, batch_size=None, epochs=1, verbose=1, callbacks=[], validation_split=None,
def fit(self, x, y, batch_size=None, epochs=1, verbose=2, callbacks=[], validation_split=None,
validation_data=None, shuffle=True, initial_epoch=0,
steps_per_epoch=None, validation_steps=None, workers=1, pass_state=False):
""" Perform fitting of a model to given data and label tensors
Expand All @@ -50,7 +50,7 @@ def fit(self, x, y, batch_size=None, epochs=1, verbose=1, callbacks=[], validati
:type batch_size: int
:param epochs: The number of training epochs to be run (each sample from the dataset is viewed exactly once)
:type epochs: int
:param verbose: If 1 use tqdm progress frontend, else display no training progress
:param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
:type verbose: int
:param callbacks: The list of torchbearer callbacks to be called during training and validation
:type callbacks: list
Expand Down Expand Up @@ -85,7 +85,7 @@ def fit(self, x, y, batch_size=None, epochs=1, verbose=1, callbacks=[], validati
callbacks=callbacks, validation_generator=valloader, validation_steps=validation_steps,
initial_epoch=initial_epoch, pass_state=pass_state)

def fit_generator(self, generator, train_steps=None, epochs=1, verbose=1, callbacks=[],
def fit_generator(self, generator, train_steps=None, epochs=1, verbose=2, callbacks=[],
validation_generator=None, validation_steps=None, initial_epoch=0, pass_state=False):
""" Perform fitting of a model to given data generator
Expand All @@ -95,7 +95,7 @@ def fit_generator(self, generator, train_steps=None, epochs=1, verbose=1, callba
:type train_steps: int
:param epochs: The number of training epochs to be run (each sample from the dataset is viewed exactly once)
:type epochs: int
:param verbose: If 1 use tqdm progress frontend, else display no training progress
:param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
:type verbose: int
:param callbacks: The list of torchbearer callbacks to be called during training and validation
:type callbacks: list
Expand All @@ -110,8 +110,7 @@ def fit_generator(self, generator, train_steps=None, epochs=1, verbose=1, callba
:return: The final state context dictionary
:rtype: dict[str,any]
"""
if verbose == 1:
callbacks = [Tqdm()] + callbacks
callbacks = Model._add_printer(callbacks, verbose)
_callbacks = CallbackList(callbacks)

# Get train and validation steps
Expand Down Expand Up @@ -290,7 +289,7 @@ def _validate(self, state, _callbacks, pass_state):
"""
self._test_loop(state, _callbacks, pass_state, self._load_batch_standard, state[torchbearer.VALIDATION_STEPS])

def evaluate(self, x=None, y=None, batch_size=32, verbose=1, steps=None, pass_state=False):
def evaluate(self, x=None, y=None, batch_size=32, verbose=2, steps=None, pass_state=False):
""" Perform an evaluation loop on given data and label tensors to evaluate metrics
:param x: The input data tensor
Expand All @@ -299,7 +298,7 @@ def evaluate(self, x=None, y=None, batch_size=32, verbose=1, steps=None, pass_st
:type y: torch.Tensor
:param batch_size: The mini-batch size (number of samples processed for a single weight update)
:type batch_size: int
:param verbose: If 1 use tqdm progress frontend, else display no training progress
:param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
:type verbose: int
:param steps: The number of evaluation mini-batches to run
:type steps: int
Expand All @@ -311,12 +310,12 @@ def evaluate(self, x=None, y=None, batch_size=32, verbose=1, steps=None, pass_st
trainset = DataLoader(TensorDataset(x, y), batch_size, steps)
return self.evaluate_generator(trainset, verbose, pass_state=pass_state)

def evaluate_generator(self, generator, verbose=1, steps=None, pass_state=False):
def evaluate_generator(self, generator, verbose=2, steps=None, pass_state=False):
""" Perform an evaluation loop on given data generator to evaluate metrics
:param generator: The evaluation data generator (usually a pytorch DataLoader)
:type generator: DataLoader
:param verbose: If 1 use tqdm progress frontend, else display no training progress
:param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
:type verbose: int
:param steps: The number of evaluation mini-batches to run
:type steps: int
Expand All @@ -329,9 +328,7 @@ def evaluate_generator(self, generator, verbose=1, steps=None, pass_state=False)
state = {torchbearer.EPOCH: 0, torchbearer.MAX_EPOCHS: 1, torchbearer.STOP_TRAINING: False, torchbearer.VALIDATION_GENERATOR: generator}
state.update(self.main_state)

_callbacks = []
if verbose == 1:
_callbacks.append(Tqdm('e'))
_callbacks = Model._add_printer([], verbose, validation_label_letter='e')

if state[torchbearer.VALIDATION_GENERATOR] is None:
batch_loader = self._load_batch_none
Expand All @@ -342,14 +339,14 @@ def evaluate_generator(self, generator, verbose=1, steps=None, pass_state=False)

return state[torchbearer.METRICS]

def predict(self, x=None, batch_size=32, verbose=1, steps=None, pass_state=False):
def predict(self, x=None, batch_size=32, verbose=2, steps=None, pass_state=False):
""" Perform a prediction loop on given data tensor to predict labels
:param x: The input data tensor
:type x: torch.Tensor
:param batch_size: The mini-batch size (number of samples processed for a single weight update)
:type batch_size: int
:param verbose: If 1 use tqdm progress frontend, else display no training progress
:param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
:type verbose: int
:param steps: The number of evaluation mini-batches to run
:type steps: int
Expand All @@ -361,12 +358,12 @@ def predict(self, x=None, batch_size=32, verbose=1, steps=None, pass_state=False
pred_set = DataLoader(TensorDataset(x), batch_size, steps)
return self.predict_generator(pred_set, verbose, pass_state=pass_state)

def predict_generator(self, generator, verbose=1, steps=None, pass_state=False):
def predict_generator(self, generator, verbose=2, steps=None, pass_state=False):
"""Perform a prediction loop on given data generator to predict labels
:param generator: The prediction data generator (usually a pytorch DataLoader)
:type generator: DataLoader
:param verbose: If 1 use tqdm progress frontend, else display no training progress
:param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
:type verbose: int
:param steps: The number of evaluation mini-batches to run
:type steps: int
Expand All @@ -378,9 +375,8 @@ def predict_generator(self, generator, verbose=1, steps=None, pass_state=False):
state = {torchbearer.EPOCH: 0, torchbearer.MAX_EPOCHS: 1, torchbearer.STOP_TRAINING: False, torchbearer.VALIDATION_GENERATOR: generator}
state.update(self.main_state)

_callbacks = [AggregatePredictions()]
if verbose == 1:
_callbacks.append(Tqdm('p'))
_callbacks = Model._add_printer([AggregatePredictions()], verbose, validation_label_letter='p')

self._test_loop(state, CallbackList(_callbacks), pass_state, self._load_batch_predict, steps)

return state[torchbearer.FINAL_PREDICTIONS]
Expand Down Expand Up @@ -459,6 +455,26 @@ def state_dict(self, **kwargs):
}
return state_dict

@staticmethod
def _add_printer(callbacks, verbose, validation_label_letter='v'):
"""Static method used to add the printer callback to the given list for the given verbose level
:param callbacks: The list to add to
:type callbacks: list
:param verbose: 2, 1 or 0, Most -> Least verbose
:type verbose: int
:param validation_label_letter: Pass to Tqdm
:type validation_label_letter: str
:return: The updated list
:rtype: list
"""
if verbose >= 2:
return [Tqdm(validation_label_letter=validation_label_letter)] + callbacks
elif verbose >= 1:
return [Tqdm(validation_label_letter=validation_label_letter, on_epoch=True)] + callbacks
else:
return callbacks

@staticmethod
def _deep_to(batch, device, dtype):
""" Static method to call :func:`to` on tensors or tuples. All items in tuple will have :func:_deep_to called
Expand Down

0 comments on commit 172593b

Please sign in to comment.