Skip to content

Commit 75e2042

Browse files
sdesrozisDesroziersvfdev-5
authored
Save model with same filename (#1423)
* save model with same filename * Update checkpoint.py * use elif * refactor to have only one comprehension list * refactoring * improve test * autopep8 fix Co-authored-by: Desroziers <sylvain.desroziers@ifpen.fr> Co-authored-by: vfdev <vfdev.5@gmail.com> Co-authored-by: sdesrozis <sdesrozis@users.noreply.github.com>
1 parent 5228846 commit 75e2042

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

ignite/handlers/checkpoint.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ class Checkpoint(Serializable):
147147
``30000-checkpoint-94.pt``
148148
149149
**Warning:** Please, keep in mind that if filename collide with already used one to saved a checkpoint,
150-
new checkpoint will not be stored. This means that filename like ``checkpoint.pt`` will be saved only once
151-
and will not be overwritten by newer checkpoints.
150+
new checkpoint will replace the older one. This means that filename like ``checkpoint.pt`` will be saved
151+
every call and will always be overwritten by newer checkpoints.
152152
153153
Note:
154154
To get the last stored filename, handler exposes attribute ``last_checkpoint``:
@@ -350,22 +350,26 @@ def __call__(self, engine: Engine) -> None:
350350
}
351351
filename = filename_pattern.format(**filename_dict)
352352

353-
if any(item.filename == filename for item in self._saved):
354-
return
355-
356353
metadata = {
357354
"basename": "{}{}{}".format(self.filename_prefix, "_" * int(len(self.filename_prefix) > 0), name),
358355
"score_name": self.score_name,
359356
"priority": priority,
360357
}
361358

362-
if not self._check_lt_n_saved():
363-
item = self._saved.pop(0)
359+
try:
360+
index = list(map(lambda it: it.filename == filename, self._saved)).index(True)
361+
to_remove = True
362+
except ValueError:
363+
index = 0
364+
to_remove = not self._check_lt_n_saved()
365+
366+
if to_remove:
367+
item = self._saved.pop(index)
364368
if isinstance(self.save_handler, BaseSaveHandler):
365369
self.save_handler.remove(item.filename)
366370

367371
self._saved.append(Checkpoint.Item(priority, filename))
368-
self._saved.sort(key=lambda item: item[0])
372+
self._saved.sort(key=lambda it: it[0])
369373

370374
if self.include_self:
371375
# Now that we've updated _saved, we can add our own state_dict.

tests/ignite/handlers/test_checkpoint.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,9 @@ def _test(
13091309
)
13101310
assert res == "best_model_12_acc=0.9999.pt"
13111311

1312+
pattern = "{name}.{ext}"
1313+
assert _test(to_save, filename_pattern=pattern) == "model.pt"
1314+
13121315
pattern = "chk-{name}--{global_step}.{ext}"
13131316
assert _test(to_save, to_save, filename_pattern=pattern) == "chk-model--203.pt"
13141317
pattern = "chk-{filename_prefix}--{name}--{global_step}.{ext}"
@@ -1446,3 +1449,25 @@ def test_checkpoint_load_state_dict():
14461449
sd = {"saved": [(0, "model_0.pt"), (10, "model_10.pt"), (20, "model_20.pt")]}
14471450
checkpointer.load_state_dict(sd)
14481451
assert checkpointer._saved == true_checkpointer._saved
1452+
1453+
1454+
def test_checkpoint_fixed_filename():
1455+
model = DummyModel()
1456+
to_save = {"model": model}
1457+
1458+
def _test(n_saved):
1459+
save_handler = MagicMock(spec=BaseSaveHandler)
1460+
checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=n_saved, filename_pattern="{name}.{ext}")
1461+
1462+
trainer = Engine(lambda e, b: None)
1463+
1464+
for i in range(10):
1465+
trainer.state = State(epoch=i, iteration=i)
1466+
checkpointer(trainer)
1467+
assert save_handler.call_count == i + 1
1468+
metadata = {"basename": "model", "score_name": None, "priority": i}
1469+
save_handler.assert_called_with(model.state_dict(), "model.pt", metadata)
1470+
1471+
_test(None)
1472+
_test(1)
1473+
_test(3)

0 commit comments

Comments
 (0)