Skip to content
Permalink
Browse files

More helpful error message when predicting with unfitted net (#488)

* More helpful error message when predicting with unfitted net

* Use check_is_fitted similar to what is used in sklearn, except that
  a skorch.exceptions.NotInitializedError is raised.
* Only assumes the presence of 'module_' so that nets remain hackable,
  except where a different attribute is specifically required.
* Re-wrote existing checks to now use check_is_fitted.

* skorch chech_is_fitted not calls sklearn check_is_fitted

Only changes error message and exception type.
  • Loading branch information...
BenjaminBossan authored and ottonemo committed Aug 12, 2019
1 parent 158bd3d commit 7d275df28d49de782373174a0882bc8c31fda844
Showing with 138 additions and 21 deletions.
  1. +2 −0 CHANGES.md
  2. +1 −0 skorch/classifier.py
  3. +43 −20 skorch/net.py
  4. +3 −0 skorch/tests/conftest.py
  5. +16 −0 skorch/tests/test_classifier.py
  6. +29 −1 skorch/tests/test_net.py
  7. +16 −0 skorch/tests/test_regressor.py
  8. +28 −0 skorch/utils.py
@@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- More helpful error messages when trying to predict using an uninitialized model.

### Changed

- Improve numerical stability when using `NLLLoss` in `NeuralNetClassifer` (#491)
@@ -352,6 +352,7 @@ def predict_proba(self, X):
"""
y_probas = []
self.check_is_fitted(attributes=['criterion_'])
bce_logits_loss = isinstance(
self.criterion_, torch.nn.BCEWithLogitsLoss)

@@ -25,6 +25,7 @@
from skorch.setter import optimizer_setter
from skorch.utils import FirstStepAccumulator
from skorch.utils import TeeGenerator
from skorch.utils import check_is_fitted
from skorch.utils import duplicate_items
from skorch.utils import get_map_location
from skorch.utils import is_dataset
@@ -679,6 +680,7 @@ def evaluation_step(self, Xi, training=False):
like dropout by setting ``training=True``.
"""
self.check_is_fitted()
with torch.set_grad_enabled(training):
self.module_.train(training)
return self.infer(Xi)
@@ -852,6 +854,27 @@ def fit(self, X, y=None, **fit_params):
self.partial_fit(X, y, **fit_params)
return self

def check_is_fitted(self, attributes=None, *args, **kwargs):
"""Checks whether the net is initialized
Parameters
----------
attributes : iterable of str or None (default=None)
All the attributes that are strictly required of a fitted
net. By default, this is the `module_` attribute.
Other arguments as in
``sklearn.utils.validation.check_is_fitted``.
Raises
------
skorch.exceptions.NotInitializedError
When the given attributes are not present.
"""
attributes = attributes or ['module_']
check_is_fitted(self, attributes, *args, **kwargs)

def forward_iter(self, X, training=False, device='cpu'):
"""Yield outputs of module forward calls on each batch of data.
The storage device of the yielded tensors is determined
@@ -1466,19 +1489,19 @@ def save_params(
"""
if f_params is not None:
if not hasattr(self, 'module_'):
raise NotInitializedError(
"Cannot save parameters of an un-initialized model. "
"Please initialize first by calling .initialize() "
"or by fitting the model with .fit(...).")
msg = (
"Cannot save parameters of an un-initialized model. "
"Please initialize first by calling .initialize() "
"or by fitting the model with .fit(...).")
self.check_is_fitted(msg=msg)
torch.save(self.module_.state_dict(), f_params)

if f_optimizer is not None:
if not hasattr(self, 'optimizer_'):
raise NotInitializedError(
"Cannot save state of an un-initialized optimizer. "
"Please initialize first by calling .initialize() "
"or by fitting the model with .fit(...).")
msg = (
"Cannot save state of an un-initialized optimizer. "
"Please initialize first by calling .initialize() "
"or by fitting the model with .fit(...).")
self.check_is_fitted(attributes=['optimizer_'], msg=msg)
torch.save(self.optimizer_.state_dict(), f_optimizer)

if f_history is not None:
@@ -1556,20 +1579,20 @@ def _get_state_dict(f):
f_optimizer = f_optimizer or formatted_files['f_optimizer']

if f_params is not None:
if not hasattr(self, 'module_'):
raise NotInitializedError(
"Cannot load parameters of an un-initialized model. "
"Please initialize first by calling .initialize() "
"or by fitting the model with .fit(...).")
msg = (
"Cannot load parameters of an un-initialized model. "
"Please initialize first by calling .initialize() "
"or by fitting the model with .fit(...).")
self.check_is_fitted(msg=msg)
state_dict = _get_state_dict(f_params)
self.module_.load_state_dict(state_dict)

if f_optimizer is not None:
if not hasattr(self, 'optimizer_'):
raise NotInitializedError(
"Cannot load state of an un-initialized optimizer. "
"Please initialize first by calling .initialize() "
"or by fitting the model with .fit(...).")
msg = (
"Cannot load state of an un-initialized optimizer. "
"Please initialize first by calling .initialize() "
"or by fitting the model with .fit(...).")
self.check_is_fitted(attributes=['optimizer_'], msg=msg)
state_dict = _get_state_dict(f_optimizer)
self.optimizer_.load_state_dict(state_dict)

@@ -11,6 +11,9 @@
F = nn.functional


INFERENCE_METHODS = ['predict', 'predict_proba', 'forward', 'forward_iter']


###################
# shared fixtures #
###################
@@ -12,6 +12,8 @@
import torch
from torch import nn

from skorch.tests.conftest import INFERENCE_METHODS


torch.manual_seed(0)

@@ -150,6 +152,20 @@ def test_fit(self, net_fit):
# fitting does not raise anything
pass

@pytest.mark.parametrize('method', INFERENCE_METHODS)
def test_not_fitted_raises(self, net_cls, module_cls, data, method):
from skorch.exceptions import NotInitializedError
net = net_cls(module_cls)
X = data[0]
with pytest.raises(NotInitializedError) as exc:
# we call `list` because `forward_iter` is lazy
list(getattr(net, method)(X))

msg = ("This NeuralNetBinaryClassifier instance is not initialized "
"yet. Call 'initialize' or 'fit' with appropriate arguments "
"before using this method.")
assert exc.value.args[0] == msg

@flaky(max_runs=3)
def test_net_learns(self, net_cls, module_cls, data):
X, y = data
@@ -18,16 +18,18 @@
import numpy as np
from packaging import version
import pytest
from sklearn.base import clone
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.base import clone
import torch
from torch import nn
from flaky import flaky

from skorch.exceptions import NotInitializedError
from skorch.tests.conftest import INFERENCE_METHODS
from skorch.utils import flatten
from skorch.utils import to_numpy
from skorch.utils import is_torch_data_type
@@ -183,6 +185,32 @@ def test_fit(self, net_fit):
# fitting does not raise anything
pass

@pytest.mark.parametrize('method', INFERENCE_METHODS)
def test_not_fitted_raises(self, net_cls, module_cls, data, method):
from skorch.exceptions import NotInitializedError
net = net_cls(module_cls)
X = data[0]
with pytest.raises(NotInitializedError) as exc:
# we call `list` because `forward_iter` is lazy
list(getattr(net, method)(X))

msg = ("This NeuralNetClassifier instance is not initialized yet. "
"Call 'initialize' or 'fit' with appropriate arguments "
"before using this method.")
assert exc.value.args[0] == msg

def test_not_fitted_other_attributes(self, module_cls):
# pass attributes to check for explicitly
with patch('skorch.net.check_is_fitted') as check:
from skorch import NeuralNetClassifier

net = NeuralNetClassifier(module_cls)
attributes = ['foo', 'bar_']

net.check_is_fitted(attributes=attributes)
args = check.call_args_list[0][0][1]
assert args == attributes

@flaky(max_runs=3)
def test_net_learns(self, net_cls, module_cls, data):
X, y = data
@@ -9,6 +9,8 @@
import pytest
import torch

from skorch.tests.conftest import INFERENCE_METHODS


torch.manual_seed(0)

@@ -60,6 +62,20 @@ def test_fit(self, net_fit):
# fitting does not raise anything
pass

@pytest.mark.parametrize('method', INFERENCE_METHODS)
def test_not_fitted_raises(self, net_cls, module_cls, data, method):
from skorch.exceptions import NotInitializedError
net = net_cls(module_cls)
X = data[0]
with pytest.raises(NotInitializedError) as exc:
# we call `list` because `forward_iter` is lazy
list(getattr(net, method)(X))

msg = ("This NeuralNetRegressor instance is not initialized "
"yet. Call 'initialize' or 'fit' with appropriate arguments "
"before using this method.")
assert exc.value.args[0] == msg

@flaky(max_runs=3)
def test_net_learns(self, net, net_cls, data, module_cls):
X, y = data
@@ -15,11 +15,14 @@
import numpy as np
from scipy import sparse
from sklearn.utils import safe_indexing
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted as sklearn_check_is_fitted
import torch
from torch.nn.utils.rnn import PackedSequence
from torch.utils.data.dataset import Subset

from skorch.exceptions import DeviceWarning
from skorch.exceptions import NotInitializedError


class Ansi(Enum):
@@ -477,6 +480,31 @@ def get_map_location(target_device, fallback_device='cpu'):
return map_location


def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
"""Checks whether the net is initialized.
Note: This calls ``sklearn.utils.validation.check_is_fitted``
under the hood, using exactly the same arguments and logic. The
only difference is that this function has an adapted error message
and raises a ``skorch.exception.NotInitializedError`` instead of
an ``sklearn.exceptions.NotFittedError``.
"""
if msg is None:
msg = ("This %(name)s instance is not initialized yet. Call "
"'initialize' or 'fit' with appropriate arguments "
"before using this method.")
try:
sklearn_check_is_fitted(
estimator=estimator,
attributes=attributes,
msg=msg,
all_or_any=all_or_any,
)
except NotFittedError as e:
raise NotInitializedError(str(e))


class TeeGenerator:
"""Stores a generator and calls ``tee`` on it to create new generators
when ``TeeGenerator`` is iterated over to let you iterate over the given

0 comments on commit 7d275df

Please sign in to comment.
You can’t perform that action at this time.