Skip to content

Commit

Permalink
Add option to globally override callback caching (#971)
Browse files Browse the repository at this point in the history
Resolves #957 (look there for motivation)

Description

I went with the option to add an argument to the NeuralNet class itself
that can override the caching behavior of the scoring callbacks. This
seemed more straightforward than going with a context manager, which is
not a pattern we typically require of skorch users. The disadvantage is
that we have yet one more parameter on the NeuralNet.

By default, the caching behavior of the callbacks is not changed, i.e.
this should be fully backwards compatible.

Implementation

Ideally, I would have liked to implement this in a way that any new (or
user defined) callback don't have to do anything special to honor the
parameter. However, this is not really possible. Although I moved the
logic of whether to use inference caching inside of
skorch.callbacks.scoring._cache_net_forward_iter, so that the global
override is automatically taken into account here, there are other
places in the callbacks whose behavior changes depending on the use of
caching (e.g. target extraction). The check for whether there is a
caching override thus has to be performed multiple times.
  • Loading branch information
BenjaminBossan committed Jun 29, 2023
1 parent c6c2ef5 commit 48cb4a1
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]
### Added
- Add the option to globally override the use of caching in scoring callbacks on the net by setting the `use_caching` argument on the net (this overrides the settings of individual callbacks)

### Changed
### Fixed

Expand Down
34 changes: 28 additions & 6 deletions skorch/callbacks/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def cache_net_infer(net, use_caching, y_preds):
Deprecated.
"""
# TODO: remove this function

warnings.warn(
"cache_net_infer is no longer uesd to provide caching for "
Expand Down Expand Up @@ -64,7 +65,12 @@ def _cache_net_forward_iter(net, use_caching, y_preds):
method will subsequently return cached predictions. Leaving the
context will undo the overwrite of the ``forward_iter`` method.
Note that the net may override the use of caching.
"""
if net.use_caching != 'auto':
use_caching = net.use_caching

if not use_caching:
yield net
return
Expand Down Expand Up @@ -229,7 +235,8 @@ class BatchScoring(ScoringBase):
use_caching : bool (default=True)
Re-use the model's prediction for computing the loss to calculate
the score. Turning this off will result in an additional inference
step for each batch.
step for each batch. Note that the net may override the use of
caching.
"""
# pylint: disable=unused-argument,arguments-differ
Expand Down Expand Up @@ -345,7 +352,7 @@ class EpochScoring(ScoringBase):
in an additional inference step for each epoch and an
inability to use arbitrary datasets as input (since we
don't know how to extract ``y_true`` from an arbitrary
dataset).
dataset). Note that the net may override the use of caching.
"""
def _initialize_cache(self):
Expand All @@ -364,7 +371,11 @@ def on_epoch_begin(self, net, dataset_train, dataset_valid, **kwargs):
# pylint: disable=arguments-differ
def on_batch_end(
self, net, batch, y_pred, training, **kwargs):
if not self.use_caching or training != self.on_train:
use_caching = self.use_caching
if net.use_caching != 'auto':
use_caching = net.use_caching

if (not use_caching) or (training != self.on_train):
return

# We collect references to the prediction and target data
Expand All @@ -379,7 +390,7 @@ def on_batch_end(
self.y_trues_.append(y)
self.y_preds_.append(y_pred)

def get_test_data(self, dataset_train, dataset_valid):
def get_test_data(self, dataset_train, dataset_valid, use_caching):
"""Return data needed to perform scoring.
This is a convenience method that handles picking of
Expand All @@ -394,6 +405,9 @@ def get_test_data(self, dataset_train, dataset_valid):
dataset_valid
Incoming validation data or dataset.
use_caching : bool
Whether caching of inference is being used.
Returns
-------
X_test
Expand All @@ -413,7 +427,7 @@ def get_test_data(self, dataset_train, dataset_valid):
"""
dataset = dataset_train if self.on_train else dataset_valid

if self.use_caching:
if use_caching:
X_test = dataset
y_pred = self.y_preds_
y_test = [self.target_extractor(y) for y in self.y_trues_]
Expand Down Expand Up @@ -459,7 +473,15 @@ def on_epoch_end(
dataset_train,
dataset_valid,
**kwargs):
X_test, y_test, y_pred = self.get_test_data(dataset_train, dataset_valid)
use_caching = self.use_caching
if net.use_caching != 'auto':
use_caching = net.use_caching

X_test, y_test, y_pred = self.get_test_data(
dataset_train,
dataset_valid,
use_caching=use_caching,
)
if X_test is None:
return

Expand Down
23 changes: 23 additions & 0 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,20 @@ class NeuralNet:
notation, e.g. when initializing the net with ``compile__dynamic=True``,
``torch.compile`` will be called with ``dynamic=True``.
use_caching : bool or 'auto' (default='auto')
Optionally override the caching behavior of scoring callbacks. Callbacks
such as :class:`.EpochScoring` and :class:`.BatchScoring` allow to cache
the inference call to save time when calculating scores during training at
the expense of memory. In certain situations, e.g. when memory is tight,
you may want to disable caching. As it is cumbersome to change the setting
on each callback individually, this parameter allows to override their
behavior globally.
By default (``'auto'``), the callbacks will determine if caching is used
or not. If this argument is set to ``False``, caching will be disabled on
all callbacks. If set to ``True``, caching will be enabled on all
callbacks.
Implementation note: It is the job of the callbacks to honor this setting.
Attributes
----------
prefixes_ : list of str
Expand Down Expand Up @@ -295,6 +309,7 @@ def __init__(
verbose=1,
device='cpu',
compile=False,
use_caching='auto',
**kwargs
):
self.module = module
Expand All @@ -313,6 +328,7 @@ def __init__(
self.verbose = verbose
self.device = device
self.compile = compile
self.use_caching = use_caching

self._check_deprecated_params(**kwargs)
history = kwargs.pop('history', None)
Expand Down Expand Up @@ -2020,6 +2036,13 @@ def _validate_params(self):
suggestion = prefix + '__' + suffix
msgs.append(tmpl.format(key, suggestion))

valid_vals_use_caching = ('auto', False, True)
if self.use_caching not in valid_vals_use_caching:
msgs.append(
f"Incorrect value for use_caching used ('{self.use_caching}'), "
f"use one of: {', '.join(map(str, valid_vals_use_caching))}"
)

if msgs:
full_msg = '\n'.join(msgs)
raise ValueError(full_msg)
Expand Down
99 changes: 98 additions & 1 deletion skorch/tests/callbacks/test_scoring.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for scoring"""

import re
from functools import partial
from unittest.mock import Mock
from unittest.mock import patch
Expand Down Expand Up @@ -521,7 +522,7 @@ def on_epoch_end(
dataset_valid,
**kwargs):
_, y_test, y_proba = self.get_test_data(
dataset_train, dataset_valid)
dataset_train, dataset_valid, use_caching=self.use_caching)
y_pred = np.concatenate(y_proba).argmax(1)

# record 2 valid scores
Expand Down Expand Up @@ -943,6 +944,102 @@ def interrupt_scoring(net, X, y):
assert len(y_pred) == len(X)


class TestScoringCacheGlobalControl:
"""This test is about the possibility to control cache usage globally
See this issue for more context:
https://github.com/skorch-dev/skorch/issues/957
"""
@pytest.fixture
def net_cls(self):
from skorch import NeuralNetClassifier
return NeuralNetClassifier

@pytest.mark.parametrize('net_use_caching', ['auto', True, False])
def test_net_overrides_caching(
self, net_cls, classifier_module, classifier_data, net_use_caching
):
from skorch.callbacks import BatchScoring, EpochScoring

call_count = 0

class MyNet(net_cls):
def infer(self, x, **kwargs):
nonlocal call_count
call_count += 1
return super().infer(x, **kwargs)

X, y = classifier_data
X, y = X[:40], y[:40] # small amount of data is sufficient
batch_size = 4
max_epochs = 3

# calculation of expected call count of infer
# net:
# 40 samples with a batch size of 4 => 10 calls to net.infer per epoch
# 3 epochs => 30 calls as a baseline
# callbacks:
# 32 samples for train => 8 calls if on_train=True => 24 for 3 epochs
# 8 samples for valid => 2 calls if on_train=False => 6 for 3 epochs

callbacks = [
# this callback adds 24 calls
BatchScoring(
scoring='f1',
use_caching=False,
on_train=True,
),
# this callback adds 6 calls
BatchScoring(
scoring='accuracy',
use_caching=True,
on_train=False,
),
# this callback adds 24 calls
EpochScoring(
scoring='recall',
use_caching=True,
on_train=True,
),
# this callback adds 6 calls
EpochScoring(
scoring='precision',
use_caching=False,
on_train=False,
),
]

net = MyNet(
classifier_module,
batch_size=batch_size,
max_epochs=max_epochs,
callbacks=callbacks,
use_caching=net_use_caching,
# turn off default scorer to not mess with the numbers
callbacks__valid_acc=None,
)
net.fit(X, y)

if net_use_caching == 'auto':
assert call_count == 30 + 24 + 0 + 0 + 6
elif net_use_caching is True:
assert call_count == 30 + 0 + 0 + 0 + 0
elif net_use_caching is False:
assert call_count == 30 + 24 + 6 + 24 + 6
else:
assert False, "incorrect parameter passed"

def test_net_use_caching_wrong_value_raises(self, net_cls, classifier_module):
net = net_cls(classifier_module, use_caching='wrong-value')
msg = re.escape(
"Incorrect value for use_caching used ('wrong-value'), "
"use one of: auto, False, True"
)
with pytest.raises(ValueError, match=msg):
net.initialize()


class TestPassthrougScoring:
@pytest.fixture
def scoring_cls(self, request):
Expand Down

0 comments on commit 48cb4a1

Please sign in to comment.