Skip to content

Commit

Permalink
More helpful error message when predicting with unfitted net (#488)
Browse files Browse the repository at this point in the history
* More helpful error message when predicting with unfitted net

* Use check_is_fitted similar to what is used in sklearn, except that
  a skorch.exceptions.NotInitializedError is raised.
* Only assumes the presence of 'module_' so that nets remain hackable,
  except where a different attribute is specifically required.
* Re-wrote existing checks to now use check_is_fitted.

* skorch chech_is_fitted not calls sklearn check_is_fitted

Only changes error message and exception type.
  • Loading branch information
BenjaminBossan authored and ottonemo committed Aug 12, 2019
1 parent 158bd3d commit 7d275df
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 21 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0


### Added ### Added


- More helpful error messages when trying to predict using an uninitialized model.

### Changed ### Changed


- Improve numerical stability when using `NLLLoss` in `NeuralNetClassifer` (#491) - Improve numerical stability when using `NLLLoss` in `NeuralNetClassifer` (#491)
Expand Down
1 change: 1 addition & 0 deletions skorch/classifier.py
Expand Up @@ -352,6 +352,7 @@ def predict_proba(self, X):
""" """
y_probas = [] y_probas = []
self.check_is_fitted(attributes=['criterion_'])
bce_logits_loss = isinstance( bce_logits_loss = isinstance(
self.criterion_, torch.nn.BCEWithLogitsLoss) self.criterion_, torch.nn.BCEWithLogitsLoss)


Expand Down
63 changes: 43 additions & 20 deletions skorch/net.py
Expand Up @@ -25,6 +25,7 @@
from skorch.setter import optimizer_setter from skorch.setter import optimizer_setter
from skorch.utils import FirstStepAccumulator from skorch.utils import FirstStepAccumulator
from skorch.utils import TeeGenerator from skorch.utils import TeeGenerator
from skorch.utils import check_is_fitted
from skorch.utils import duplicate_items from skorch.utils import duplicate_items
from skorch.utils import get_map_location from skorch.utils import get_map_location
from skorch.utils import is_dataset from skorch.utils import is_dataset
Expand Down Expand Up @@ -679,6 +680,7 @@ def evaluation_step(self, Xi, training=False):
like dropout by setting ``training=True``. like dropout by setting ``training=True``.
""" """
self.check_is_fitted()
with torch.set_grad_enabled(training): with torch.set_grad_enabled(training):
self.module_.train(training) self.module_.train(training)
return self.infer(Xi) return self.infer(Xi)
Expand Down Expand Up @@ -852,6 +854,27 @@ def fit(self, X, y=None, **fit_params):
self.partial_fit(X, y, **fit_params) self.partial_fit(X, y, **fit_params)
return self return self


def check_is_fitted(self, attributes=None, *args, **kwargs):
"""Checks whether the net is initialized
Parameters
----------
attributes : iterable of str or None (default=None)
All the attributes that are strictly required of a fitted
net. By default, this is the `module_` attribute.
Other arguments as in
``sklearn.utils.validation.check_is_fitted``.
Raises
------
skorch.exceptions.NotInitializedError
When the given attributes are not present.
"""
attributes = attributes or ['module_']
check_is_fitted(self, attributes, *args, **kwargs)

def forward_iter(self, X, training=False, device='cpu'): def forward_iter(self, X, training=False, device='cpu'):
"""Yield outputs of module forward calls on each batch of data. """Yield outputs of module forward calls on each batch of data.
The storage device of the yielded tensors is determined The storage device of the yielded tensors is determined
Expand Down Expand Up @@ -1466,19 +1489,19 @@ def save_params(
""" """
if f_params is not None: if f_params is not None:
if not hasattr(self, 'module_'): msg = (
raise NotInitializedError( "Cannot save parameters of an un-initialized model. "
"Cannot save parameters of an un-initialized model. " "Please initialize first by calling .initialize() "
"Please initialize first by calling .initialize() " "or by fitting the model with .fit(...).")
"or by fitting the model with .fit(...).") self.check_is_fitted(msg=msg)
torch.save(self.module_.state_dict(), f_params) torch.save(self.module_.state_dict(), f_params)


if f_optimizer is not None: if f_optimizer is not None:
if not hasattr(self, 'optimizer_'): msg = (
raise NotInitializedError( "Cannot save state of an un-initialized optimizer. "
"Cannot save state of an un-initialized optimizer. " "Please initialize first by calling .initialize() "
"Please initialize first by calling .initialize() " "or by fitting the model with .fit(...).")
"or by fitting the model with .fit(...).") self.check_is_fitted(attributes=['optimizer_'], msg=msg)
torch.save(self.optimizer_.state_dict(), f_optimizer) torch.save(self.optimizer_.state_dict(), f_optimizer)


if f_history is not None: if f_history is not None:
Expand Down Expand Up @@ -1556,20 +1579,20 @@ def _get_state_dict(f):
f_optimizer = f_optimizer or formatted_files['f_optimizer'] f_optimizer = f_optimizer or formatted_files['f_optimizer']


if f_params is not None: if f_params is not None:
if not hasattr(self, 'module_'): msg = (
raise NotInitializedError( "Cannot load parameters of an un-initialized model. "
"Cannot load parameters of an un-initialized model. " "Please initialize first by calling .initialize() "
"Please initialize first by calling .initialize() " "or by fitting the model with .fit(...).")
"or by fitting the model with .fit(...).") self.check_is_fitted(msg=msg)
state_dict = _get_state_dict(f_params) state_dict = _get_state_dict(f_params)
self.module_.load_state_dict(state_dict) self.module_.load_state_dict(state_dict)


if f_optimizer is not None: if f_optimizer is not None:
if not hasattr(self, 'optimizer_'): msg = (
raise NotInitializedError( "Cannot load state of an un-initialized optimizer. "
"Cannot load state of an un-initialized optimizer. " "Please initialize first by calling .initialize() "
"Please initialize first by calling .initialize() " "or by fitting the model with .fit(...).")
"or by fitting the model with .fit(...).") self.check_is_fitted(attributes=['optimizer_'], msg=msg)
state_dict = _get_state_dict(f_optimizer) state_dict = _get_state_dict(f_optimizer)
self.optimizer_.load_state_dict(state_dict) self.optimizer_.load_state_dict(state_dict)


Expand Down
3 changes: 3 additions & 0 deletions skorch/tests/conftest.py
Expand Up @@ -11,6 +11,9 @@
F = nn.functional F = nn.functional




INFERENCE_METHODS = ['predict', 'predict_proba', 'forward', 'forward_iter']


################### ###################
# shared fixtures # # shared fixtures #
################### ###################
Expand Down
16 changes: 16 additions & 0 deletions skorch/tests/test_classifier.py
Expand Up @@ -12,6 +12,8 @@
import torch import torch
from torch import nn from torch import nn


from skorch.tests.conftest import INFERENCE_METHODS



torch.manual_seed(0) torch.manual_seed(0)


Expand Down Expand Up @@ -150,6 +152,20 @@ def test_fit(self, net_fit):
# fitting does not raise anything # fitting does not raise anything
pass pass


@pytest.mark.parametrize('method', INFERENCE_METHODS)
def test_not_fitted_raises(self, net_cls, module_cls, data, method):
from skorch.exceptions import NotInitializedError
net = net_cls(module_cls)
X = data[0]
with pytest.raises(NotInitializedError) as exc:
# we call `list` because `forward_iter` is lazy
list(getattr(net, method)(X))

msg = ("This NeuralNetBinaryClassifier instance is not initialized "
"yet. Call 'initialize' or 'fit' with appropriate arguments "
"before using this method.")
assert exc.value.args[0] == msg

@flaky(max_runs=3) @flaky(max_runs=3)
def test_net_learns(self, net_cls, module_cls, data): def test_net_learns(self, net_cls, module_cls, data):
X, y = data X, y = data
Expand Down
30 changes: 29 additions & 1 deletion skorch/tests/test_net.py
Expand Up @@ -18,16 +18,18 @@
import numpy as np import numpy as np
from packaging import version from packaging import version
import pytest import pytest
from sklearn.base import clone
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from sklearn.base import clone
import torch import torch
from torch import nn from torch import nn
from flaky import flaky from flaky import flaky


from skorch.exceptions import NotInitializedError
from skorch.tests.conftest import INFERENCE_METHODS
from skorch.utils import flatten from skorch.utils import flatten
from skorch.utils import to_numpy from skorch.utils import to_numpy
from skorch.utils import is_torch_data_type from skorch.utils import is_torch_data_type
Expand Down Expand Up @@ -183,6 +185,32 @@ def test_fit(self, net_fit):
# fitting does not raise anything # fitting does not raise anything
pass pass


@pytest.mark.parametrize('method', INFERENCE_METHODS)
def test_not_fitted_raises(self, net_cls, module_cls, data, method):
from skorch.exceptions import NotInitializedError
net = net_cls(module_cls)
X = data[0]
with pytest.raises(NotInitializedError) as exc:
# we call `list` because `forward_iter` is lazy
list(getattr(net, method)(X))

msg = ("This NeuralNetClassifier instance is not initialized yet. "
"Call 'initialize' or 'fit' with appropriate arguments "
"before using this method.")
assert exc.value.args[0] == msg

def test_not_fitted_other_attributes(self, module_cls):
# pass attributes to check for explicitly
with patch('skorch.net.check_is_fitted') as check:
from skorch import NeuralNetClassifier

net = NeuralNetClassifier(module_cls)
attributes = ['foo', 'bar_']

net.check_is_fitted(attributes=attributes)
args = check.call_args_list[0][0][1]
assert args == attributes

@flaky(max_runs=3) @flaky(max_runs=3)
def test_net_learns(self, net_cls, module_cls, data): def test_net_learns(self, net_cls, module_cls, data):
X, y = data X, y = data
Expand Down
16 changes: 16 additions & 0 deletions skorch/tests/test_regressor.py
Expand Up @@ -9,6 +9,8 @@
import pytest import pytest
import torch import torch


from skorch.tests.conftest import INFERENCE_METHODS



torch.manual_seed(0) torch.manual_seed(0)


Expand Down Expand Up @@ -60,6 +62,20 @@ def test_fit(self, net_fit):
# fitting does not raise anything # fitting does not raise anything
pass pass


@pytest.mark.parametrize('method', INFERENCE_METHODS)
def test_not_fitted_raises(self, net_cls, module_cls, data, method):
from skorch.exceptions import NotInitializedError
net = net_cls(module_cls)
X = data[0]
with pytest.raises(NotInitializedError) as exc:
# we call `list` because `forward_iter` is lazy
list(getattr(net, method)(X))

msg = ("This NeuralNetRegressor instance is not initialized "
"yet. Call 'initialize' or 'fit' with appropriate arguments "
"before using this method.")
assert exc.value.args[0] == msg

@flaky(max_runs=3) @flaky(max_runs=3)
def test_net_learns(self, net, net_cls, data, module_cls): def test_net_learns(self, net, net_cls, data, module_cls):
X, y = data X, y = data
Expand Down
28 changes: 28 additions & 0 deletions skorch/utils.py
Expand Up @@ -15,11 +15,14 @@
import numpy as np import numpy as np
from scipy import sparse from scipy import sparse
from sklearn.utils import safe_indexing from sklearn.utils import safe_indexing
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted as sklearn_check_is_fitted
import torch import torch
from torch.nn.utils.rnn import PackedSequence from torch.nn.utils.rnn import PackedSequence
from torch.utils.data.dataset import Subset from torch.utils.data.dataset import Subset


from skorch.exceptions import DeviceWarning from skorch.exceptions import DeviceWarning
from skorch.exceptions import NotInitializedError




class Ansi(Enum): class Ansi(Enum):
Expand Down Expand Up @@ -477,6 +480,31 @@ def get_map_location(target_device, fallback_device='cpu'):
return map_location return map_location




def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
"""Checks whether the net is initialized.
Note: This calls ``sklearn.utils.validation.check_is_fitted``
under the hood, using exactly the same arguments and logic. The
only difference is that this function has an adapted error message
and raises a ``skorch.exception.NotInitializedError`` instead of
an ``sklearn.exceptions.NotFittedError``.
"""
if msg is None:
msg = ("This %(name)s instance is not initialized yet. Call "
"'initialize' or 'fit' with appropriate arguments "
"before using this method.")
try:
sklearn_check_is_fitted(
estimator=estimator,
attributes=attributes,
msg=msg,
all_or_any=all_or_any,
)
except NotFittedError as e:
raise NotInitializedError(str(e))


class TeeGenerator: class TeeGenerator:
"""Stores a generator and calls ``tee`` on it to create new generators """Stores a generator and calls ``tee`` on it to create new generators
when ``TeeGenerator`` is iterated over to let you iterate over the given when ``TeeGenerator`` is iterated over to let you iterate over the given
Expand Down

0 comments on commit 7d275df

Please sign in to comment.