Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] ENH Uses checkpoint internally #463

Merged
merged 2 commits into from May 2, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 21 additions & 16 deletions skorch/callbacks/training.py
Expand Up @@ -604,7 +604,7 @@ def on_train_begin(self, net,
net.load_params(checkpoint=self.checkpoint)


class TrainEndCheckpoint(Checkpoint):
class TrainEndCheckpoint(Callback):
"""Saves the model parameters, optimizer state, and history at the end of
training. The default ``fn_prefix`` is 'train_end_'.

Expand Down Expand Up @@ -681,29 +681,34 @@ def __init__(
dirname='',
sink=noop,
):

# TODO: Remove warning in release 0.5.0
if fn_prefix is None:
warnings.warn(
"'fn_prefix' default value will change from 'final_' "
"to 'train_end_' in 0.5.0", FutureWarning)
fn_prefix = 'final_'

super().__init__(
self.f_params = f_params
self.f_optimizer = f_optimizer
self.f_history = f_history
self.f_pickle = f_pickle
self.fn_prefix = fn_prefix
self.dirname = dirname
self.sink = sink

def initialize(self):
self.checkpoint_ = Checkpoint(
monitor=None,
f_params=f_params,
f_optimizer=f_optimizer,
f_history=f_history,
f_pickle=f_pickle,
fn_prefix=fn_prefix,
dirname=dirname,
f_params=self.f_params,
f_optimizer=self.f_optimizer,
f_history=self.f_history,
f_pickle=self.f_pickle,
fn_prefix=self.fn_prefix,
dirname=self.dirname,
event_name=None,
sink=sink,
)

def on_epoch_end(self, net, **kwargs):
pass
sink=self.sink)
self.checkpoint_.initialize()

def on_train_end(self, net, **kwargs):
self.save_model(net)
self._sink("Final checkpoint triggered", net.verbose)
self.checkpoint_.save_model(net)
self.checkpoint_._sink("Final checkpoint triggered", net.verbose)
6 changes: 6 additions & 0 deletions skorch/tests/callbacks/test_training.py
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import pytest
from sklearn.base import clone


class TestCheckpoint:
Expand Down Expand Up @@ -891,3 +892,8 @@ def test_saves_at_end_with_custom_formatting(
call(f_optimizer='exp1/train_end_optimizer_10.pt'),
call(f_history='exp1/train_end_history.json')
])

def test_cloneable(self, finalcheckpoint_cls):
# reproduces bug #459
cp = finalcheckpoint_cls()
clone(cp) # does not raise