Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Complete list of handlers
:autolist:

.. autoclass:: Checkpoint
:members: load_objects
:members: reset, setup_filename_pattern, load_objects, state_dict, load_state_dict, get_default_score_fn

.. autoclass:: ignite.handlers.checkpoint.BaseSaveHandler
:members: __call__, remove
Expand Down
9 changes: 2 additions & 7 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,12 +569,7 @@ def setup_trains_logging(
return setup_clearml_logging(trainer, optimizers, evaluators, log_every_iters, **kwargs)


def get_default_score_fn(metric_name: str) -> Any:
def wrapper(engine: Engine) -> Any:
score = engine.state.metrics[metric_name]
return score

return wrapper
get_default_score_fn = Checkpoint.get_default_score_fn


def gen_save_best_models_by_val_score(
Expand Down Expand Up @@ -628,7 +623,7 @@ def gen_save_best_models_by_val_score(
n_saved=n_saved,
global_step_transform=global_step_transform,
score_name=f"{tag}_{metric_name.lower()}",
score_function=get_default_score_fn(metric_name),
score_function=Checkpoint.get_default_score_fn(metric_name),
**kwargs,
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)
Expand Down
60 changes: 58 additions & 2 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,7 @@ class Checkpoint(Serializable):
# Run evaluation on epoch completed event
# ...

def score_function(engine):
return engine.state.metrics['accuracy']
score_function = Checkpoint.get_default_score_fn("accuracy")

to_save = {'model': model}
handler = Checkpoint(
Expand Down Expand Up @@ -329,6 +328,7 @@ def reset(self) -> None:
trainer.run(data1, max_epochs=max_epochs)
print("Last checkpoint:", checkpointer.last_checkpoint)

.. versionadded:: 0.4.3
"""
self._saved = []

Expand Down Expand Up @@ -463,6 +463,8 @@ def setup_filename_pattern(

print(filename_pattern)
> "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}"

.. versionadded:: 0.4.3
"""
filename_pattern = "{name}"

Expand Down Expand Up @@ -561,12 +563,66 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
obj.load_state_dict(checkpoint[k])

def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
Can be used to save internal state of the class.
"""
return OrderedDict([("saved", [(p, f) for p, f in self._saved])])

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replace internal state of the class with provided state dict data.

Args:
state_dict (Mapping): a dict with "saved" key and list of ``(priority, filename)`` pairs as values.
"""
super().load_state_dict(state_dict)
self._saved = [Checkpoint.Item(p, f) for p, f in state_dict["saved"]]

@staticmethod
def get_default_score_fn(metric_name: str, score_sign: float = 1.0) -> Callable:
"""Helper method to get default score function based on the metric name.

Args:
metric_name (str): metric name to get the value from ``engine.state.metrics``.
Engine is the one to which :class:`~ignite.handlers.checkpoint.Checkpoint` handler is added.
score_sign (float): sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better,
a negative score sign should be used (objects with larger score are retained). Default, 1.0.

Exemples:

.. code-block:: python

from ignite.handlers import Checkpoint

best_acc_score = Checkpoint.get_default_score_fn("accuracy")

best_model_handler = Checkpoint(
to_save, save_handler, score_name="val_accuracy", score_function=best_acc_score
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

Usage with error-like metric:

.. code-block:: python

from ignite.handlers import Checkpoint

neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0)

best_model_handler = Checkpoint(
to_save, save_handler, score_name="val_neg_loss", score_function=neg_loss_score
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

.. versionadded:: 0.4.3
"""
if score_sign not in (1.0, -1.0):
raise ValueError("Argument score_sign should be 1 or -1")

def wrapper(engine: Engine) -> float:
return score_sign * engine.state.metrics[metric_name]

return wrapper


class DiskSaver(BaseSaveHandler):
"""Handler that saves input checkpoint on a disk.
Expand Down
18 changes: 18 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,3 +1556,21 @@ def __call__(self, c, f, m):
for _ in range(4):
checkpointer(trainer)
assert handler.counter == 4


def test_get_default_score_fn():

with pytest.raises(ValueError, match=r"Argument score_sign should be 1 or -1"):
Checkpoint.get_default_score_fn("acc", 2.0)

engine = Engine(lambda e, b: None)
engine.state.metrics["acc"] = 0.9
engine.state.metrics["loss"] = 0.123

score_fn = Checkpoint.get_default_score_fn("acc")
score = score_fn(engine)
assert score == 0.9

score_fn = Checkpoint.get_default_score_fn("loss", -1)
score = score_fn(engine)
assert score == -0.123