Skip to content

Commit

Permalink
Add list support to tqdm precision (#492)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored and MattPainter01 committed Jan 24, 2019
1 parent eca7028 commit cbf42aa
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 27 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added an optional dimension argument to the mean, std and running_mean metric aggregators
- Added a var metric and decorator which can be used to calculate the variance of a metric
- Added an unbiased flag to the std and var metrics to optionally not apply Bessel's correction (consistent with torch.std / torch.var)
- Added support for rounding 1D lists to the Tqdm callback
### Changed
- Changed the default behaviour of the std metric to compute the sample std, in line with torch.std
- Tqdm precision argument now rounds to decimal places rather than significant figures
### Deprecated
### Removed
### Fixed
Expand Down
8 changes: 4 additions & 4 deletions tests/callbacks/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
class TestFormatMetrics(TestCase):
def test_precision(self):
metrics = {'test': 1.2345}
res = torchbearer.callbacks.printer._format_metrics(metrics, 3)
self.assertEqual(res, 'test=1.23')
res = torchbearer.callbacks.printer._format_metrics(metrics, lambda x: round(x, 3))
self.assertEqual('test=1.234', res)

def test_string(self):
metrics = {'test': '1.2345'}
res = torchbearer.callbacks.printer._format_metrics(metrics, 3)
res = torchbearer.callbacks.printer._format_metrics(metrics, lambda x: round(x, 3))
self.assertEqual(res, 'test=1.2345')

def test_not_string(self):
metrics = {'test': {'hello': 2}}
res = torchbearer.callbacks.printer._format_metrics(metrics, 3)
res = torchbearer.callbacks.printer._format_metrics(metrics, lambda x: round(x, 3))
self.assertEqual(res, 'test={\'hello\': 2}')


Expand Down
43 changes: 20 additions & 23 deletions torchbearer/callbacks/printer.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
import torchbearer

from torchbearer.callbacks import Callback
from tqdm import tqdm
from collections import OrderedDict
from numbers import Number
from functools import partial

from tqdm import tqdm

def _format_num(n, precision):
# Adapted from https://github.com/tqdm/tqdm
f = ('{0:.' + str(precision) + 'g}').format(n).replace('+0', '+').replace('-0', '-')
n = str(n)
return f if len(f) < len(n) else n
import torchbearer
from torchbearer.callbacks import Callback


def _format_metrics(metrics, precision):
def _format_metrics(metrics, rounder):
# Adapted from https://github.com/tqdm/tqdm
postfix = OrderedDict([])
for key in sorted(metrics.keys()):
postfix[key] = metrics[key]

for key in postfix.keys():
if isinstance(postfix[key], Number):
postfix[key] = _format_num(postfix[key], precision)
elif not isinstance(postfix[key], str):
postfix[key] = str(postfix[key])
try:
postfix[key] = str(rounder(postfix[key]))
except TypeError:
try:
postfix[key] = str(list(map(rounder, postfix[key])))
except TypeError:
postfix[key] = str(postfix[key])
postfix = ', '.join(key + '=' + postfix[key].strip() for key in postfix.keys())
return postfix

Expand All @@ -34,7 +31,7 @@ class ConsolePrinter(Callback):
Args:
validation_label_letter (str): This is the letter displayed after the epoch number indicating the current phase
of training
precision (int): Precision of the number format in significant figures
precision (int): Precision of the number format in decimal places
State Requirements:
- :attr:`torchbearer.state.EPOCH`: The current epoch number
Expand All @@ -46,17 +43,17 @@ class ConsolePrinter(Callback):
def __init__(self, validation_label_letter='v', precision=4):
super().__init__()
self.validation_label = validation_label_letter
self.precision = precision
self.rounder = partial(round, ndigits=precision)

def _step(self, state, letter, steps):
epoch_str = '{:d}/{:d}({:s}): '.format(state[torchbearer.EPOCH], state[torchbearer.MAX_EPOCHS], letter)
batch_str = '{:d}/{:d} '.format(state[torchbearer.BATCH], steps)
stats_str = _format_metrics(state[torchbearer.METRICS], self.precision)
stats_str = _format_metrics(state[torchbearer.METRICS], self.rounder)
print('\r' + epoch_str + batch_str + stats_str, end='')

def _end(self, state, letter):
epoch_str = '{:d}/{:d}({:s}): '.format(state[torchbearer.EPOCH], state[torchbearer.MAX_EPOCHS], letter)
stats_str = _format_metrics(state[torchbearer.METRICS], self.precision)
stats_str = _format_metrics(state[torchbearer.METRICS], self.rounder)
print('\r' + epoch_str + stats_str)

def on_step_training(self, state):
Expand All @@ -78,7 +75,7 @@ class Tqdm(Callback):
Args:
validation_label_letter (str): The letter to use for validation outputs.
precision (int): Precision of the number format in significant figures
precision (int): Precision of the number format in decimal places
on_epoch (bool): If True, output a single progress bar which tracks epochs
tqdm_args: Any extra keyword args provided here will be passed through to the tqdm module constructor.
See `github.com/tqdm/tqdm#parameters <https://github.com/tqdm/tqdm#parameters>`_ for more details.
Expand All @@ -94,7 +91,7 @@ def __init__(self, tqdm_module=tqdm, validation_label_letter='v', precision=4, o
self.tqdm_module = tqdm_module
self._loader = None
self.validation_label = validation_label_letter
self.precision = precision
self.rounder = partial(round, ndigits=precision)
self._on_epoch = on_epoch
self.tqdm_args = tqdm_args

Expand All @@ -104,10 +101,10 @@ def _on_start(self, state, letter):

def _update(self, state):
self._loader.update(1)
self._loader.set_postfix_str(_format_metrics(state[torchbearer.METRICS], self.precision))
self._loader.set_postfix_str(_format_metrics(state[torchbearer.METRICS], self.rounder))

def _close(self, state):
self._loader.set_postfix_str(_format_metrics(state[torchbearer.METRICS], self.precision))
self._loader.set_postfix_str(_format_metrics(state[torchbearer.METRICS], self.rounder))
self._loader.close()

def on_start(self, state):
Expand Down

0 comments on commit cbf42aa

Please sign in to comment.