Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class Checkpoint(Serializable):
details.
include_self (bool): Whether to include the `state_dict` of this object in the checkpoint. If `True`, then
there must not be another object in ``to_save`` with key ``checkpointer``.
greater_or_equal (bool): if `True`, the latest equally scored model is stored. Otherwise, the first model.
Default, `False`.

.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
torch.nn.parallel.DistributedDataParallel.html
Expand Down Expand Up @@ -261,6 +263,7 @@ def __init__(
global_step_transform: Optional[Callable] = None,
filename_pattern: Optional[str] = None,
include_self: bool = False,
greater_or_equal: bool = False,
) -> None:

if to_save is not None: # for compatibility with ModelCheckpoint
Expand Down Expand Up @@ -301,6 +304,7 @@ def __init__(
self.filename_pattern = filename_pattern
self._saved = [] # type: List["Checkpoint.Item"]
self.include_self = include_self
self.greater_or_equal = greater_or_equal

def reset(self) -> None:
"""Method to reset saved checkpoint names.
Expand Down Expand Up @@ -339,6 +343,12 @@ def _check_lt_n_saved(self, or_equal: bool = False) -> bool:
return True
return len(self._saved) < self.n_saved + int(or_equal)

def _compare_fn(self, new: Union[int, float]) -> bool:
if self.greater_or_equal:
return new >= self._saved[0].priority
else:
return new > self._saved[0].priority

def __call__(self, engine: Engine) -> None:

global_step = None
Expand All @@ -354,7 +364,7 @@ def __call__(self, engine: Engine) -> None:
global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)
priority = global_step

if self._check_lt_n_saved() or self._saved[0].priority < priority:
if self._check_lt_n_saved() or self._compare_fn(priority):

priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}"

Expand Down
33 changes: 33 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,3 +1526,36 @@ def test_checkpoint_reset_with_engine(dirname):
expected += [f"{_PREFIX}_{name}_{i}.pt" for i in [1 * 2, 2 * 2]]
assert sorted(os.listdir(dirname)) == sorted(expected)
assert "PREFIX_model_4.pt" in handler.last_checkpoint


def test_greater_or_equal():
scores = iter([1, 2, 2, 2])

def score_function(_):
return next(scores)

class Saver:
def __init__(self):
self.counter = 0

def __call__(self, c, f, m):
if self.counter == 0:
assert f == "model_1.pt"
else:
assert f == "model_2.pt"
self.counter += 1

handler = Saver()

checkpointer = Checkpoint(
to_save={"model": DummyModel()},
save_handler=handler,
score_function=score_function,
n_saved=2,
greater_or_equal=True,
)
trainer = Engine(lambda e, b: None)

for _ in range(4):
checkpointer(trainer)
assert handler.counter == 4