Skip to content

Commit

Permalink
Address #925 (#926)
Browse files Browse the repository at this point in the history
Changing the `_get_param_names` method to return a list instead of a
generator to fix the exception error message when passing unknown
parameters to `set_params`. Before the error message just included
the generator `repr`-string as the list of possible parameters.
Now the string contains the possible parameter names instead.
  • Loading branch information
githubnemo committed May 8, 2023
1 parent 785b917 commit df84519
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed install command to work with recent changes in Google Colab (#928)
- Fixed a couple of bugs related to using non-default modules and criteria (#927)
- Fixed a bug when using `AccelerateMixin` in a multi-GPU setup (#947)
- `_get_param_names` returns a list instead of a generator so that subsequent
error messages return useful information instead of a generator `repr`
string (#925)

## [0.12.1] - 2022-11-18

Expand Down
2 changes: 1 addition & 1 deletion skorch/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def on_grad_computed(
"""

def _get_param_names(self):
return (key for key in self.__dict__ if not key.endswith('_'))
return [key for key in self.__dict__ if not key.endswith('_')]

def get_params(self, deep=True):
return BaseEstimator.get_params(self, deep=deep)
Expand Down
2 changes: 1 addition & 1 deletion skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,7 +1934,7 @@ def get_params_for_optimizer(self, prefix, named_parameters):
return args, kwargs

def _get_param_names(self):
return (k for k in self.__dict__ if not k.endswith('_'))
return [k for k in self.__dict__ if not k.endswith('_')]

def _get_params_callbacks(self, deep=True):
"""sklearn's .get_params checks for `hasattr(value,
Expand Down
21 changes: 21 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,27 @@ def test_set_params_works(self, net, data):
assert net.module_.sequential[3].in_features == 20
assert np.isclose(net.lr, 0.2)

def test_unknown_set_params_gives_helpful_message(self, net_fit):
# test that the error message of set_params includes helpful
# information instead of, e.g., generator expressions.
# sklearn 0.2x does not output the parameter names so we can
# skip detailled checks of the error message there.

sklearn_0_2x_string = "Check the list of available parameters with `estimator.get_params().keys()`"

with pytest.raises(ValueError) as e:
net_fit.set_params(invalid_parameter_xyz=42)

exception_str = str(e.value)

if sklearn_0_2x_string in exception_str:
return

expected_keys = ["module", "criterion"]

for key in expected_keys:
assert key in exception_str[exception_str.find("Valid parameters are: ") :]

def test_set_params_then_initialize_remembers_param(
self, net_cls, module_cls):
net = net_cls(module_cls)
Expand Down

0 comments on commit df84519

Please sign in to comment.