Skip to content
Permalink
Browse files

[MRG] ENH Uses checkpoint internally (#463)

Fix a bug in TrainEndCheckpoint that prevented it from being cloned

TrainEndCheckpoint now wraps Checkpoint instead of inheriting from it. This way, there is no longer a mismatch in the init parameters.
  • Loading branch information...
thomasjpfan authored and BenjaminBossan committed May 2, 2019
1 parent 910c560 commit 17256a1325ea463003a0e927a09a11065dc0b21a
Showing with 28 additions and 16 deletions.
  1. +1 −0 CHANGES.md
  2. +21 −16 skorch/callbacks/training.py
  3. +6 −0 skorch/tests/callbacks/test_training.py
@@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Include requirements in MANIFEST.in
- Add `criterion_` to `NeuralNet.cuda_dependent_attributes_` to avoid issues with criterion
weight tensors from, e.g., `NLLLoss` (#426)
- `TrainEndCheckpoint` can be cloned by `sklearn.base.clone`. (#459)


## [0.5.0] - 2018-12-13
@@ -604,7 +604,7 @@ def on_train_begin(self, net,
net.load_params(checkpoint=self.checkpoint)


class TrainEndCheckpoint(Checkpoint):
class TrainEndCheckpoint(Callback):
"""Saves the model parameters, optimizer state, and history at the end of
training. The default ``fn_prefix`` is 'train_end_'.
@@ -681,29 +681,34 @@ def __init__(
dirname='',
sink=noop,
):

# TODO: Remove warning in release 0.5.0
if fn_prefix is None:
warnings.warn(
"'fn_prefix' default value will change from 'final_' "
"to 'train_end_' in 0.5.0", FutureWarning)
fn_prefix = 'final_'

super().__init__(
self.f_params = f_params
self.f_optimizer = f_optimizer
self.f_history = f_history
self.f_pickle = f_pickle
self.fn_prefix = fn_prefix
self.dirname = dirname
self.sink = sink

def initialize(self):
self.checkpoint_ = Checkpoint(
monitor=None,
f_params=f_params,
f_optimizer=f_optimizer,
f_history=f_history,
f_pickle=f_pickle,
fn_prefix=fn_prefix,
dirname=dirname,
f_params=self.f_params,
f_optimizer=self.f_optimizer,
f_history=self.f_history,
f_pickle=self.f_pickle,
fn_prefix=self.fn_prefix,
dirname=self.dirname,
event_name=None,
sink=sink,
)

def on_epoch_end(self, net, **kwargs):
pass
sink=self.sink)
self.checkpoint_.initialize()

def on_train_end(self, net, **kwargs):
self.save_model(net)
self._sink("Final checkpoint triggered", net.verbose)
self.checkpoint_.save_model(net)
self.checkpoint_._sink("Final checkpoint triggered", net.verbose)
@@ -7,6 +7,7 @@

import numpy as np
import pytest
from sklearn.base import clone


class TestCheckpoint:
@@ -891,3 +892,8 @@ def test_saves_at_end_with_custom_formatting(
call(f_optimizer='exp1/train_end_optimizer_10.pt'),
call(f_history='exp1/train_end_history.json')
])

def test_cloneable(self, finalcheckpoint_cls):
# reproduces bug #459
cp = finalcheckpoint_cls()
clone(cp) # does not raise

0 comments on commit 17256a1

Please sign in to comment.
You can’t perform that action at this time.