Skip to content

Commit

Permalink
Merge pull request #108 from dnouri/feature/progress-bar
Browse files Browse the repository at this point in the history
Introduce ProgressBar callback
  • Loading branch information
benjamin-work committed Nov 15, 2017
2 parents 249c04a + 95ea0de commit d0b3f4b
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 1 deletion.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ dependencies:
- wheel=0.29.0=py36_0
- xz=5.2.2=1
- zlib=1.2.8=3
- tqdm=4.14.0
- pip:
- tabulate==0.7.7
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ numpy>=1.13.3
PyYAML==3.12
scikit-learn==0.18.1
scipy==0.19.0
tqdm==4.14.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'scikit-learn>=0.18',
'scipy',
'tabulate',
'tqdm',
]

tests_require = [
Expand Down
73 changes: 73 additions & 0 deletions skorch/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sklearn.metrics.scorer import check_scoring
from sklearn.model_selection._validation import _score
from tabulate import tabulate
import tqdm

from skorch.utils import Ansi
from skorch.utils import to_numpy
Expand Down Expand Up @@ -527,3 +528,75 @@ def on_epoch_end(self, net, **kwargs):
if net.verbose > 0:
print("Checkpoint! Saving model to {}.".format(target))
net.save_params(target)


class ProgressBar(Callback):
"""Display a progress bar for each epoch including duration, estimated
remaining time and user-defined metrics.
For jupyter notebooks a non-ASCII progress bar is printed instead.
Parameters:
-----------
batches_per_epoch : int (default=None)
The progress bar determines the number of batches per epoch
automatically after one epoch but you can also specify this
number yourself using this parameter.
detect_notebook : bool (default=True)
If enabled, the progress bar determines if its current environment
is a jupyter notebook and switches to a non-ASCII progress bar.
postfix_keys : list of str (default=['train_loss', 'valid_loss'])
You can use this list to specify additional info displayed in the
progress bar such as metrics and losses. A prerequisite to this is
that these values are residing in the history on batch level already,
i.e. they must be accessible via
>>> net.history[-1, 'batches', -1, key]
"""

def __init__(
self,
batches_per_epoch=None,
detect_notebook=True,
postfix_keys=None
):
self.batches_per_epoch = batches_per_epoch
self.detect_notebook = detect_notebook
self.postfix_keys = postfix_keys or ['train_loss', 'valid_loss']
self.pbar = None

def in_ipynb(self):
try:
return get_ipython().__class__.__name__ == 'ZMQInteractiveShell'
except NameError:
return False

def _use_notebook(self):
return self.in_ipynb() if self.detect_notebook else False

def _get_postfix_dict(self, net):
postfix = {}
for key in self.postfix_keys:
try:
postfix[key] = net.history[-1, 'batches', -1, key]
except KeyError:
pass
return postfix

def on_batch_end(self, net, **kwargs):
self.pbar.set_postfix(self._get_postfix_dict(net))
self.pbar.update()

def on_epoch_begin(self, net, **kwargs):
if self._use_notebook():
self.pbar = tqdm.tqdm_notebook(total=self.batches_per_epoch)
else:
self.pbar = tqdm.tqdm(total=self.batches_per_epoch)

def on_epoch_end(self, net, **kwargs):
if self.batches_per_epoch is None:
self.batches_per_epoch = self.pbar.n
self.pbar.close()
49 changes: 48 additions & 1 deletion skorch/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = torch.nn.Linear(1, 1)

# pylint: disable=arguments-differ
def forward(self, x):
return self.p(x)

Expand Down Expand Up @@ -798,3 +798,50 @@ def epoch_3_scorer(net, *_):

assert save_params_mock.call_count == 1
save_params_mock.assert_called_with('model_3_10.pt')


class TestProgressBar:
@pytest.yield_fixture
def progressbar_cls(self):
from skorch.callbacks import ProgressBar
return ProgressBar

@pytest.fixture
def net_cls(self):
"""very simple network that trains for 2 epochs"""
from skorch.net import NeuralNetRegressor
import torch

class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = torch.nn.Linear(1, 1)
# pylint: disable=arguments-differ
def forward(self, x):
return self.p(x)

return partial(
NeuralNetRegressor,
module=Module,
max_epochs=2,
batch_size=10)

@pytest.fixture(scope='module')
def data(self):
# have 10 examples so we can do a nice CV split
X = np.zeros((10, 1), dtype='float32')
y = np.zeros((10, 1), dtype='float32')
return X, y

@pytest.mark.parametrize('postfix', [
[],
['train_loss'],
['train_loss', 'valid_loss'],
['doesnotexist'],
['train_loss', 'doesnotexist'],
])
def test_invalid_postfix(self, postfix, net_cls, progressbar_cls, data):
net = net_cls(callbacks=[
progressbar_cls(postfix_keys=postfix),
])
net.fit(*data)

0 comments on commit d0b3f4b

Please sign in to comment.