Skip to content

Commit

Permalink
Trim net for prediction (#847)
Browse files Browse the repository at this point in the history
After finishing training, a net has many attributes that are not needed
anymore when only inference is performed. A user might want to get rid
of these attributes to keep the size of the net small. E.g., depending
on the chosen optimizer, it can be quite big.

(In addition, there is the advantage that the less attributes exist, the
more likely a net can be unpickled that used a different version of
skorch, PyTorch, etc.).

This commit adds a convenience method, trim_for_prediction, that takes
care of removing all attributes that are no longer required.

Implementation

In addition to removing unneeded attributes, make sure to clear
callbacks (which, at least at the moment, are used exclusively during
training), the history, the train_split, and the iterator_train. The training
state is set to False.

Also, set an attribute after trimming. When the net is
initialized/fitted, check if the net is trimmed and raise a useful error
message for the user. This check does not require the attribute to be
set, to prevent possible compatibility issues. For this, a new training
readiness check and exception are introduced.
  • Loading branch information
BenjaminBossan committed Apr 16, 2022
1 parent f99eea9 commit b1cacb1
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added `load_best` attribute to `EarlyStopping` callback to automatically load module weights of the best result at the end of training
- Added a method, `trim_for_prediction`, on the net classes, which trims the net from everything not required for using it for prediction; call this after fitting to reduce the size of the net
- Added experimental support for [huggingface accelerate](https://github.com/huggingface/accelerate); use the provided mixin class to add advanced training capabilities provided by the accelerate library to skorch

### Changed
Expand Down
4 changes: 4 additions & 0 deletions skorch/exceptions.py
Expand Up @@ -22,3 +22,7 @@ class SkorchWarning(UserWarning):

class DeviceWarning(SkorchWarning):
"""A problem with a device (e.g. CUDA) was detected."""


class SkorchTrainingImpossibleError(SkorchException):
"""The net cannot be used for training"""
67 changes: 67 additions & 0 deletions skorch/net.py
Expand Up @@ -30,6 +30,7 @@
from skorch.dataset import unpack_data
from skorch.exceptions import DeviceWarning
from skorch.exceptions import SkorchAttributeError
from skorch.exceptions import SkorchTrainingImpossibleError
from skorch.history import History
from skorch.setter import optimizer_setter
from skorch.utils import _identity
Expand Down Expand Up @@ -812,6 +813,8 @@ def _initialize_history(self):

def initialize(self):
"""Initializes all of its components and returns self."""
self.check_training_readiness()

self._initialize_virtual_params()
self._initialize_callbacks()
self._initialize_module()
Expand All @@ -824,6 +827,16 @@ def initialize(self):
self.initialized_ = True
return self

def check_training_readiness(self):
"""Check that the net is ready to train"""
is_trimmed_for_prediction = getattr(self, '_trimmed_for_prediction', False)
if is_trimmed_for_prediction:
msg = (
"The net's attributes were trimmed for prediction, thus it cannot "
"be used for training anymore"
)
raise SkorchTrainingImpossibleError(msg)

def check_data(self, X, y=None):
pass

Expand Down Expand Up @@ -1073,6 +1086,7 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
"""
self.check_data(X, y)
self.check_training_readiness()
epochs = epochs if epochs is not None else self.max_epochs

dataset_train, dataset_valid = self.get_split_datasets(
Expand Down Expand Up @@ -1242,6 +1256,59 @@ def check_is_fitted(self, attributes=None, *args, **kwargs):
attributes = attributes or ['module_']
check_is_fitted(self, attributes, *args, **kwargs)

def trim_for_prediction(self):
"""Remove all attributes not required for prediction.
Use this method after you finished training your net, with the goal of
reducing its size. All attributes only required during training (e.g.
the optimizer) are set to None. This can lead to a considerable decrease
in memory footprint. It also makes it more likely that the net can be
loaded with different library versions.
After calling this function, the net can only be used for prediction
(e.g. ``net.predict`` or ``net.predict_proba``) but no longer for
training (e.g. ``net.fit(X, y)`` will raise an exception).
This operation is irreversible. Once the net has been trimmed for
prediction, it is no longer possible to restore the original state.
Morevoer, this operation mutates the net. If you need the unmodified
net, create a deepcopy before trimming:
.. code:: python
from copy import deepcopy
net = NeuralNet(...)
net.fit(X, y)
# training finished
net_original = deepcopy(net)
net.trim_for_prediction()
net.predict(X)
"""
# pylint: disable=protected-access
if getattr(self, '_trimmed_for_prediction', False):
return

self.check_is_fitted()
# pylint: disable=attribute-defined-outside-init
self._trimmed_for_prediction = True
self._set_training(False)

if isinstance(self.callbacks, list):
self.callbacks.clear()
self.callbacks_.clear()

self.train_split = None
self.iterator_train = None
self.history.clear()

attrs_to_trim = self._optimizers[:] + self._criteria[:]

for name in attrs_to_trim:
setattr(self, name + '_', None)
if hasattr(self, name):
setattr(self, name, None)

def forward_iter(self, X, training=False, device='cpu'):
"""Yield outputs of module forward calls on each batch of data.
The storage device of the yielded tensors is determined
Expand Down
139 changes: 139 additions & 0 deletions skorch/tests/test_net.py
Expand Up @@ -3698,3 +3698,142 @@ def test_fit_sparse_csr_learns_cuda(self, model, X, y):
score_end = net.history[-1]['train_loss']

assert score_start > 1.25 * score_end


class TestTrimForPrediction:
@pytest.fixture
def net_untrained(self, classifier_module):
"""A net with a custom 'module2_' criterion and a progress bar callback"""
from skorch import NeuralNetClassifier
from skorch.callbacks import ProgressBar

net = NeuralNetClassifier(
classifier_module,
max_epochs=2,
callbacks=[ProgressBar()],
)
return net

@pytest.fixture
def net(self, net_untrained, classifier_data):
X, y = classifier_data
return net_untrained.fit(X[:100], y[:100])

@pytest.fixture
def net_2_criteria(self, classifier_module, classifier_data):
"""A net with a custom 'module2_' criterion and disabled callbacks"""
# Check that not only the standard components are trimmed and that
# callbacks don't need to be lists.

from skorch import NeuralNetClassifier

class MyNet(NeuralNetClassifier):
def initialize_criterion(self):
super().initialize_criterion()
# pylint: disable=attribute-defined-outside-init
self.criterion2_ = classifier_module()
return self

X, y = classifier_data
net = MyNet(classifier_module, max_epochs=2, callbacks='disable')
net.fit(X, y)
return net

def test_trimmed_net_less_memory(self, net):
# very rough way of checking for smaller memory footprint
size_before = len(pickle.dumps(net))
net.trim_for_prediction()
size_after = len(pickle.dumps(net))
# check if there is at least 10% size gain
assert 0.9 * size_before > size_after

def test_trim_untrained_net_raises(self, net_untrained):
from skorch.exceptions import NotInitializedError

with pytest.raises(NotInitializedError):
net_untrained.trim_for_prediction()

def test_try_fitting_trimmed_net_raises(self, net, classifier_data):
from skorch.exceptions import SkorchTrainingImpossibleError

X, y = classifier_data
msg = (
"The net's attributes were trimmed for prediction, thus it cannot "
"be used for training anymore")

net.trim_for_prediction()
with pytest.raises(SkorchTrainingImpossibleError, match=msg):
net.fit(X, y)

def test_try_trimmed_net_partial_fit_raises(
self, net, classifier_data
):
from skorch.exceptions import SkorchTrainingImpossibleError

X, y = classifier_data
msg = (
"The net's attributes were trimmed for prediction, thus it cannot "
"be used for training anymore"
)

net.trim_for_prediction()
with pytest.raises(SkorchTrainingImpossibleError, match=msg):
net.partial_fit(X, y)

def test_inference_works(self, net, classifier_data):
# does not raise
net.trim_for_prediction()
X, _ = classifier_data
net.predict(X)
net.predict_proba(X)
net.forward(X)

def test_trim_twice_works(self, net):
# does not raise
net.trim_for_prediction()
net.trim_for_prediction()

def test_callbacks_trimmed(self, net):
net.trim_for_prediction()
assert not net.callbacks
assert not net.callbacks_

def test_optimizer_trimmed(self, net):
net.trim_for_prediction()
assert net.optimizer is None
assert net.optimizer_ is None

def test_criteria_trimmed(self, net_2_criteria):
net_2_criteria.trim_for_prediction()
assert net_2_criteria.criterion is None
assert net_2_criteria.criterion_ is None
assert net_2_criteria.criterion2_ is None

def test_history_trimmed(self, net):
net.trim_for_prediction()
assert not net.history

def test_train_iterator_trimmed(self, net):
net.trim_for_prediction()
assert net.iterator_train is None

def test_module_training(self, net):
# pylint: disable=protected-access
net._set_training(True)
net.trim_for_prediction()
assert net.module_.training is False

def test_can_be_pickled(self, net):
pickle.dumps(net)
net.trim_for_prediction()
pickle.dumps(net)

def test_can_be_copied(self, net):
copy.deepcopy(net)
net.trim_for_prediction()
copy.deepcopy(net)

def test_can_be_cloned(self, net):
clone(net)
net.trim_for_prediction()
clone(net)

0 comments on commit b1cacb1

Please sign in to comment.