diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 4bcebcab956c..d7da85273b73 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -98,6 +98,8 @@ class Checkpoint(Serializable): details. include_self (bool): Whether to include the `state_dict` of this object in the checkpoint. If `True`, then there must not be another object in ``to_save`` with key ``checkpointer``. + greater_or_equal (bool): if `True`, the latest equally scored model is stored. Otherwise, the first model. + Default, `False`. .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/ torch.nn.parallel.DistributedDataParallel.html @@ -261,6 +263,7 @@ def __init__( global_step_transform: Optional[Callable] = None, filename_pattern: Optional[str] = None, include_self: bool = False, + greater_or_equal: bool = False, ) -> None: if to_save is not None: # for compatibility with ModelCheckpoint @@ -301,6 +304,7 @@ def __init__( self.filename_pattern = filename_pattern self._saved = [] # type: List["Checkpoint.Item"] self.include_self = include_self + self.greater_or_equal = greater_or_equal def reset(self) -> None: """Method to reset saved checkpoint names. @@ -339,6 +343,12 @@ def _check_lt_n_saved(self, or_equal: bool = False) -> bool: return True return len(self._saved) < self.n_saved + int(or_equal) + def _compare_fn(self, new: Union[int, float]) -> bool: + if self.greater_or_equal: + return new >= self._saved[0].priority + else: + return new > self._saved[0].priority + def __call__(self, engine: Engine) -> None: global_step = None @@ -354,7 +364,7 @@ def __call__(self, engine: Engine) -> None: global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) priority = global_step - if self._check_lt_n_saved() or self._saved[0].priority < priority: + if self._check_lt_n_saved() or self._compare_fn(priority): priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}" diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 8acc73407bb7..e5cd023ed341 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1526,3 +1526,36 @@ def test_checkpoint_reset_with_engine(dirname): expected += [f"{_PREFIX}_{name}_{i}.pt" for i in [1 * 2, 2 * 2]] assert sorted(os.listdir(dirname)) == sorted(expected) assert "PREFIX_model_4.pt" in handler.last_checkpoint + + +def test_greater_or_equal(): + scores = iter([1, 2, 2, 2]) + + def score_function(_): + return next(scores) + + class Saver: + def __init__(self): + self.counter = 0 + + def __call__(self, c, f, m): + if self.counter == 0: + assert f == "model_1.pt" + else: + assert f == "model_2.pt" + self.counter += 1 + + handler = Saver() + + checkpointer = Checkpoint( + to_save={"model": DummyModel()}, + save_handler=handler, + score_function=score_function, + n_saved=2, + greater_or_equal=True, + ) + trainer = Engine(lambda e, b: None) + + for _ in range(4): + checkpointer(trainer) + assert handler.counter == 4