diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 31828b5ac60d..870776032680 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -11,7 +11,7 @@ Complete list of handlers :autolist: .. autoclass:: Checkpoint - :members: load_objects + :members: reset, setup_filename_pattern, load_objects, state_dict, load_state_dict, get_default_score_fn .. autoclass:: ignite.handlers.checkpoint.BaseSaveHandler :members: __call__, remove diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index 97fedbcbc16e..7b769ea71bf7 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -569,12 +569,7 @@ def setup_trains_logging( return setup_clearml_logging(trainer, optimizers, evaluators, log_every_iters, **kwargs) -def get_default_score_fn(metric_name: str) -> Any: - def wrapper(engine: Engine) -> Any: - score = engine.state.metrics[metric_name] - return score - - return wrapper +get_default_score_fn = Checkpoint.get_default_score_fn def gen_save_best_models_by_val_score( @@ -628,7 +623,7 @@ def gen_save_best_models_by_val_score( n_saved=n_saved, global_step_transform=global_step_transform, score_name=f"{tag}_{metric_name.lower()}", - score_function=get_default_score_fn(metric_name), + score_function=Checkpoint.get_default_score_fn(metric_name), **kwargs, ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 41a567eecd56..3fefa45347f5 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -231,8 +231,7 @@ class Checkpoint(Serializable): # Run evaluation on epoch completed event # ... - def score_function(engine): - return engine.state.metrics['accuracy'] + score_function = Checkpoint.get_default_score_fn("accuracy") to_save = {'model': model} handler = Checkpoint( @@ -329,6 +328,7 @@ def reset(self) -> None: trainer.run(data1, max_epochs=max_epochs) print("Last checkpoint:", checkpointer.last_checkpoint) + .. versionadded:: 0.4.3 """ self._saved = [] @@ -463,6 +463,8 @@ def setup_filename_pattern( print(filename_pattern) > "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}" + + .. versionadded:: 0.4.3 """ filename_pattern = "{name}" @@ -561,12 +563,66 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None: obj.load_state_dict(checkpoint[k]) def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]": + """Method returns state dict with saved items: list of ``(priority, filename)`` pairs. + Can be used to save internal state of the class. + """ return OrderedDict([("saved", [(p, f) for p, f in self._saved])]) def load_state_dict(self, state_dict: Mapping) -> None: + """Method replace internal state of the class with provided state dict data. + + Args: + state_dict (Mapping): a dict with "saved" key and list of ``(priority, filename)`` pairs as values. + """ super().load_state_dict(state_dict) self._saved = [Checkpoint.Item(p, f) for p, f in state_dict["saved"]] + @staticmethod + def get_default_score_fn(metric_name: str, score_sign: float = 1.0) -> Callable: + """Helper method to get default score function based on the metric name. + + Args: + metric_name (str): metric name to get the value from ``engine.state.metrics``. + Engine is the one to which :class:`~ignite.handlers.checkpoint.Checkpoint` handler is added. + score_sign (float): sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, + a negative score sign should be used (objects with larger score are retained). Default, 1.0. + + Exemples: + + .. code-block:: python + + from ignite.handlers import Checkpoint + + best_acc_score = Checkpoint.get_default_score_fn("accuracy") + + best_model_handler = Checkpoint( + to_save, save_handler, score_name="val_accuracy", score_function=best_acc_score + ) + evaluator.add_event_handler(Events.COMPLETED, best_model_handler) + + Usage with error-like metric: + + .. code-block:: python + + from ignite.handlers import Checkpoint + + neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0) + + best_model_handler = Checkpoint( + to_save, save_handler, score_name="val_neg_loss", score_function=neg_loss_score + ) + evaluator.add_event_handler(Events.COMPLETED, best_model_handler) + + .. versionadded:: 0.4.3 + """ + if score_sign not in (1.0, -1.0): + raise ValueError("Argument score_sign should be 1 or -1") + + def wrapper(engine: Engine) -> float: + return score_sign * engine.state.metrics[metric_name] + + return wrapper + class DiskSaver(BaseSaveHandler): """Handler that saves input checkpoint on a disk. diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index bd36f16b7539..2f1a3ff901ca 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1556,3 +1556,21 @@ def __call__(self, c, f, m): for _ in range(4): checkpointer(trainer) assert handler.counter == 4 + + +def test_get_default_score_fn(): + + with pytest.raises(ValueError, match=r"Argument score_sign should be 1 or -1"): + Checkpoint.get_default_score_fn("acc", 2.0) + + engine = Engine(lambda e, b: None) + engine.state.metrics["acc"] = 0.9 + engine.state.metrics["loss"] = 0.123 + + score_fn = Checkpoint.get_default_score_fn("acc") + score = score_fn(engine) + assert score == 0.9 + + score_fn = Checkpoint.get_default_score_fn("loss", -1) + score = score_fn(engine) + assert score == -0.123