From 7d275df28d49de782373174a0882bc8c31fda844 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 13 Aug 2019 00:15:03 +0200 Subject: [PATCH] More helpful error message when predicting with unfitted net (#488) * More helpful error message when predicting with unfitted net * Use check_is_fitted similar to what is used in sklearn, except that a skorch.exceptions.NotInitializedError is raised. * Only assumes the presence of 'module_' so that nets remain hackable, except where a different attribute is specifically required. * Re-wrote existing checks to now use check_is_fitted. * skorch chech_is_fitted not calls sklearn check_is_fitted Only changes error message and exception type. --- CHANGES.md | 2 ++ skorch/classifier.py | 1 + skorch/net.py | 63 ++++++++++++++++++++++----------- skorch/tests/conftest.py | 3 ++ skorch/tests/test_classifier.py | 16 +++++++++ skorch/tests/test_net.py | 30 +++++++++++++++- skorch/tests/test_regressor.py | 16 +++++++++ skorch/utils.py | 28 +++++++++++++++ 8 files changed, 138 insertions(+), 21 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index b98d6138a..c9f21f38e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- More helpful error messages when trying to predict using an uninitialized model. + ### Changed - Improve numerical stability when using `NLLLoss` in `NeuralNetClassifer` (#491) diff --git a/skorch/classifier.py b/skorch/classifier.py index 0e27481bd..46a4497db 100644 --- a/skorch/classifier.py +++ b/skorch/classifier.py @@ -352,6 +352,7 @@ def predict_proba(self, X): """ y_probas = [] + self.check_is_fitted(attributes=['criterion_']) bce_logits_loss = isinstance( self.criterion_, torch.nn.BCEWithLogitsLoss) diff --git a/skorch/net.py b/skorch/net.py index 002f65c6d..7f00d741e 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -25,6 +25,7 @@ from skorch.setter import optimizer_setter from skorch.utils import FirstStepAccumulator from skorch.utils import TeeGenerator +from skorch.utils import check_is_fitted from skorch.utils import duplicate_items from skorch.utils import get_map_location from skorch.utils import is_dataset @@ -679,6 +680,7 @@ def evaluation_step(self, Xi, training=False): like dropout by setting ``training=True``. """ + self.check_is_fitted() with torch.set_grad_enabled(training): self.module_.train(training) return self.infer(Xi) @@ -852,6 +854,27 @@ def fit(self, X, y=None, **fit_params): self.partial_fit(X, y, **fit_params) return self + def check_is_fitted(self, attributes=None, *args, **kwargs): + """Checks whether the net is initialized + + Parameters + ---------- + attributes : iterable of str or None (default=None) + All the attributes that are strictly required of a fitted + net. By default, this is the `module_` attribute. + + Other arguments as in + ``sklearn.utils.validation.check_is_fitted``. + + Raises + ------ + skorch.exceptions.NotInitializedError + When the given attributes are not present. + + """ + attributes = attributes or ['module_'] + check_is_fitted(self, attributes, *args, **kwargs) + def forward_iter(self, X, training=False, device='cpu'): """Yield outputs of module forward calls on each batch of data. The storage device of the yielded tensors is determined @@ -1466,19 +1489,19 @@ def save_params( """ if f_params is not None: - if not hasattr(self, 'module_'): - raise NotInitializedError( - "Cannot save parameters of an un-initialized model. " - "Please initialize first by calling .initialize() " - "or by fitting the model with .fit(...).") + msg = ( + "Cannot save parameters of an un-initialized model. " + "Please initialize first by calling .initialize() " + "or by fitting the model with .fit(...).") + self.check_is_fitted(msg=msg) torch.save(self.module_.state_dict(), f_params) if f_optimizer is not None: - if not hasattr(self, 'optimizer_'): - raise NotInitializedError( - "Cannot save state of an un-initialized optimizer. " - "Please initialize first by calling .initialize() " - "or by fitting the model with .fit(...).") + msg = ( + "Cannot save state of an un-initialized optimizer. " + "Please initialize first by calling .initialize() " + "or by fitting the model with .fit(...).") + self.check_is_fitted(attributes=['optimizer_'], msg=msg) torch.save(self.optimizer_.state_dict(), f_optimizer) if f_history is not None: @@ -1556,20 +1579,20 @@ def _get_state_dict(f): f_optimizer = f_optimizer or formatted_files['f_optimizer'] if f_params is not None: - if not hasattr(self, 'module_'): - raise NotInitializedError( - "Cannot load parameters of an un-initialized model. " - "Please initialize first by calling .initialize() " - "or by fitting the model with .fit(...).") + msg = ( + "Cannot load parameters of an un-initialized model. " + "Please initialize first by calling .initialize() " + "or by fitting the model with .fit(...).") + self.check_is_fitted(msg=msg) state_dict = _get_state_dict(f_params) self.module_.load_state_dict(state_dict) if f_optimizer is not None: - if not hasattr(self, 'optimizer_'): - raise NotInitializedError( - "Cannot load state of an un-initialized optimizer. " - "Please initialize first by calling .initialize() " - "or by fitting the model with .fit(...).") + msg = ( + "Cannot load state of an un-initialized optimizer. " + "Please initialize first by calling .initialize() " + "or by fitting the model with .fit(...).") + self.check_is_fitted(attributes=['optimizer_'], msg=msg) state_dict = _get_state_dict(f_optimizer) self.optimizer_.load_state_dict(state_dict) diff --git a/skorch/tests/conftest.py b/skorch/tests/conftest.py index 36881a716..6450a778c 100644 --- a/skorch/tests/conftest.py +++ b/skorch/tests/conftest.py @@ -11,6 +11,9 @@ F = nn.functional +INFERENCE_METHODS = ['predict', 'predict_proba', 'forward', 'forward_iter'] + + ################### # shared fixtures # ################### diff --git a/skorch/tests/test_classifier.py b/skorch/tests/test_classifier.py index 1c5e29c10..5ba09167f 100644 --- a/skorch/tests/test_classifier.py +++ b/skorch/tests/test_classifier.py @@ -12,6 +12,8 @@ import torch from torch import nn +from skorch.tests.conftest import INFERENCE_METHODS + torch.manual_seed(0) @@ -150,6 +152,20 @@ def test_fit(self, net_fit): # fitting does not raise anything pass + @pytest.mark.parametrize('method', INFERENCE_METHODS) + def test_not_fitted_raises(self, net_cls, module_cls, data, method): + from skorch.exceptions import NotInitializedError + net = net_cls(module_cls) + X = data[0] + with pytest.raises(NotInitializedError) as exc: + # we call `list` because `forward_iter` is lazy + list(getattr(net, method)(X)) + + msg = ("This NeuralNetBinaryClassifier instance is not initialized " + "yet. Call 'initialize' or 'fit' with appropriate arguments " + "before using this method.") + assert exc.value.args[0] == msg + @flaky(max_runs=3) def test_net_learns(self, net_cls, module_cls, data): X, y = data diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index 14b6b1c60..3769c1ff0 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -18,16 +18,18 @@ import numpy as np from packaging import version import pytest +from sklearn.base import clone from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics import accuracy_score from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler -from sklearn.base import clone import torch from torch import nn from flaky import flaky +from skorch.exceptions import NotInitializedError +from skorch.tests.conftest import INFERENCE_METHODS from skorch.utils import flatten from skorch.utils import to_numpy from skorch.utils import is_torch_data_type @@ -183,6 +185,32 @@ def test_fit(self, net_fit): # fitting does not raise anything pass + @pytest.mark.parametrize('method', INFERENCE_METHODS) + def test_not_fitted_raises(self, net_cls, module_cls, data, method): + from skorch.exceptions import NotInitializedError + net = net_cls(module_cls) + X = data[0] + with pytest.raises(NotInitializedError) as exc: + # we call `list` because `forward_iter` is lazy + list(getattr(net, method)(X)) + + msg = ("This NeuralNetClassifier instance is not initialized yet. " + "Call 'initialize' or 'fit' with appropriate arguments " + "before using this method.") + assert exc.value.args[0] == msg + + def test_not_fitted_other_attributes(self, module_cls): + # pass attributes to check for explicitly + with patch('skorch.net.check_is_fitted') as check: + from skorch import NeuralNetClassifier + + net = NeuralNetClassifier(module_cls) + attributes = ['foo', 'bar_'] + + net.check_is_fitted(attributes=attributes) + args = check.call_args_list[0][0][1] + assert args == attributes + @flaky(max_runs=3) def test_net_learns(self, net_cls, module_cls, data): X, y = data diff --git a/skorch/tests/test_regressor.py b/skorch/tests/test_regressor.py index 69a0c079f..21e8e33be 100644 --- a/skorch/tests/test_regressor.py +++ b/skorch/tests/test_regressor.py @@ -9,6 +9,8 @@ import pytest import torch +from skorch.tests.conftest import INFERENCE_METHODS + torch.manual_seed(0) @@ -60,6 +62,20 @@ def test_fit(self, net_fit): # fitting does not raise anything pass + @pytest.mark.parametrize('method', INFERENCE_METHODS) + def test_not_fitted_raises(self, net_cls, module_cls, data, method): + from skorch.exceptions import NotInitializedError + net = net_cls(module_cls) + X = data[0] + with pytest.raises(NotInitializedError) as exc: + # we call `list` because `forward_iter` is lazy + list(getattr(net, method)(X)) + + msg = ("This NeuralNetRegressor instance is not initialized " + "yet. Call 'initialize' or 'fit' with appropriate arguments " + "before using this method.") + assert exc.value.args[0] == msg + @flaky(max_runs=3) def test_net_learns(self, net, net_cls, data, module_cls): X, y = data diff --git a/skorch/utils.py b/skorch/utils.py index 05be8aae6..4ca62dee5 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -15,11 +15,14 @@ import numpy as np from scipy import sparse from sklearn.utils import safe_indexing +from sklearn.exceptions import NotFittedError +from sklearn.utils.validation import check_is_fitted as sklearn_check_is_fitted import torch from torch.nn.utils.rnn import PackedSequence from torch.utils.data.dataset import Subset from skorch.exceptions import DeviceWarning +from skorch.exceptions import NotInitializedError class Ansi(Enum): @@ -477,6 +480,31 @@ def get_map_location(target_device, fallback_device='cpu'): return map_location +def check_is_fitted(estimator, attributes, msg=None, all_or_any=all): + """Checks whether the net is initialized. + + Note: This calls ``sklearn.utils.validation.check_is_fitted`` + under the hood, using exactly the same arguments and logic. The + only difference is that this function has an adapted error message + and raises a ``skorch.exception.NotInitializedError`` instead of + an ``sklearn.exceptions.NotFittedError``. + + """ + if msg is None: + msg = ("This %(name)s instance is not initialized yet. Call " + "'initialize' or 'fit' with appropriate arguments " + "before using this method.") + try: + sklearn_check_is_fitted( + estimator=estimator, + attributes=attributes, + msg=msg, + all_or_any=all_or_any, + ) + except NotFittedError as e: + raise NotInitializedError(str(e)) + + class TeeGenerator: """Stores a generator and calls ``tee`` on it to create new generators when ``TeeGenerator`` is iterated over to let you iterate over the given