Skip to content
Permalink
Browse files

Feature skip data iteration when caching scoring (#557)

* Some cleanups in test_scoring

* indentation level
* disable some pylint messages
* unused fixtures

* Don't iterate over data when using cached scoring

Before, net.infer was cached when using a scoring callback wiht
use_caching=True. This way, the time to make an inference step was
saved. However, there was still an iteration step over the data for
each scoring callback. If iteration is slow, this could incur a
significant overhead.

Now net.forward_iter is cached instead. This way, the iteration over
the data is skipped and the iteration overhead should be gone.

* Add comment to explain attribute priority

Similar to the comment in cache_net_infer

* Move common functionality to skorch.utils.to_device

... instead of having it as a method on NeuralNet.

Add tests

* Remove unnecessary import in test_net
  • Loading branch information
BenjaminBossan authored and ottonemo committed Dec 16, 2019
1 parent ab8e6ed commit 09be626e74512124eb74c76b4cbabad4d3b1f274
Showing with 147 additions and 17 deletions.
  1. +2 −0 CHANGES.md
  2. +51 −6 skorch/callbacks/scoring.py
  3. +2 −4 skorch/net.py
  4. +31 −6 skorch/tests/callbacks/test_scoring.py
  5. +0 −1 skorch/tests/test_net.py
  6. +50 −0 skorch/tests/test_utils.py
  7. +11 −0 skorch/utils.py
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- When using caching in scoring callbacks, no longer uselessly iterate over the data; this can save time if iteration is slow (#552, #557)

### Fixed

- Make skorch compatible with sklearn 0.22
@@ -3,6 +3,7 @@
from contextlib import contextmanager
from contextlib import suppress
from functools import partial
import warnings

import numpy as np
from sklearn.metrics.scorer import (
@@ -13,6 +14,7 @@
from skorch.utils import to_numpy
from skorch.callbacks import Callback
from skorch.utils import check_indexing
from skorch.utils import to_device
from skorch.utils import train_loss_score
from skorch.utils import valid_loss_score

@@ -22,10 +24,23 @@

@contextmanager
def cache_net_infer(net, use_caching, y_preds):
"""Caching context for ``skorch.NeuralNet`` instance. Returns
a modified version of the net whose ``infer`` method will
subsequently return cached predictions. Leaving the context
will undo the overwrite of the ``infer`` method."""
"""Caching context for ``skorch.NeuralNet`` instance.
Returns a modified version of the net whose ``infer`` method will
subsequently return cached predictions. Leaving the context will
undo the overwrite of the ``infer`` method.
Deprecated.
"""

warnings.warn(
"cache_net_infer is no longer uesd to provide caching for "
"the scoring callbacks and will hence be removed in a "
"future release.",
DeprecationWarning,
)

if not use_caching:
yield net
return
@@ -41,6 +56,36 @@ def cache_net_infer(net, use_caching, y_preds):
del net.__dict__['infer']


@contextmanager
def _cache_net_forward_iter(net, use_caching, y_preds):
"""Caching context for ``skorch.NeuralNet`` instance.
Returns a modified version of the net whose ``forward_iter``
method will subsequently return cached predictions. Leaving the
context will undo the overwrite of the ``forward_iter`` method.
"""
if not use_caching:
yield net
return
y_preds = iter(y_preds)

# pylint: disable=unused-argument
def cached_forward_iter(*args, device=net.device, **kwargs):
for yp in y_preds:
yield to_device(yp, device=device)

net.forward_iter = cached_forward_iter
try:
yield net
finally:
# By setting net.forward_iter we define an attribute
# `forward_iter` that precedes the bound method
# `forward_iter`. By deleting the entry from the attribute
# dict we undo this.
del net.__dict__['forward_iter']


def convert_sklearn_metric_function(scoring):
"""If ``scoring`` is a sklearn metric function, convert it to a
sklearn scorer and return it. Otherwise, return ``scoring`` unchanged."""
@@ -196,7 +241,7 @@ def on_batch_end(self, net, X, y, training, **kwargs):
return

y_preds = [kwargs['y_pred']]
with cache_net_infer(net, self.use_caching, y_preds) as cached_net:
with _cache_net_forward_iter(net, self.use_caching, y_preds) as cached_net:
# In case of y=None we will not have gathered any samples.
# We expect the scoring function to deal with y=None.
y = None if y is None else self.target_extractor(y)
@@ -418,7 +463,7 @@ def on_epoch_end(
if X_test is None:
return

with cache_net_infer(net, self.use_caching, y_pred) as cached_net:
with _cache_net_forward_iter(net, self.use_caching, y_pred) as cached_net:
current_score = self._scoring(cached_net, X_test, y_test)

self._record_score(net.history, current_score)
@@ -30,6 +30,7 @@
from skorch.utils import is_dataset
from skorch.utils import noop
from skorch.utils import params_for
from skorch.utils import to_device
from skorch.utils import to_numpy
from skorch.utils import to_tensor
from skorch.utils import train_loss_score
@@ -911,10 +912,7 @@ def forward_iter(self, X, training=False, device='cpu'):
for data in iterator:
Xi = unpack_data(data)[0]
yp = self.evaluation_step(Xi, training=training)
if isinstance(yp, tuple):
yield tuple(n.to(device) for n in yp)
else:
yield yp.to(device)
yield to_device(yp, device=device)

def forward(self, X, training=False, device='cpu'):
"""Gather and concatenate the output from forward call with
@@ -87,8 +87,8 @@ def test_scoring_uses_score_when_none(
])
@pytest.mark.parametrize('initial_epochs', [1, 2, 3, 4])
def test_scoring_uses_best_score_when_continuing_training(
self, net_cls, module_cls, scoring_cls, train_split, data,
lower_is_better, expected, tmpdir, initial_epochs
self, net_cls, module_cls, scoring_cls, data,
lower_is_better, expected, tmpdir, initial_epochs,
):
# set scoring to None so that mocked net.score is used
net = net_cls(
@@ -361,7 +361,7 @@ def net_input_is_scoring_input(
train_split, expected_type, caching,
):
score_calls = 0
def myscore(net, X, y=None):
def myscore(net, X, y=None): # pylint: disable=unused-argument
nonlocal score_calls
score_calls += 1
assert type(X) == expected_type
@@ -473,6 +473,31 @@ def test_multiple_scorings_with_dict(
with pytest.raises(ValueError, match=msg):
net.fit(*data)

@pytest.mark.parametrize('use_caching, count', [(False, 1), (True, 0)])
def test_with_caching_get_iterator_not_called(
self, net_cls, module_cls, train_split, caching_scoring_cls, data,
use_caching, count,
):
max_epochs = 3
net = net_cls(
module=module_cls,
callbacks=[
('acc', caching_scoring_cls('accuracy', use_caching=use_caching)),
],
train_split=train_split,
max_epochs=max_epochs,
)

get_iterator = net.get_iterator
net.get_iterator = Mock(side_effect=get_iterator)
net.fit(*data)

# expected count should be:
# max_epochs * (1 (train) + 1 (valid) + 0 or 1 (from scoring,
# depending on caching))
count_expected = max_epochs * (1 + 1 + count)
assert net.get_iterator.call_count == count_expected

def test_subclassing_epoch_scoring(
self, classifier_module, classifier_data):
# This test's purpose is to check that it is possible to
@@ -618,8 +643,8 @@ def test_average_honors_weights(self, train_loss, history):
])
@pytest.mark.parametrize('initial_epochs', [1, 2, 3, 4])
def test_scoring_uses_best_score_when_continuing_training(
self, net_cls, module_cls, scoring_cls, train_split, data,
lower_is_better, expected, tmpdir, initial_epochs
self, net_cls, module_cls, scoring_cls, data,
lower_is_better, expected, tmpdir, initial_epochs,
):
# set scoring to None so that mocked net.score is used
net = net_cls(
@@ -856,7 +881,7 @@ def test_without_target_data_works(
):
score_calls = 0

def myscore(net, X, y=None):
def myscore(net, X, y=None): # pylint: disable=unused-argument
nonlocal score_calls
score_calls += 1
assert y is None
@@ -28,7 +28,6 @@
from torch import nn
from flaky import flaky

from skorch.exceptions import NotInitializedError
from skorch.tests.conftest import INFERENCE_METHODS
from skorch.utils import flatten
from skorch.utils import to_numpy
@@ -135,6 +135,56 @@ def test_sparse_tensor_not_accepted_raises(self, to_tensor, device):
assert exc.value.args[0] == msg


class TestToDevice:
@pytest.fixture
def to_device(self):
from skorch.utils import to_device
return to_device

@pytest.fixture
def x(self):
return torch.zeros(3)

@pytest.fixture
def x_tup(self):
return torch.zeros(3), torch.ones((4, 5))

@pytest.mark.parametrize('device_from, device_to', [
('cpu', 'cpu'),
('cpu', 'cuda'),
('cuda', 'cpu'),
('cuda', 'cuda'),
])
def test_check_device_torch_tensor(self, to_device, x, device_from, device_to):
if 'cuda' in (device_from, device_to) and not torch.cuda.is_available():
pytest.skip()

x = to_device(x, device=device_from)
assert x.device.type == device_from

x = to_device(x, device=device_to)
assert x.device.type == device_to

@pytest.mark.parametrize('device_from, device_to', [
('cpu', 'cpu'),
('cpu', 'cuda'),
('cuda', 'cpu'),
('cuda', 'cuda'),
])
def test_check_device_tuple_torch_tensor(
self, to_device, x, device_from, device_to):
if 'cuda' in (device_from, device_to) and not torch.cuda.is_available():
pytest.skip()

x = to_device(x, device=device_from)
for xi in x:
assert xi.device.type == device_from

x = to_device(x, device=device_to)
for xi in x:
assert xi.device.type == device_to


class TestDuplicateItems:
@pytest.fixture
def duplicate_items(self):
@@ -125,6 +125,17 @@ def to_numpy(X):
return X.numpy()


def to_device(X, device):
"""Generic function to move module output(s) to a device.
Deals with X being a torch tensor or a tuple of torch tensors.
"""
if isinstance(X, tuple):
return tuple(x.to(device) for x in X)
return X.to(device)


def get_dim(y):
"""Return the number of dimensions of a torch tensor or numpy
array-like object.

0 comments on commit 09be626

Please sign in to comment.
You can’t perform that action at this time.