diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 3cb61be00c4f..13720adb205c 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -27,7 +27,8 @@ class Checkpoint: retained. score_name (str, optional): If `score_function` not None, it is possible to store its absolute value using `score_name`. See Notes for more details. - n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. + n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to + `None`, all objects are kept. global_step_transform (callable, optional): global step transform function to output a desired global step. Input of the function is `(engine, event_name)`. Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. @@ -154,6 +155,11 @@ def last_checkpoint(self): return None return self._saved[0].filename + def _check_lt_n_saved(self, or_equal=False): + if self._n_saved is None: + return True + return len(self._saved) < self._n_saved + int(or_equal) + def __call__(self, engine): suffix = "" @@ -166,8 +172,7 @@ def __call__(self, engine): else: priority = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) - if len(self._saved) < self._n_saved or \ - self._saved[0].priority < priority: + if self._check_lt_n_saved() or self._saved[0].priority < priority: if self._score_name is not None: if len(suffix) > 0: @@ -194,7 +199,7 @@ def __call__(self, engine): self._saved.append(Checkpoint.Item(priority, filename)) self._saved.sort(key=lambda item: item[0]) - if len(self._saved) > self._n_saved: + if not self._check_lt_n_saved(or_equal=True): item = self._saved.pop(0) self.save_handler.remove(item.filename) @@ -316,7 +321,8 @@ class ModelCheckpoint(Checkpoint): retained. score_name (str, optional): if `score_function` not None, it is possible to store its absolute value using `score_name`. See Notes for more details. - n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. + n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to + `None`, all objects are kept. atomic (bool, optional): If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving). diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 9652933394c3..f11e9a6f259a 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -403,6 +403,27 @@ def test_last_k(dirname): assert sorted(os.listdir(dirname)) == expected, "{} vs {}".format(sorted(os.listdir(dirname)), expected) +def test_disabled_n_saved(dirname): + + h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=None) + engine = Engine(lambda e, b: None) + engine.state = State(epoch=0, iteration=0) + + model = DummyModel() + to_save = {'model': model} + + num_iters = 100 + for i in range(num_iters): + engine.state.iteration = i + h(engine, to_save) + + saved_files = sorted(os.listdir(dirname)) + assert len(saved_files) == num_iters, "{}".format(saved_files) + + expected = sorted(['{}_{}_{}.pth'.format(_PREFIX, 'model', i) for i in range(num_iters)]) + assert saved_files == expected, "{} vs {}".format(saved_files, expected) + + def test_best_k(dirname): scores = iter([1.2, -2., 3.1, -4.0])