diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 83884ef5536a..9bb936ac70ca 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -147,8 +147,8 @@ class Checkpoint(Serializable): ``30000-checkpoint-94.pt`` **Warning:** Please, keep in mind that if filename collide with already used one to saved a checkpoint, - new checkpoint will not be stored. This means that filename like ``checkpoint.pt`` will be saved only once - and will not be overwritten by newer checkpoints. + new checkpoint will replace the older one. This means that filename like ``checkpoint.pt`` will be saved + every call and will always be overwritten by newer checkpoints. Note: To get the last stored filename, handler exposes attribute ``last_checkpoint``: @@ -350,22 +350,26 @@ def __call__(self, engine: Engine) -> None: } filename = filename_pattern.format(**filename_dict) - if any(item.filename == filename for item in self._saved): - return - metadata = { "basename": "{}{}{}".format(self.filename_prefix, "_" * int(len(self.filename_prefix) > 0), name), "score_name": self.score_name, "priority": priority, } - if not self._check_lt_n_saved(): - item = self._saved.pop(0) + try: + index = list(map(lambda it: it.filename == filename, self._saved)).index(True) + to_remove = True + except ValueError: + index = 0 + to_remove = not self._check_lt_n_saved() + + if to_remove: + item = self._saved.pop(index) if isinstance(self.save_handler, BaseSaveHandler): self.save_handler.remove(item.filename) self._saved.append(Checkpoint.Item(priority, filename)) - self._saved.sort(key=lambda item: item[0]) + self._saved.sort(key=lambda it: it[0]) if self.include_self: # Now that we've updated _saved, we can add our own state_dict. diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index d646a4866372..479096c7423a 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1309,6 +1309,9 @@ def _test( ) assert res == "best_model_12_acc=0.9999.pt" + pattern = "{name}.{ext}" + assert _test(to_save, filename_pattern=pattern) == "model.pt" + pattern = "chk-{name}--{global_step}.{ext}" assert _test(to_save, to_save, filename_pattern=pattern) == "chk-model--203.pt" pattern = "chk-{filename_prefix}--{name}--{global_step}.{ext}" @@ -1446,3 +1449,25 @@ def test_checkpoint_load_state_dict(): sd = {"saved": [(0, "model_0.pt"), (10, "model_10.pt"), (20, "model_20.pt")]} checkpointer.load_state_dict(sd) assert checkpointer._saved == true_checkpointer._saved + + +def test_checkpoint_fixed_filename(): + model = DummyModel() + to_save = {"model": model} + + def _test(n_saved): + save_handler = MagicMock(spec=BaseSaveHandler) + checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=n_saved, filename_pattern="{name}.{ext}") + + trainer = Engine(lambda e, b: None) + + for i in range(10): + trainer.state = State(epoch=i, iteration=i) + checkpointer(trainer) + assert save_handler.call_count == i + 1 + metadata = {"basename": "model", "score_name": None, "priority": i} + save_handler.assert_called_with(model.state_dict(), "model.pt", metadata) + + _test(None) + _test(1) + _test(3)