Skip to content

Commit

Permalink
Fix false warning if iterator_valid__shuffle=False (#908)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan committed Oct 13, 2022
1 parent 5ec195c commit cfe568b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
11 changes: 6 additions & 5 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,11 +1914,12 @@ def _validate_params(self):
# warn about usage of iterator_valid__shuffle=True, since this
# is almost certainly not what the user wants
if 'iterator_valid__shuffle' in self._params_to_validate:
warnings.warn(
"You set iterator_valid__shuffle=True; this is most likely not "
"what you want because the values returned by predict and "
"predict_proba will be shuffled.",
UserWarning)
if self.iterator_valid__shuffle:
warnings.warn(
"You set iterator_valid__shuffle=True; this is most likely not "
"what you want because the values returned by predict and "
"predict_proba will be shuffled.",
UserWarning)

# check for wrong arguments
unexpected_kwargs = []
Expand Down
19 changes: 12 additions & 7 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,19 +263,24 @@ def __init__(self, *args, **kwargs):
# "optimizer_2".
MyNet(module_cls, optimizer_2__lr=0.123) # should not raise

def test_net_init_with_iterator_valid_shuffle_true(
def test_net_init_with_iterator_valid_shuffle_false_no_warning(
self, net_cls, module_cls, recwarn):
# If a user sets iterator_valid__shuffle=False, everything is good and
# no warning should be issued, see
# https://github.com/skorch-dev/skorch/issues/907
net_cls(module_cls, iterator_valid__shuffle=False).initialize()
assert not recwarn.list

def test_net_init_with_iterator_valid_shuffle_true_warns(
self, net_cls, module_cls, recwarn):
# If a user sets iterator_valid__shuffle=True, they might be
# in for a surprise, since predict et al. will result in
# shuffled predictions. It is best to warn about this, since
# most of the times, this is not what users actually want.
expected = (
"You set iterator_valid__shuffle=True; this is most likely not what you want "
"because the values returned by predict and predict_proba will be shuffled.")

# no warning expected here
net_cls(module_cls, iterator_valid__shuffle=False)
assert not recwarn.list
"You set iterator_valid__shuffle=True; this is most likely not what you "
"want because the values returned by predict and predict_proba will be "
"shuffled.")

# warning expected here
with pytest.warns(UserWarning, match=expected):
Expand Down

0 comments on commit cfe568b

Please sign in to comment.