From a8c0a06ae67172c657db503cf4468603ceba0010 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Feb 2021 17:54:53 +0800 Subject: [PATCH 1/5] [NVIDIA] add greater_or_equal option to checkpoint handler Signed-off-by: Nic Ma --- ignite/handlers/checkpoint.py | 12 ++++++++- tests/ignite/handlers/test_checkpoint.py | 33 ++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 4bcebcab956c..968ea48297f4 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -98,6 +98,8 @@ class Checkpoint(Serializable): details. include_self (bool): Whether to include the `state_dict` of this object in the checkpoint. If `True`, then there must not be another object in ``to_save`` with key ``checkpointer``. + greater_or_equal (bool): when comparing priority, whether to save if new priority equals to _saved[0], + default to `False` to only save when new priority is greater. .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/ torch.nn.parallel.DistributedDataParallel.html @@ -261,6 +263,7 @@ def __init__( global_step_transform: Optional[Callable] = None, filename_pattern: Optional[str] = None, include_self: bool = False, + greater_or_equal: bool = False, ) -> None: if to_save is not None: # for compatibility with ModelCheckpoint @@ -301,6 +304,7 @@ def __init__( self.filename_pattern = filename_pattern self._saved = [] # type: List["Checkpoint.Item"] self.include_self = include_self + self.greater_or_equal = greater_or_equal def reset(self) -> None: """Method to reset saved checkpoint names. @@ -354,7 +358,13 @@ def __call__(self, engine: Engine) -> None: global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) priority = global_step - if self._check_lt_n_saved() or self._saved[0].priority < priority: + def _compare_fn(old, new): + if self.greater_or_equal: + return new >= old + else: + return new > old + + if self._check_lt_n_saved() or _compare_fn(self._saved[0].priority, priority): priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}" diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 8acc73407bb7..1b27363964a0 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1526,3 +1526,36 @@ def test_checkpoint_reset_with_engine(dirname): expected += [f"{_PREFIX}_{name}_{i}.pt" for i in [1 * 2, 2 * 2]] assert sorted(os.listdir(dirname)) == sorted(expected) assert "PREFIX_model_4.pt" in handler.last_checkpoint + + +def test_greater_or_equal(): + scores = iter([1, 2, 2, 2]) + + def score_function(_): + return next(scores) + + class Saver: + def __init__(self): + self.counter = 0 + + def __call__(self, c, f, m): + if self.counter == 0: + assert f == "model_1.pt" + else: + assert f == "model_2.pt" + self.counter += 1 + + handler = Saver() + + checkpointer = Checkpoint( + to_save={"model": DummyModel()}, + save_handler=handler, + score_function=score_function, + n_saved=2, + greater_or_equal=True, + ) + trainer = Engine(lambda e, b: None) + + for _ in range(4): + checkpointer(trainer) + assert handler.counter == 3 From 330c54037584ae98e9ff5cd9d4cf4b341216e4a4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Feb 2021 18:10:17 +0800 Subject: [PATCH 2/5] [NVIDIA] fix mypy errors Signed-off-by: Nic Ma --- ignite/handlers/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 968ea48297f4..b9a288a29236 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -358,7 +358,7 @@ def __call__(self, engine: Engine) -> None: global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) priority = global_step - def _compare_fn(old, new): + def _compare_fn(old: Union[int, float], new: Union[int, float]) -> bool: if self.greater_or_equal: return new >= old else: From 0ad512a00dc2c79bb4b1efc80fb4486db616b2d6 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Feb 2021 18:52:03 +0800 Subject: [PATCH 3/5] [NVIDIA] update doc-string Signed-off-by: Nic Ma --- ignite/handlers/checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index b9a288a29236..366ed7706ff8 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -98,8 +98,8 @@ class Checkpoint(Serializable): details. include_self (bool): Whether to include the `state_dict` of this object in the checkpoint. If `True`, then there must not be another object in ``to_save`` with key ``checkpointer``. - greater_or_equal (bool): when comparing priority, whether to save if new priority equals to _saved[0], - default to `False` to only save when new priority is greater. + greater_or_equal (bool): if `True`, the latest equally scored model is stored. Otherwise, the first model. + Default, `False`. .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/ torch.nn.parallel.DistributedDataParallel.html From b29fed6f8867ed990825fc4dbbbd7e7782d15e2d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Feb 2021 18:58:37 +0800 Subject: [PATCH 4/5] [NVIDIA] update according to comments Signed-off-by: Nic Ma --- ignite/handlers/checkpoint.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 366ed7706ff8..d7da85273b73 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -343,6 +343,12 @@ def _check_lt_n_saved(self, or_equal: bool = False) -> bool: return True return len(self._saved) < self.n_saved + int(or_equal) + def _compare_fn(self, new: Union[int, float]) -> bool: + if self.greater_or_equal: + return new >= self._saved[0].priority + else: + return new > self._saved[0].priority + def __call__(self, engine: Engine) -> None: global_step = None @@ -358,13 +364,7 @@ def __call__(self, engine: Engine) -> None: global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) priority = global_step - def _compare_fn(old: Union[int, float], new: Union[int, float]) -> bool: - if self.greater_or_equal: - return new >= old - else: - return new > old - - if self._check_lt_n_saved() or _compare_fn(self._saved[0].priority, priority): + if self._check_lt_n_saved() or self._compare_fn(priority): priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}" From 762a8dd96f0aa545356f3e57899eceeb3a827454 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Feb 2021 19:12:43 +0800 Subject: [PATCH 5/5] [NVIDIA] fix typo Signed-off-by: Nic Ma --- tests/ignite/handlers/test_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 1b27363964a0..e5cd023ed341 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1558,4 +1558,4 @@ def __call__(self, c, f, m): for _ in range(4): checkpointer(trainer) - assert handler.counter == 3 + assert handler.counter == 4