Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class Checkpoint:
retained.
score_name (str, optional): If `score_function` not None, it is possible to store its absolute value using
`score_name`. See Notes for more details.
n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed.
n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to
`None`, all objects are kept.
global_step_transform (callable, optional): global step transform function to output a desired global step.
Input of the function is `(engine, event_name)`. Output of function should be an integer.
Default is None, global_step based on attached engine. If provided, uses function output as global_step.
Expand Down Expand Up @@ -154,6 +155,11 @@ def last_checkpoint(self):
return None
return self._saved[0].filename

def _check_lt_n_saved(self, or_equal=False):
if self._n_saved is None:
return True
return len(self._saved) < self._n_saved + int(or_equal)

def __call__(self, engine):

suffix = ""
Expand All @@ -166,8 +172,7 @@ def __call__(self, engine):
else:
priority = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)

if len(self._saved) < self._n_saved or \
self._saved[0].priority < priority:
if self._check_lt_n_saved() or self._saved[0].priority < priority:

if self._score_name is not None:
if len(suffix) > 0:
Expand All @@ -194,7 +199,7 @@ def __call__(self, engine):
self._saved.append(Checkpoint.Item(priority, filename))
self._saved.sort(key=lambda item: item[0])

if len(self._saved) > self._n_saved:
if not self._check_lt_n_saved(or_equal=True):
item = self._saved.pop(0)
self.save_handler.remove(item.filename)

Expand Down Expand Up @@ -316,7 +321,8 @@ class ModelCheckpoint(Checkpoint):
retained.
score_name (str, optional): if `score_function` not None, it is possible to store its absolute value using
`score_name`. See Notes for more details.
n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed.
n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to
`None`, all objects are kept.
atomic (bool, optional): If True, objects are serialized to a temporary file, and then moved to final
destination, so that files are guaranteed to not be damaged (for example if exception
occurs during saving).
Expand Down
21 changes: 21 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,27 @@ def test_last_k(dirname):
assert sorted(os.listdir(dirname)) == expected, "{} vs {}".format(sorted(os.listdir(dirname)), expected)


def test_disabled_n_saved(dirname):

h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=None)
engine = Engine(lambda e, b: None)
engine.state = State(epoch=0, iteration=0)

model = DummyModel()
to_save = {'model': model}

num_iters = 100
for i in range(num_iters):
engine.state.iteration = i
h(engine, to_save)

saved_files = sorted(os.listdir(dirname))
assert len(saved_files) == num_iters, "{}".format(saved_files)

expected = sorted(['{}_{}_{}.pth'.format(_PREFIX, 'model', i) for i in range(num_iters)])
assert saved_files == expected, "{} vs {}".format(saved_files, expected)


def test_best_k(dirname):
scores = iter([1.2, -2., 3.1, -4.0])

Expand Down