Skip to content

Commit

Permalink
Add restore_best_weights argument to EarlyStopping callback (#809)
Browse files Browse the repository at this point in the history
Add load_best option for EarlyStopping callback

This automatically loads module weights of the best result at the end of training.
  • Loading branch information
cedricrommel committed Nov 9, 2021
1 parent 617339b commit 9ac71e2
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- Added `load_best` attribute to `EarlyStopping` callback to automatically load module weights of the best result at the end of training

### Changed

Expand Down
31 changes: 31 additions & 0 deletions skorch/callbacks/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fnmatch import fnmatch
from functools import partial
from itertools import product
from copy import deepcopy

import numpy as np
from skorch.callbacks import Callback
Expand Down Expand Up @@ -367,6 +368,13 @@ class EarlyStopping(Callback):
sent to. By default, the output is printed to stdout, but the
sink could also be a logger or :func:`~skorch.utils.noop`.
load_best: bool (default=False)
Whether to restore module weights from the epoch with the best value of
the monitored quantity. If False, the module weights obtained at the
last step of training are used. Note that only the module is restored.
Use the ``Checkpoint`` callback with the :attr:`~Checkpoint.load_best`
argument set to ``True`` if you need to restore the whole object.
"""
def __init__(
self,
Expand All @@ -376,6 +384,7 @@ def __init__(
threshold_mode='rel',
lower_is_better=True,
sink=print,
load_best=False,
):
self.monitor = monitor
self.lower_is_better = lower_is_better
Expand All @@ -385,6 +394,13 @@ def __init__(
self.misses_ = 0
self.dynamic_threshold_ = None
self.sink = sink
self.load_best = load_best

def __getstate__(self):
# Avoids to save the module_ weights twice when pickling
state = self.__dict__.copy()
state['best_model_weights_'] = None
return state

# pylint: disable=arguments-differ
def on_train_begin(self, net, **kwargs):
Expand All @@ -393,6 +409,8 @@ def on_train_begin(self, net, **kwargs):
.format(self.threshold_mode))
self.misses_ = 0
self.dynamic_threshold_ = np.inf if self.lower_is_better else -np.inf
self.best_model_weights_ = None
self.best_epoch_ = 0

def on_epoch_end(self, net, **kwargs):
current_score = net.history[-1, self.monitor]
Expand All @@ -401,13 +419,26 @@ def on_epoch_end(self, net, **kwargs):
else:
self.misses_ = 0
self.dynamic_threshold_ = self._calc_new_threshold(current_score)
self.best_epoch_ = net.history[-1, "epoch"]
if self.load_best:
self.best_model_weights_ = deepcopy(net.module_.state_dict())
if self.misses_ == self.patience:
if net.verbose:
self._sink("Stopping since {} has not improved in the last "
"{} epochs.".format(self.monitor, self.patience),
verbose=net.verbose)
raise KeyboardInterrupt

def on_train_end(self, net, **kwargs):
if (
self.load_best and (self.best_epoch_ != net.history[-1, "epoch"])
and (self.best_model_weights_ is not None)
):
net.module_.load_state_dict(self.best_model_weights_)
self._sink("Restoring best model from epoch {}.".format(
self.best_epoch_
), verbose=net.verbose)

def _is_score_improved(self, score):
if self.lower_is_better:
return score < self.dynamic_threshold_
Expand Down
94 changes: 94 additions & 0 deletions skorch/tests/callbacks/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
from unittest.mock import Mock
from unittest.mock import patch
from unittest.mock import call
from copy import deepcopy

import numpy as np
import pytest
from sklearn.base import clone
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss
import torch
from torch.utils.data import TensorDataset

from skorch.helper import predefined_split


class TestCheckpoint:
Expand Down Expand Up @@ -487,6 +493,94 @@ def test_typical_use_case_nonstop(

assert len(net.history) == max_epochs

def test_weights_restore(
self, net_clf_cls, classifier_module, classifier_data,
early_stopping_cls):
patience = 3
max_epochs = 20
seed = 1

side_effect = []

def sink(x):
side_effect.append(x)

early_stopping_cb = early_stopping_cls(
patience=patience,
sink=sink,
load_best=True,
monitor="valid_acc",
lower_is_better=False,
)

# Split dataset to have a fixed validation
X_tr, X_val, y_tr, y_val = train_test_split(
*classifier_data, random_state=seed)
tr_dataset = TensorDataset(
torch.as_tensor(X_tr).float(), torch.as_tensor(y_tr))
val_dataset = TensorDataset(
torch.as_tensor(X_val).float(), torch.as_tensor(y_val))

# Fix the network once with early stoppping and fixed seed
net1 = net_clf_cls(
classifier_module,
callbacks=[early_stopping_cb],
max_epochs=max_epochs,
train_split=predefined_split(val_dataset),
)
torch.manual_seed(seed)
net1.fit(tr_dataset, y=None)

# Check training was stopped before the end
assert len(net1.history) < max_epochs

# check correct output messages
assert len(side_effect) == 2

msg = side_effect[0]
expected_msg = ("Stopping since valid_acc has not improved in "
"the last 3 epochs.")
assert msg == expected_msg

msg = side_effect[1]
expected_msg = "Restoring best model from epoch "
assert expected_msg in msg

# Recompute validation loss and store it together with module weights
y_proba = net1.predict_proba(val_dataset)
es_weights = deepcopy(net1.module_.state_dict())
es_loss = log_loss(y_val, y_proba)

# Retrain same classifier without ES, using the best epochs number
net2 = net_clf_cls(
classifier_module,
max_epochs=early_stopping_cb.best_epoch_,
train_split=predefined_split(val_dataset),
)
torch.manual_seed(seed)
net2.fit(tr_dataset, y=None)

# Check that weights obtained match
assert all(
torch.equal(wi, wj)
for wi, wj in zip(
net2.module_.state_dict().values(),
es_weights.values()
)
)

# Check validation loss obtained match
y_proba_2 = net2.predict_proba(val_dataset)
assert es_loss == log_loss(y_val, y_proba_2)

# Check best_model_weights_ is transformed into None when pickling
del net1.callbacks[0].sink
net1_pkl = pickle.dumps(net1)

reloaded_net1 = pickle.loads(net1_pkl)
assert reloaded_net1.callbacks[0].best_epoch_ == net1.callbacks[0].best_epoch_
assert reloaded_net1.callbacks[0].best_model_weights_ is None

def test_typical_use_case_stopping(
self, net_clf_cls, broken_classifier_module, classifier_data,
early_stopping_cls):
Expand Down

0 comments on commit 9ac71e2

Please sign in to comment.