From 08143c5ec0dd10309f229d08b0bdccfcc8390bd9 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 25 Dec 2020 11:30:43 +0000 Subject: [PATCH] Added reset method to Checkpoint/ModelCheckpoint Fixes #1422 --- ignite/handlers/checkpoint.py | 52 +++++++++++++++--- tests/ignite/handlers/test_checkpoint.py | 68 +++++++++++++++++++++--- 2 files changed, 108 insertions(+), 12 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 994f89b0033c..035eff43b784 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -196,10 +196,24 @@ class Checkpoint(Serializable): lr_scheduler = ... to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer} - handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2) - trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) + + if (checkpoint_iters): + # A: Output is "checkpoint_.pt" + handler = Checkpoint( + to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2 + ) + trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) + else: + # B:Output is "checkpoint_.pt" + gst = lambda *_: trainer.state.epoch + handler = Checkpoint( + to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2, global_step_transform=gst + ) + trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) + trainer.run(data_loader, max_epochs=6) - > ["checkpoint_7000.pt", "checkpoint_8000.pt", ] + > A: ["checkpoint_7000.pt", "checkpoint_8000.pt", ] + > B: ["checkpoint_5.pt", "checkpoint_6.pt", ] Attach the handler to an evaluator to save best model during the training according to computed validation metric: @@ -288,6 +302,32 @@ def __init__( self._saved = [] # type: List["Checkpoint.Item"] self.include_self = include_self + def reset(self) -> None: + """Method to reset saved checkpoint names. + + Use this method if the engine will independently run multiple times: + + .. code-block:: python + + from ignite.handlers import Checkpoint + + trainer = ... + checkpointer = Checkpoint(...) + + trainer.add_event_handler(Events.COMPLETED, checkpointer) + trainer.add_event_handler(Events.STARTED, checkpointer.reset) + + # fold 0 + trainer.run(data0, max_epochs=max_epochs) + print("Last checkpoint:", checkpointer.last_checkpoint) + + # fold 1 + trainer.run(data1, max_epochs=max_epochs) + print("Last checkpoint:", checkpointer.last_checkpoint) + + """ + self._saved = [] + @property def last_checkpoint(self) -> Optional[str]: if len(self._saved) < 1: @@ -663,11 +703,11 @@ class ModelCheckpoint(Checkpoint): >>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True) >>> model = nn.Linear(3, 3) >>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model}) - >>> trainer.run([0], max_epochs=6) + >>> trainer.run([0, 1, 2, 3, 4], max_epochs=6) >>> os.listdir('/tmp/models') - ['myprefix_mymodel_4.pt', 'myprefix_mymodel_6.pt'] + ['myprefix_mymodel_20.pt', 'myprefix_mymodel_30.pt'] >>> handler.last_checkpoint - ['/tmp/models/myprefix_mymodel_6.pt'] + ['/tmp/models/myprefix_mymodel_30.pt'] """ def __init__( diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index b851353676db..8acc73407bb7 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -738,9 +738,9 @@ def update_fn(_1, _2): model = DummyModel() to_save = {"model": model} engine.add_event_handler(Events.EPOCH_COMPLETED, handler, to_save) - engine.run([0], max_epochs=4) + engine.run([0, 1,], max_epochs=4) - expected = [f"{_PREFIX}_{name}_{i}.pt" for i in [3, 4]] + expected = sorted([f"{_PREFIX}_{name}_{i}.pt" for i in [3 * 2, 4 * 2]]) assert sorted(os.listdir(dirname)) == expected @@ -755,7 +755,7 @@ def update_fn(_1, _2): model = DummyModel() to_save = {"model": model} engine.add_event_handler(Events.EPOCH_COMPLETED, handler, to_save) - engine.run([0], max_epochs=4) + engine.run([0, 1, 2], max_epochs=4) saved_model = os.path.join(dirname, os.listdir(dirname)[0]) load_model = torch.load(saved_model) @@ -827,7 +827,7 @@ def update_fn(engine, batch): Events.EPOCH_COMPLETED, handler, {"model": model, "optimizer": optim, "lr_scheduler": lr_scheduler} ) - engine.run([0], max_epochs=4) + engine.run([0, 1, 2], max_epochs=4) idist.barrier() @@ -937,7 +937,7 @@ def score_function(engine): return trainer, evaluator, model, optim, lr_scheduler, early_stop, checkpointer trainer, evaluator, model, optim, scheduler, early, checkpointer = _build_objects([0.2, 0.3, 0.2]) - trainer.run([0], max_epochs=3) + trainer.run([0, 1, 2], max_epochs=3) saved_objects = sorted(os.listdir(dirname)) saved_checkpoint = os.path.join(dirname, saved_objects[0]) @@ -994,7 +994,7 @@ def _check_state_dict(original, loaded): _check_state_dict(early, early2) _check_state_dict(checkpointer, checkpointer2) - trainer2.run([0], max_epochs=6) + trainer2.run([0, 1, 2], max_epochs=6) # early stopping should have triggered assert trainer2.state.epoch == 4 @@ -1470,3 +1470,59 @@ def _test(n_saved): _test(None) _test(1) _test(3) + + +def test_checkpoint_reset(): + model = DummyModel() + to_save = {"model": model} + + save_handler = MagicMock(spec=BaseSaveHandler) + + checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=2) + assert checkpointer.last_checkpoint is None + + trainer = Engine(lambda e, b: None) + + trainer.state = State(epoch=0, iteration=123) + checkpointer(trainer) + trainer.state.iteration = 234 + checkpointer(trainer) + + assert save_handler.call_count == 2 + assert checkpointer.last_checkpoint == "model_234.pt" + assert len(checkpointer._saved) == 2 + assert sorted([item.filename for item in checkpointer._saved]) == sorted(["model_123.pt", "model_234.pt"]) + + checkpointer.reset() + assert len(checkpointer._saved) == 0 + + trainer.state.iteration = 124 + checkpointer(trainer) + + assert save_handler.call_count == 3 + assert checkpointer.last_checkpoint == "model_124.pt" + assert len(checkpointer._saved) == 1 + assert sorted([item.filename for item in checkpointer._saved]) == sorted(["model_124.pt",]) + + +def test_checkpoint_reset_with_engine(dirname): + name = "model" + engine = Engine(lambda e, b: None) + handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2) + + model = DummyModel() + to_save = {"model": model} + engine.add_event_handler(Events.EPOCH_COMPLETED, handler, to_save) + engine.run([0, 1,], max_epochs=10) + + expected = sorted([f"{_PREFIX}_{name}_{i}.pt" for i in [9 * 2, 10 * 2]]) + assert sorted(os.listdir(dirname)) == expected + assert "PREFIX_model_20.pt" in handler.last_checkpoint + + handler.reset() + engine.state.max_epochs = None + engine.run([0, 1,], max_epochs=2) + + expected += [f"{_PREFIX}_{name}_{i}.pt" for i in [1 * 2, 2 * 2]] + assert sorted(os.listdir(dirname)) == sorted(expected) + assert "PREFIX_model_4.pt" in handler.last_checkpoint