From e93f3fe3b805044da757475e26961efa43f1ae99 Mon Sep 17 00:00:00 2001 From: Desroziers Date: Fri, 30 Oct 2020 18:25:29 +0100 Subject: [PATCH 1/7] save model with same filename --- ignite/handlers/checkpoint.py | 29 ++++++++++++++---------- tests/ignite/handlers/test_checkpoint.py | 23 +++++++++++++++++++ 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 255a9775951b..6520ff3f8923 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -147,8 +147,8 @@ class Checkpoint(Serializable): ``30000-checkpoint-94.pt`` **Warning:** Please, keep in mind that if filename collide with already used one to saved a checkpoint, - new checkpoint will not be stored. This means that filename like ``checkpoint.pt`` will be saved only once - and will not be overwritten by newer checkpoints. + new checkpoint will replace the older one. This means that filename like ``checkpoint.pt`` will be saved + every call and will always be overwrittDiskSaveren by newer checkpoints. Note: To get the last stored filename, handler exposes attribute ``last_checkpoint``: @@ -350,8 +350,21 @@ def __call__(self, engine: Engine) -> None: } filename = filename_pattern.format(**filename_dict) - if any(item.filename == filename for item in self._saved): - return + filename_already_exists = any(item.filename == filename for item in self._saved) + + if filename_already_exists: + if isinstance(self.save_handler, BaseSaveHandler): + self.save_handler.remove(filename) + # list is purged + self._saved = [item for item in self._saved if item.filename != filename] + else: + if not self._check_lt_n_saved(): + item = self._saved.pop(0) + if isinstance(self.save_handler, BaseSaveHandler): + self.save_handler.remove(item.filename) + + self._saved.append(Checkpoint.Item(priority, filename)) + self._saved.sort(key=lambda item: item[0]) metadata = { "basename": "{}{}{}".format(self.filename_prefix, "_" * int(len(self.filename_prefix) > 0), name), @@ -359,14 +372,6 @@ def __call__(self, engine: Engine) -> None: "priority": priority, } - if not self._check_lt_n_saved(): - item = self._saved.pop(0) - if isinstance(self.save_handler, BaseSaveHandler): - self.save_handler.remove(item.filename) - - self._saved.append(Checkpoint.Item(priority, filename)) - self._saved.sort(key=lambda item: item[0]) - if self.include_self: # Now that we've updated _saved, we can add our own state_dict. checkpoint["checkpointer"] = self.state_dict() diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index d646a4866372..65e269ef6f65 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1309,6 +1309,9 @@ def _test( ) assert res == "best_model_12_acc=0.9999.pt" + pattern = "{name}.{ext}" + assert _test(to_save, filename_pattern=pattern) == "model.pt" + pattern = "chk-{name}--{global_step}.{ext}" assert _test(to_save, to_save, filename_pattern=pattern) == "chk-model--203.pt" pattern = "chk-{filename_prefix}--{name}--{global_step}.{ext}" @@ -1446,3 +1449,23 @@ def test_checkpoint_load_state_dict(): sd = {"saved": [(0, "model_0.pt"), (10, "model_10.pt"), (20, "model_20.pt")]} checkpointer.load_state_dict(sd) assert checkpointer._saved == true_checkpointer._saved + + +def test_checkpoint_fixed_filename(): + save_handler = MagicMock(spec=BaseSaveHandler) + model = DummyModel() + to_save = {"model": model} + checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=None, filename_pattern="{name}.{ext}") + + trainer = Engine(lambda e, b: None) + trainer.state = State(epoch=0, iteration=0) + + checkpointer(trainer) + assert save_handler.call_count == 1 + metadata = {"basename": "model", "score_name": None, "priority": 0} + save_handler.assert_called_with(model.state_dict(), "model.pt", metadata) + + checkpointer(trainer) + assert save_handler.call_count == 2 + metadata = {"basename": "model", "score_name": None, "priority": 0} + save_handler.assert_called_with(model.state_dict(), "model.pt", metadata) From f56576db38c4370e777b0cabff866d693f321211 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 30 Oct 2020 19:22:08 +0100 Subject: [PATCH 2/7] Update checkpoint.py --- ignite/handlers/checkpoint.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 6520ff3f8923..1e78e911b7ba 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -148,7 +148,7 @@ class Checkpoint(Serializable): **Warning:** Please, keep in mind that if filename collide with already used one to saved a checkpoint, new checkpoint will replace the older one. This means that filename like ``checkpoint.pt`` will be saved - every call and will always be overwrittDiskSaveren by newer checkpoints. + every call and will always be overwritten by newer checkpoints. Note: To get the last stored filename, handler exposes attribute ``last_checkpoint``: @@ -350,6 +350,12 @@ def __call__(self, engine: Engine) -> None: } filename = filename_pattern.format(**filename_dict) + metadata = { + "basename": "{}{}{}".format(self.filename_prefix, "_" * int(len(self.filename_prefix) > 0), name), + "score_name": self.score_name, + "priority": priority, + } + filename_already_exists = any(item.filename == filename for item in self._saved) if filename_already_exists: @@ -366,12 +372,6 @@ def __call__(self, engine: Engine) -> None: self._saved.append(Checkpoint.Item(priority, filename)) self._saved.sort(key=lambda item: item[0]) - metadata = { - "basename": "{}{}{}".format(self.filename_prefix, "_" * int(len(self.filename_prefix) > 0), name), - "score_name": self.score_name, - "priority": priority, - } - if self.include_self: # Now that we've updated _saved, we can add our own state_dict. checkpoint["checkpointer"] = self.state_dict() From 4cbcefe40e1f63630034b82231aa1c034715043d Mon Sep 17 00:00:00 2001 From: Desroziers Date: Sat, 31 Oct 2020 00:00:07 +0100 Subject: [PATCH 3/7] use elif --- ignite/handlers/checkpoint.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 1e78e911b7ba..5be61de32096 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -363,11 +363,10 @@ def __call__(self, engine: Engine) -> None: self.save_handler.remove(filename) # list is purged self._saved = [item for item in self._saved if item.filename != filename] - else: - if not self._check_lt_n_saved(): - item = self._saved.pop(0) - if isinstance(self.save_handler, BaseSaveHandler): - self.save_handler.remove(item.filename) + elif not self._check_lt_n_saved(): + item = self._saved.pop(0) + if isinstance(self.save_handler, BaseSaveHandler): + self.save_handler.remove(item.filename) self._saved.append(Checkpoint.Item(priority, filename)) self._saved.sort(key=lambda item: item[0]) From 225a3e82cbbeac6c0dddb79cc54d99e53da8043e Mon Sep 17 00:00:00 2001 From: Desroziers Date: Sat, 31 Oct 2020 00:15:48 +0100 Subject: [PATCH 4/7] refactor to have only one comprehension list --- ignite/handlers/checkpoint.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 5be61de32096..3958f1cba89b 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -356,13 +356,12 @@ def __call__(self, engine: Engine) -> None: "priority": priority, } - filename_already_exists = any(item.filename == filename for item in self._saved) + saved = [item for item in self._saved if item.filename != filename] - if filename_already_exists: + if self._saved != saved: if isinstance(self.save_handler, BaseSaveHandler): self.save_handler.remove(filename) - # list is purged - self._saved = [item for item in self._saved if item.filename != filename] + self._saved = saved elif not self._check_lt_n_saved(): item = self._saved.pop(0) if isinstance(self.save_handler, BaseSaveHandler): From 342bafb4bb05e20c2c4be721f6188aaffa797502 Mon Sep 17 00:00:00 2001 From: Desroziers Date: Sat, 31 Oct 2020 20:18:07 +0100 Subject: [PATCH 5/7] refactoring --- ignite/handlers/checkpoint.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 3958f1cba89b..804acf5935be 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -356,19 +356,20 @@ def __call__(self, engine: Engine) -> None: "priority": priority, } - saved = [item for item in self._saved if item.filename != filename] - - if self._saved != saved: - if isinstance(self.save_handler, BaseSaveHandler): - self.save_handler.remove(filename) - self._saved = saved - elif not self._check_lt_n_saved(): - item = self._saved.pop(0) + try: + index = list(map(lambda it: it.filename == filename, self._saved)).index(True) + to_remove = True + except ValueError: + index = 0 + to_remove = not self._check_lt_n_saved() + + if to_remove: + item = self._saved.pop(index) if isinstance(self.save_handler, BaseSaveHandler): self.save_handler.remove(item.filename) self._saved.append(Checkpoint.Item(priority, filename)) - self._saved.sort(key=lambda item: item[0]) + self._saved.sort(key=lambda it: it[0]) if self.include_self: # Now that we've updated _saved, we can add our own state_dict. From e659652e50a931bca826a8242f2329833c399e8a Mon Sep 17 00:00:00 2001 From: Desroziers Date: Sat, 31 Oct 2020 22:28:30 +0100 Subject: [PATCH 6/7] improve test --- tests/ignite/handlers/test_checkpoint.py | 26 +++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 65e269ef6f65..7519bc235bee 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1452,20 +1452,22 @@ def test_checkpoint_load_state_dict(): def test_checkpoint_fixed_filename(): - save_handler = MagicMock(spec=BaseSaveHandler) model = DummyModel() to_save = {"model": model} - checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=None, filename_pattern="{name}.{ext}") - trainer = Engine(lambda e, b: None) - trainer.state = State(epoch=0, iteration=0) + def _test(n_saved): + save_handler = MagicMock(spec=BaseSaveHandler) + checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=n_saved, filename_pattern="{name}.{ext}") - checkpointer(trainer) - assert save_handler.call_count == 1 - metadata = {"basename": "model", "score_name": None, "priority": 0} - save_handler.assert_called_with(model.state_dict(), "model.pt", metadata) + trainer = Engine(lambda e, b: None) - checkpointer(trainer) - assert save_handler.call_count == 2 - metadata = {"basename": "model", "score_name": None, "priority": 0} - save_handler.assert_called_with(model.state_dict(), "model.pt", metadata) + for i in range(10): + trainer.state = State(epoch=i, iteration=i) + checkpointer(trainer) + assert save_handler.call_count == i+1 + metadata = {"basename": "model", "score_name": None, "priority": i} + save_handler.assert_called_with(model.state_dict(), "model.pt", metadata) + + _test(None) + _test(1) + _test(3) From 7c98266296176c15d4bdfebcde06cff8ac112377 Mon Sep 17 00:00:00 2001 From: sdesrozis Date: Sat, 31 Oct 2020 21:30:35 +0000 Subject: [PATCH 7/7] autopep8 fix --- tests/ignite/handlers/test_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 7519bc235bee..479096c7423a 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1464,7 +1464,7 @@ def _test(n_saved): for i in range(10): trainer.state = State(epoch=i, iteration=i) checkpointer(trainer) - assert save_handler.call_count == i+1 + assert save_handler.call_count == i + 1 metadata = {"basename": "model", "score_name": None, "priority": i} save_handler.assert_called_with(model.state_dict(), "model.pt", metadata)