Skip to content

Commit

Permalink
Fix wrong error message in Checkpoint (#869)
Browse files Browse the repository at this point in the history
When a key was monitored that was not found in history, the resulting
error message would be:

skorch.exceptions.SkorchException: Monitor value 'Key
'valid_loss_best' was not found in history.' cannot be found in history.
Make sure you have validation data if you use validation scores for
checkpointing.

This bugfix now returns the correct error message.
  • Loading branch information
BenjaminBossan committed Aug 1, 2022
1 parent ef23ce4 commit d5b0001
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
8 changes: 4 additions & 4 deletions skorch/callbacks/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ def on_epoch_end(self, net, **kwargs):
try:
do_checkpoint = net.history[-1, self.monitor]
except KeyError as e:
raise SkorchException(
"Monitor value '{}' cannot be found in history. "
"Make sure you have validation data if you use "
"validation scores for checkpointing.".format(e.args[0]))
msg = (
f"{e.args[0]} Make sure you have validation data if you use "
"validation scores for checkpointing.")
raise SkorchException(msg)

if self.event_name is not None:
net.history.record(self.event_name, bool(do_checkpoint))
Expand Down
15 changes: 7 additions & 8 deletions skorch/tests/callbacks/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,14 @@ def test_default_without_validation_raises_meaningful_error(
train_split=None
)
from skorch.exceptions import SkorchException
with pytest.raises(SkorchException) as e:

msg_expected = (
r"Key 'valid_loss_best' was not found in history. "
r"Make sure you have validation data if you use "
r"validation scores for checkpointing."
)
with pytest.raises(SkorchException, match=msg_expected):
net.fit(*data)
expected = (
"Monitor value '{}' cannot be found in history. "
"Make sure you have validation data if you use "
"validation scores for checkpointing.".format(
'valid_loss_best')
)
assert str(e.value) == expected

def test_string_monitor_and_formatting(
self, save_params_mock, net_cls, checkpoint_cls, data):
Expand Down

0 comments on commit d5b0001

Please sign in to comment.