Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import tempfile
import numbers

from collections import namedtuple
import collections.abc as collections
Expand Down Expand Up @@ -177,21 +178,26 @@ def __call__(self, engine: Engine) -> None:

if self._score_function is not None:
priority = self._score_function(engine)
if not isinstance(priority, numbers.Number):
raise ValueError("Output of score_function should be a number")
else:
priority = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)

if self._check_lt_n_saved() or self._saved[0].priority < priority:

priority_str = "{}".format(priority) if isinstance(priority, numbers.Integral) \
else "{:.4f}".format(priority)

if self._score_name is not None:
if len(suffix) > 0:
suffix += "_"
suffix = "{}{}={:.4f}".format(suffix, self._score_name, priority)
suffix = "{}{}={}".format(suffix, self._score_name, priority_str)
elif self._score_function is not None:
if len(suffix) > 0:
suffix += "_"
suffix = "{}{:.4f}".format(suffix, priority)
suffix = "{}{}".format(suffix, priority_str)
elif len(suffix) == 0:
suffix = "{}".format(priority)
suffix = "{}".format(priority_str)

checkpoint = self._setup_checkpoint()

Expand Down
56 changes: 56 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def test_checkpoint_wrong_input():
Checkpoint(to_save, lambda x: x, score_function=lambda e: 123, score_name="acc", global_step_transform=123)


def test_checkpoint_score_function_wrong_output():
model = DummyModel()
to_save = {'model': model}

checkpointer = Checkpoint(to_save, lambda x: x, score_function=lambda e: {"1": 1}, score_name="acc")
trainer = Engine(lambda e, b: None)
trainer.state = State(epoch=0, iteration=0)
with pytest.raises(ValueError, match=r"Output of score_function should be a number"):
checkpointer(trainer)


def test_checkpoint_default():

def _test(to_save, obj, name):
Expand Down Expand Up @@ -199,6 +210,51 @@ def _test(to_save, obj, name):
_test(to_save, {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint')


def test_checkpoint_with_int_score():

def _test(to_save, obj, name, score_name=None):
save_handler = MagicMock()
save_handler.remove = MagicMock()

checkpointer = Checkpoint(to_save, save_handler=save_handler,
score_name=score_name,
score_function=lambda e: e.state.epoch)

if score_name is None:
score_name = ""
else:
score_name += "="

trainer = Engine(lambda e, b: None)
trainer.state = State(epoch=1, iteration=1)

checkpointer(trainer)
assert save_handler.call_count == 1

save_handler.assert_called_with(obj, "{}_{}1.pth".format(name, score_name))

trainer.state.epoch = 12
trainer.state.iteration = 1234

checkpointer(trainer)
assert save_handler.call_count == 2
save_handler.assert_called_with(obj, "{}_{}12.pth".format(name, score_name))
assert save_handler.remove.call_count == 1
save_handler.remove.assert_called_with("{}_{}1.pth".format(name, score_name))
assert checkpointer.last_checkpoint == "{}_{}12.pth".format(name, score_name)

model = DummyModel()
to_save = {'model': model}
_test(to_save, model.state_dict(), 'model')
_test(to_save, model.state_dict(), 'model', "epoch")

model = DummyModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
to_save = {'model': model, 'optimizer': optimizer}
_test(to_save, {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint')
_test(to_save, {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint', "epoch")


def test_checkpoint_with_score_function_and_trainer_epoch():

def _test(to_save, obj, name):
Expand Down