Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ class Checkpoint(Serializable):
``30000-checkpoint-94.pt``

**Warning:** Please, keep in mind that if filename collide with already used one to saved a checkpoint,
new checkpoint will not be stored. This means that filename like ``checkpoint.pt`` will be saved only once
and will not be overwritten by newer checkpoints.
new checkpoint will replace the older one. This means that filename like ``checkpoint.pt`` will be saved
every call and will always be overwritten by newer checkpoints.

Note:
To get the last stored filename, handler exposes attribute ``last_checkpoint``:
Expand Down Expand Up @@ -350,22 +350,26 @@ def __call__(self, engine: Engine) -> None:
}
filename = filename_pattern.format(**filename_dict)

if any(item.filename == filename for item in self._saved):
return

metadata = {
"basename": "{}{}{}".format(self.filename_prefix, "_" * int(len(self.filename_prefix) > 0), name),
"score_name": self.score_name,
"priority": priority,
}

if not self._check_lt_n_saved():
item = self._saved.pop(0)
try:
index = list(map(lambda it: it.filename == filename, self._saved)).index(True)
to_remove = True
except ValueError:
index = 0
to_remove = not self._check_lt_n_saved()

if to_remove:
item = self._saved.pop(index)
if isinstance(self.save_handler, BaseSaveHandler):
self.save_handler.remove(item.filename)

self._saved.append(Checkpoint.Item(priority, filename))
self._saved.sort(key=lambda item: item[0])
self._saved.sort(key=lambda it: it[0])

if self.include_self:
# Now that we've updated _saved, we can add our own state_dict.
Expand Down
25 changes: 25 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,9 @@ def _test(
)
assert res == "best_model_12_acc=0.9999.pt"

pattern = "{name}.{ext}"
assert _test(to_save, filename_pattern=pattern) == "model.pt"

pattern = "chk-{name}--{global_step}.{ext}"
assert _test(to_save, to_save, filename_pattern=pattern) == "chk-model--203.pt"
pattern = "chk-{filename_prefix}--{name}--{global_step}.{ext}"
Expand Down Expand Up @@ -1446,3 +1449,25 @@ def test_checkpoint_load_state_dict():
sd = {"saved": [(0, "model_0.pt"), (10, "model_10.pt"), (20, "model_20.pt")]}
checkpointer.load_state_dict(sd)
assert checkpointer._saved == true_checkpointer._saved


def test_checkpoint_fixed_filename():
model = DummyModel()
to_save = {"model": model}

def _test(n_saved):
save_handler = MagicMock(spec=BaseSaveHandler)
checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=n_saved, filename_pattern="{name}.{ext}")

trainer = Engine(lambda e, b: None)

for i in range(10):
trainer.state = State(epoch=i, iteration=i)
checkpointer(trainer)
assert save_handler.call_count == i + 1
metadata = {"basename": "model", "score_name": None, "priority": i}
save_handler.assert_called_with(model.state_dict(), "model.pt", metadata)

_test(None)
_test(1)
_test(3)