Skip to content

Commit

Permalink
Merge pull request #124 from dnouri/feature/progress-bar-auto
Browse files Browse the repository at this point in the history
Add 'auto' parameter to progress bar
  • Loading branch information
benjamin-work committed Dec 4, 2017
2 parents 6818c00 + 034166b commit 2bbb638
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 10 deletions.
43 changes: 35 additions & 8 deletions skorch/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,10 +545,13 @@ class ProgressBar(Callback):
Parameters:
-----------
batches_per_epoch : int (default=None)
batches_per_epoch : int, str (default='count')
The progress bar determines the number of batches per epoch
automatically after one epoch but you can also specify this
number yourself using this parameter.
by itself in ``'count'`` mode where the number of iterations is
determined after one epoch which will leave you without a progress
bar at the first epoch. To fix that you can provide this number manually
or set ``'auto'`` where the callback attempts to compute the
number of batches per epoch beforehand.
detect_notebook : bool (default=True)
If enabled, the progress bar determines if its current environment
Expand All @@ -565,7 +568,7 @@ class ProgressBar(Callback):

def __init__(
self,
batches_per_epoch=None,
batches_per_epoch='count',
detect_notebook=True,
postfix_keys=None
):
Expand All @@ -582,6 +585,21 @@ def in_ipynb(self):
def _use_notebook(self):
return self.in_ipynb() if self.detect_notebook else False

def _get_batch_size(self, net, training):
name = 'iterator_train' if training else 'iterator_valid'
net_params = net.get_params()
return net_params.get(name + '__batch_size', net_params['batch_size'])

def _get_batches_per_epoch_phase(self, net, X, training):
if X is None:
return 0
batch_size = self._get_batch_size(net, training)
return int(np.ceil(len(X) / batch_size))

def _get_batches_per_epoch(self, net, X, X_valid):
return (self._get_batches_per_epoch_phase(net, X, True) +
self._get_batches_per_epoch_phase(net, X_valid, False))

def _get_postfix_dict(self, net):
postfix = {}
for key in self.postfix_keys:
Expand All @@ -595,13 +613,22 @@ def on_batch_end(self, net, **kwargs):
self.pbar.set_postfix(self._get_postfix_dict(net))
self.pbar.update()

def on_epoch_begin(self, net, **kwargs):
def on_epoch_begin(self, net, X=None, X_valid=None, **kwargs):
# Assume it is a number until proven otherwise.
batches_per_epoch = self.batches_per_epoch

if self.batches_per_epoch == 'auto':
batches_per_epoch = self._get_batches_per_epoch(net, X, X_valid)
elif self.batches_per_epoch == 'count':
# No limit is known until the end of the first epoch.
batches_per_epoch = None

if self._use_notebook():
self.pbar = tqdm.tqdm_notebook(total=self.batches_per_epoch)
self.pbar = tqdm.tqdm_notebook(total=batches_per_epoch)
else:
self.pbar = tqdm.tqdm(total=self.batches_per_epoch)
self.pbar = tqdm.tqdm(total=batches_per_epoch)

def on_epoch_end(self, net, **kwargs):
if self.batches_per_epoch is None:
if self.batches_per_epoch == 'count':
self.batches_per_epoch = self.pbar.n
self.pbar.close()
19 changes: 17 additions & 2 deletions skorch/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,8 @@ def forward(self, x):
@pytest.fixture(scope='module')
def data(self):
# have 10 examples so we can do a nice CV split
X = np.zeros((10, 1), dtype='float32')
y = np.zeros((10, 1), dtype='float32')
X = np.zeros((20, 1), dtype='float32')
y = np.zeros((20, 1), dtype='float32')
return X, y

@pytest.mark.parametrize('postfix', [
Expand All @@ -869,3 +869,18 @@ def test_invalid_postfix(self, postfix, net_cls, progressbar_cls, data):
progressbar_cls(postfix_keys=postfix),
])
net.fit(*data)

@pytest.mark.parametrize('scheme', [
'count',
'auto',
None,
2, # correct number of batches_per_epoch (20 // 10)
3, # offset by +1, should still work
1, # offset by -1, should still work
])
def test_different_count_schemes(
self, scheme, net_cls, progressbar_cls, data):
net = net_cls(callbacks=[
progressbar_cls(batches_per_epoch=scheme),
])
net.fit(*data)

0 comments on commit 2bbb638

Please sign in to comment.