Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,24 @@ class Checkpoint(Serializable):
lr_scheduler = ...

to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer}
handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

if (checkpoint_iters):
# A: Output is "checkpoint_<iteration>.pt"
handler = Checkpoint(
to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2
)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)
else:
# B:Output is "checkpoint_<epoch>.pt"
gst = lambda *_: trainer.state.epoch
handler = Checkpoint(
to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2, global_step_transform=gst
)
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)

trainer.run(data_loader, max_epochs=6)
> ["checkpoint_7000.pt", "checkpoint_8000.pt", ]
> A: ["checkpoint_7000.pt", "checkpoint_8000.pt", ]
> B: ["checkpoint_5.pt", "checkpoint_6.pt", ]

Attach the handler to an evaluator to save best model during the training
according to computed validation metric:
Expand Down Expand Up @@ -288,6 +302,32 @@ def __init__(
self._saved = [] # type: List["Checkpoint.Item"]
self.include_self = include_self

def reset(self) -> None:
"""Method to reset saved checkpoint names.

Use this method if the engine will independently run multiple times:

.. code-block:: python

from ignite.handlers import Checkpoint

trainer = ...
checkpointer = Checkpoint(...)

trainer.add_event_handler(Events.COMPLETED, checkpointer)
trainer.add_event_handler(Events.STARTED, checkpointer.reset)

# fold 0
trainer.run(data0, max_epochs=max_epochs)
print("Last checkpoint:", checkpointer.last_checkpoint)

# fold 1
trainer.run(data1, max_epochs=max_epochs)
print("Last checkpoint:", checkpointer.last_checkpoint)

"""
self._saved = []

@property
def last_checkpoint(self) -> Optional[str]:
if len(self._saved) < 1:
Expand Down Expand Up @@ -663,11 +703,11 @@ class ModelCheckpoint(Checkpoint):
>>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True)
>>> model = nn.Linear(3, 3)
>>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model})
>>> trainer.run([0], max_epochs=6)
>>> trainer.run([0, 1, 2, 3, 4], max_epochs=6)
>>> os.listdir('/tmp/models')
['myprefix_mymodel_4.pt', 'myprefix_mymodel_6.pt']
['myprefix_mymodel_20.pt', 'myprefix_mymodel_30.pt']
>>> handler.last_checkpoint
['/tmp/models/myprefix_mymodel_6.pt']
['/tmp/models/myprefix_mymodel_30.pt']
"""

def __init__(
Expand Down
68 changes: 62 additions & 6 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,9 +738,9 @@ def update_fn(_1, _2):
model = DummyModel()
to_save = {"model": model}
engine.add_event_handler(Events.EPOCH_COMPLETED, handler, to_save)
engine.run([0], max_epochs=4)
engine.run([0, 1,], max_epochs=4)

expected = [f"{_PREFIX}_{name}_{i}.pt" for i in [3, 4]]
expected = sorted([f"{_PREFIX}_{name}_{i}.pt" for i in [3 * 2, 4 * 2]])

assert sorted(os.listdir(dirname)) == expected

Expand All @@ -755,7 +755,7 @@ def update_fn(_1, _2):
model = DummyModel()
to_save = {"model": model}
engine.add_event_handler(Events.EPOCH_COMPLETED, handler, to_save)
engine.run([0], max_epochs=4)
engine.run([0, 1, 2], max_epochs=4)

saved_model = os.path.join(dirname, os.listdir(dirname)[0])
load_model = torch.load(saved_model)
Expand Down Expand Up @@ -827,7 +827,7 @@ def update_fn(engine, batch):
Events.EPOCH_COMPLETED, handler, {"model": model, "optimizer": optim, "lr_scheduler": lr_scheduler}
)

engine.run([0], max_epochs=4)
engine.run([0, 1, 2], max_epochs=4)

idist.barrier()

Expand Down Expand Up @@ -937,7 +937,7 @@ def score_function(engine):
return trainer, evaluator, model, optim, lr_scheduler, early_stop, checkpointer

trainer, evaluator, model, optim, scheduler, early, checkpointer = _build_objects([0.2, 0.3, 0.2])
trainer.run([0], max_epochs=3)
trainer.run([0, 1, 2], max_epochs=3)

saved_objects = sorted(os.listdir(dirname))
saved_checkpoint = os.path.join(dirname, saved_objects[0])
Expand Down Expand Up @@ -994,7 +994,7 @@ def _check_state_dict(original, loaded):
_check_state_dict(early, early2)
_check_state_dict(checkpointer, checkpointer2)

trainer2.run([0], max_epochs=6)
trainer2.run([0, 1, 2], max_epochs=6)

# early stopping should have triggered
assert trainer2.state.epoch == 4
Expand Down Expand Up @@ -1470,3 +1470,59 @@ def _test(n_saved):
_test(None)
_test(1)
_test(3)


def test_checkpoint_reset():
model = DummyModel()
to_save = {"model": model}

save_handler = MagicMock(spec=BaseSaveHandler)

checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=2)
assert checkpointer.last_checkpoint is None

trainer = Engine(lambda e, b: None)

trainer.state = State(epoch=0, iteration=123)
checkpointer(trainer)
trainer.state.iteration = 234
checkpointer(trainer)

assert save_handler.call_count == 2
assert checkpointer.last_checkpoint == "model_234.pt"
assert len(checkpointer._saved) == 2
assert sorted([item.filename for item in checkpointer._saved]) == sorted(["model_123.pt", "model_234.pt"])

checkpointer.reset()
assert len(checkpointer._saved) == 0

trainer.state.iteration = 124
checkpointer(trainer)

assert save_handler.call_count == 3
assert checkpointer.last_checkpoint == "model_124.pt"
assert len(checkpointer._saved) == 1
assert sorted([item.filename for item in checkpointer._saved]) == sorted(["model_124.pt",])


def test_checkpoint_reset_with_engine(dirname):
name = "model"
engine = Engine(lambda e, b: None)
handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2)

model = DummyModel()
to_save = {"model": model}
engine.add_event_handler(Events.EPOCH_COMPLETED, handler, to_save)
engine.run([0, 1,], max_epochs=10)

expected = sorted([f"{_PREFIX}_{name}_{i}.pt" for i in [9 * 2, 10 * 2]])
assert sorted(os.listdir(dirname)) == expected
assert "PREFIX_model_20.pt" in handler.last_checkpoint

handler.reset()
engine.state.max_epochs = None
engine.run([0, 1,], max_epochs=2)

expected += [f"{_PREFIX}_{name}_{i}.pt" for i in [1 * 2, 2 * 2]]
assert sorted(os.listdir(dirname)) == sorted(expected)
assert "PREFIX_model_4.pt" in handler.last_checkpoint