Skip to content
Browse files

Preparation for more correct get params change (#527)

* Introduce changes and tests in preparation for new get_params

The new behavior of get_params will be to not returned any "learned"
attributes such as "module_".

This PR implements the new behavior but doesn't switch to it yet to
give users time to adjust their code. This is a breaking change but it
is necessary since it is the "correct" behavior; the old one could
introduce subtle bugs in rare situations (e.g. `GridSearchCV` with a
net that has `warm_start=True`).

The PR also includes tests that are currently failing but that are
passing under the new behavior. When switching to the new behavior,
all tests, including these new ones, should pass (they currently

* Make net.history a property that refers to net.history_

Add setter and getter methods for net.history. That way, history now
ends on '_' like all other parameters that are not provided directly
by the user.

Co-Authored-By: ottonemo <>
  • Loading branch information...
BenjaminBossan and ottonemo committed Oct 10, 2019
1 parent 3aa9b43 commit 161f28d6e07a7470b9567a9c60a383f44c795984
Showing with 70 additions and 9 deletions.
  1. +5 −0
  2. +20 −7 skorch/
  3. +45 −2 skorch/tests/
@@ -21,6 +21,11 @@ and this project adheres to [Semantic Versioning](
- Improve numerical stability when using `NLLLoss` in `NeuralNetClassifer` (#491)
- Refactor code to make gradient accumulation easier to implement (#506)
- NeuralNetBinaryClassifier.predict_proba now returns a 2-dim array; to access the "old" y_proba, take y_proba[:, 1] (#515)
- net.history is now a property that accesses net.history_, which stores the History object (#527)

### Future Changes

- WARNING: In a future release, the behavior of method `net.get_params` will change to make it more consistent with sklearn: it will no longer return "learned" attributes like `module_`. Therefore, functions like `sklearn.base.clone`, when called with a fitted net, will no longer return a fitted net but instead an uninitialized net. If you want a copy of a fitted net, use `copy.deepcopy` instead. Note that `net.get_params` is used under the hood by many sklearn functions and classes, such as `GridSearchCV`, whose behavior may thus be affected by the change. (#521, #527)

### Fixed

@@ -20,7 +20,6 @@
from skorch.dataset import unpack_data
from skorch.dataset import uses_placeholder_y
from skorch.exceptions import DeviceWarning
from skorch.exceptions import NotInitializedError
from skorch.history import History
from skorch.setter import optimizer_setter
from skorch.utils import FirstStepAccumulator
@@ -233,10 +232,18 @@ def __init__(
kwargs = self._check_kwargs(kwargs)

self.history = history
self.history_ = history
self.initialized_ = initialized
self.virtual_params_ = virtual_params

def history(self):
return self.history_

def history(self, value):
self.history_ = value

def _default_callbacks(self):
return [
@@ -524,7 +531,7 @@ def initialize_optimizer(self, triggered_directly=True):

def initialize_history(self):
"""Initializes the history."""
self.history = History()
self.history_ = History()

def initialize(self):
"""Initializes all components of the :class:`.NeuralNet` and
@@ -1264,7 +1271,15 @@ def _get_params_for_optimizer(self, prefix, named_parameters):
return [pgroups], kwargs

def _get_param_names(self):
return self.__dict__.keys()
return (k for k in self.__dict__.keys() if k != 'history_')

def _get_param_names_new(self):
# TODO: This will be the new behavior for _get_param_names in
# a future release. This is to make get_params work as in
# sklearn, i.e. not returning "learned" attributes (ending on
# '_'). Once the transition period has passed, remove the old
# code and use the new one instead.
return (k for k in self.__dict__ if not k.endswith('_'))

def _get_params_callbacks(self, deep=True):
"""sklearn's .get_params checks for `hasattr(value,
@@ -1639,8 +1654,6 @@ def _get_state_dict(f):

def __repr__(self):
params = self.get_params(deep=False)

to_include = ['module']
to_exclude = []
parts = [str(self.__class__) + '[uninitialized](']
@@ -1649,7 +1662,7 @@ def __repr__(self):
to_include = ['module_']
to_exclude = ['module__']

for key, val in sorted(params.items()):
for key, val in sorted(self.__dict__.items()):
if not any(key.startswith(prefix) for prefix in to_include):
if any(key.startswith(prefix) for prefix in to_exclude):
@@ -36,6 +36,7 @@


# pylint: disable=too-many-public-methods
@@ -108,7 +109,7 @@ def net_pickleable(self, net_fit):
# remove mock callback
net_fit.callbacks_ = [(n, cb) for n, cb in net_fit.callbacks_
if not isinstance(cb, Mock)]
net_clone = clone(net_fit)
net_clone = copy.deepcopy(net_fit)
net_fit.callbacks = callbacks
net_fit.callbacks_ = callbacks_
return net_clone
@@ -271,7 +272,7 @@ def test_net_learns(self, net_cls, module_cls, data):
), y)
y_pred = net.predict(X)
assert accuracy_score(y, y_pred) > 0.65
assert accuracy_score(y, y_pred) > ACCURACY_EXPECTED

def test_forward(self, net_fit, data):
X = data[0]
@@ -1244,6 +1245,48 @@ def test_get_params_with_uninit_callbacks(self, net_cls, module_cls):

def test_get_params_no_learned_params(self, net_fit):
# TODO: This test should fail for now but should succeed once
# we change the behavior of get_params to be more in line with
# sklearn. At that point, remove the decorators.
params = net_fit.get_params()
params_learned = set(filter(lambda x: x.endswith('_'), params))
assert not params_learned

def test_clone_results_in_uninitialized_net(
self, net_fit, data):
# TODO: This test should fail for now but should succeed once
# we change the behavior of get_params to be more in line with
# sklearn. At that point, remove the decorators.
X, y = data
accuracy = accuracy_score(net_fit.predict(X), y)
assert accuracy > ACCURACY_EXPECTED # make sure net has learned

net_cloned = clone(net_fit).set_params(max_epochs=0)
net_cloned.callbacks_ = []
net_cloned.partial_fit(X, y)
accuracy_cloned = accuracy_score(net_cloned.predict(X), y)
assert accuracy_cloned < ACCURACY_EXPECTED

assert not net_cloned.history

def test_clone_copies_parameters(self, net_cls, module_cls):
kwargs = dict(
net = net_cls(module_cls, **kwargs)
net_cloned = clone(net)
params = net_cloned.get_params()
for key, val in kwargs.items():
assert params[key] == val

def test_with_initialized_module(self, net_cls, module_cls, data):
X, y = data
net = net_cls(module_cls(), max_epochs=1)

0 comments on commit 161f28d

Please sign in to comment.
You can’t perform that action at this time.