Skip to content

Commit

Permalink
FIX Allows TrainEndCheckpoint to be unpickled (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan committed May 30, 2021
1 parent 54796f1 commit 37313f5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Fixed a few bugs in the `net.history` implementation (#776)
- Fixed a bug in `TrainEndCheckpoint` that prevented it from being unpickled (#773)

## [0.10.0] - 2021-03-23

Expand Down
12 changes: 7 additions & 5 deletions skorch/callbacks/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,8 +639,11 @@ def on_train_begin(self, net,
X=None, y=None, **kwargs):
if not self.did_load_:
self.did_load_ = True
with suppress(Exception):
net.load_params(checkpoint=self.checkpoint)
with suppress(FileNotFoundError):
if isinstance(self.checkpoint, TrainEndCheckpoint):
net.load_params(checkpoint=self.checkpoint.checkpoint_)
else:
net.load_params(checkpoint=self.checkpoint)


class TrainEndCheckpoint(Callback):
Expand Down Expand Up @@ -752,10 +755,9 @@ def initialize(self):
**self._f_kwargs()
)
self.checkpoint_.initialize()
return self

def on_train_end(self, net, **kwargs):
self.checkpoint_.save_model(net)
self.checkpoint_._sink("Final checkpoint triggered", net.verbose)

def __getattr__(self, attr):
return getattr(self.checkpoint_, attr)
return self
15 changes: 15 additions & 0 deletions skorch/tests/callbacks/test_training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for callbacks in training.py"""

from functools import partial
import pickle
from unittest.mock import Mock
from unittest.mock import patch
from unittest.mock import call
Expand Down Expand Up @@ -1125,3 +1126,17 @@ def initialize_module(self, *args, **kwargs):

assert save_params_mock.call_count == 1
save_params_mock.assert_has_calls([call(f_mymodule='train_end_mymodule.pt')])

def test_pickle_uninitialized_callback(self, trainendcheckpoint_cls):
# isuue 773
cp = trainendcheckpoint_cls()
# does not raise
s = pickle.dumps(cp)
pickle.loads(s)

def test_pickle_initialized_callback(self, trainendcheckpoint_cls):
# issue 773
cp = trainendcheckpoint_cls().initialize()
# does not raise
s = pickle.dumps(cp)
pickle.loads(s)

0 comments on commit 37313f5

Please sign in to comment.