From 8998c243703427f5efdb6dfeaced1f463c516cfc Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 31 Jan 2020 11:10:57 +0100 Subject: [PATCH] Fixes #756 - integer score value is not converted to float anymore --- ignite/handlers/checkpoint.py | 12 +++-- tests/ignite/handlers/test_checkpoint.py | 56 ++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 43c0b9928b6c..4f806cd5b282 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -1,5 +1,6 @@ import os import tempfile +import numbers from collections import namedtuple import collections.abc as collections @@ -177,21 +178,26 @@ def __call__(self, engine: Engine) -> None: if self._score_function is not None: priority = self._score_function(engine) + if not isinstance(priority, numbers.Number): + raise ValueError("Output of score_function should be a number") else: priority = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) if self._check_lt_n_saved() or self._saved[0].priority < priority: + priority_str = "{}".format(priority) if isinstance(priority, numbers.Integral) \ + else "{:.4f}".format(priority) + if self._score_name is not None: if len(suffix) > 0: suffix += "_" - suffix = "{}{}={:.4f}".format(suffix, self._score_name, priority) + suffix = "{}{}={}".format(suffix, self._score_name, priority_str) elif self._score_function is not None: if len(suffix) > 0: suffix += "_" - suffix = "{}{:.4f}".format(suffix, priority) + suffix = "{}{}".format(suffix, priority_str) elif len(suffix) == 0: - suffix = "{}".format(priority) + suffix = "{}".format(priority_str) checkpoint = self._setup_checkpoint() diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index b36d5fc106e2..208698bca458 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -47,6 +47,17 @@ def test_checkpoint_wrong_input(): Checkpoint(to_save, lambda x: x, score_function=lambda e: 123, score_name="acc", global_step_transform=123) +def test_checkpoint_score_function_wrong_output(): + model = DummyModel() + to_save = {'model': model} + + checkpointer = Checkpoint(to_save, lambda x: x, score_function=lambda e: {"1": 1}, score_name="acc") + trainer = Engine(lambda e, b: None) + trainer.state = State(epoch=0, iteration=0) + with pytest.raises(ValueError, match=r"Output of score_function should be a number"): + checkpointer(trainer) + + def test_checkpoint_default(): def _test(to_save, obj, name): @@ -199,6 +210,51 @@ def _test(to_save, obj, name): _test(to_save, {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint') +def test_checkpoint_with_int_score(): + + def _test(to_save, obj, name, score_name=None): + save_handler = MagicMock() + save_handler.remove = MagicMock() + + checkpointer = Checkpoint(to_save, save_handler=save_handler, + score_name=score_name, + score_function=lambda e: e.state.epoch) + + if score_name is None: + score_name = "" + else: + score_name += "=" + + trainer = Engine(lambda e, b: None) + trainer.state = State(epoch=1, iteration=1) + + checkpointer(trainer) + assert save_handler.call_count == 1 + + save_handler.assert_called_with(obj, "{}_{}1.pth".format(name, score_name)) + + trainer.state.epoch = 12 + trainer.state.iteration = 1234 + + checkpointer(trainer) + assert save_handler.call_count == 2 + save_handler.assert_called_with(obj, "{}_{}12.pth".format(name, score_name)) + assert save_handler.remove.call_count == 1 + save_handler.remove.assert_called_with("{}_{}1.pth".format(name, score_name)) + assert checkpointer.last_checkpoint == "{}_{}12.pth".format(name, score_name) + + model = DummyModel() + to_save = {'model': model} + _test(to_save, model.state_dict(), 'model') + _test(to_save, model.state_dict(), 'model', "epoch") + + model = DummyModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + to_save = {'model': model, 'optimizer': optimizer} + _test(to_save, {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint') + _test(to_save, {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint', "epoch") + + def test_checkpoint_with_score_function_and_trainer_epoch(): def _test(to_save, obj, name):