From eb4829ee1a8cbd345126f9f8ceec4696e8dd84c5 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Sat, 6 Feb 2021 23:48:11 +0000 Subject: [PATCH 1/3] Added Checkpoint.get_default_score_fn to simplify best_model_handler creation --- ignite/contrib/engines/common.py | 9 ++------ ignite/handlers/checkpoint.py | 26 ++++++++++++++++++++++++ tests/ignite/handlers/test_checkpoint.py | 10 +++++++++ 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index 97fedbcbc16e..7b769ea71bf7 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -569,12 +569,7 @@ def setup_trains_logging( return setup_clearml_logging(trainer, optimizers, evaluators, log_every_iters, **kwargs) -def get_default_score_fn(metric_name: str) -> Any: - def wrapper(engine: Engine) -> Any: - score = engine.state.metrics[metric_name] - return score - - return wrapper +get_default_score_fn = Checkpoint.get_default_score_fn def gen_save_best_models_by_val_score( @@ -628,7 +623,7 @@ def gen_save_best_models_by_val_score( n_saved=n_saved, global_step_transform=global_step_transform, score_name=f"{tag}_{metric_name.lower()}", - score_function=get_default_score_fn(metric_name), + score_function=Checkpoint.get_default_score_fn(metric_name), **kwargs, ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 41a567eecd56..b8a8708af155 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -567,6 +567,32 @@ def load_state_dict(self, state_dict: Mapping) -> None: super().load_state_dict(state_dict) self._saved = [Checkpoint.Item(p, f) for p, f in state_dict["saved"]] + @staticmethod + def get_default_score_fn(metric_name: str) -> Any: + """Helper method to get default score function based on the metric name. + + Exemples: + + .. code-block:: python + + from ignite.handlers import Checkpoint + + best_acc_score = Checkpoint.get_default_score_fn(accuracy) + + best_model_handler = Checkpoint( + to_save, save_handler, score_name="val_accuracy", score_function=best_acc_score + ) + evaluator.add_event_handler(Events.COMPLETED, best_model_handler) + + .. versionadded:: 0.4.3 + """ + + def wrapper(engine: Engine) -> Any: + score = engine.state.metrics[metric_name] + return score + + return wrapper + class DiskSaver(BaseSaveHandler): """Handler that saves input checkpoint on a disk. diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index bd36f16b7539..b1b412178d81 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1556,3 +1556,13 @@ def __call__(self, c, f, m): for _ in range(4): checkpointer(trainer) assert handler.counter == 4 + + +def test_get_default_score_fn(): + + engine = Engine(lambda e, b: None) + engine.state.metrics["acc"] = 0.9 + + score_fn = Checkpoint.get_default_score_fn("acc") + score = score_fn(engine) + assert score == 0.9 From a4196258b0c29f0e9ac3c8cbe77b9612ba03a485 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Sun, 7 Feb 2021 21:36:47 +0000 Subject: [PATCH 2/3] Added score_sign argument --- ignite/handlers/checkpoint.py | 41 ++++++++++++++++++++---- tests/ignite/handlers/test_checkpoint.py | 8 +++++ 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index b8a8708af155..ec108e1097d6 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -231,8 +231,7 @@ class Checkpoint(Serializable): # Run evaluation on epoch completed event # ... - def score_function(engine): - return engine.state.metrics['accuracy'] + score_function = Checkpoint.get_default_score_fn("accuracy") to_save = {'model': model} handler = Checkpoint( @@ -561,35 +560,63 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None: obj.load_state_dict(checkpoint[k]) def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]": + """Method returns state dict with saved items: list of ``(priority, filename)`` pairs. + Can be used to save internal state of the class. + """ return OrderedDict([("saved", [(p, f) for p, f in self._saved])]) def load_state_dict(self, state_dict: Mapping) -> None: + """Method replace internal state of the class with provided state dict data. + + Args: + state_dict (Mapping): a dict with "saved" key and list of ``(priority, filename)`` pairs as values. + """ super().load_state_dict(state_dict) self._saved = [Checkpoint.Item(p, f) for p, f in state_dict["saved"]] @staticmethod - def get_default_score_fn(metric_name: str) -> Any: + def get_default_score_fn(metric_name: str, score_sign: float = 1.0) -> Callable: """Helper method to get default score function based on the metric name. + Args: + metric_name (str): metric name to get the value from ``engine.state.metrics``. + Engine is the one to which :class:`~ignite.handlers.checkpoint.Checkpoint` handler is added. + score_sign (float): sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, + a negative score sign should be used (objects with larger score are retained). Default, 1.0. + Exemples: .. code-block:: python from ignite.handlers import Checkpoint - best_acc_score = Checkpoint.get_default_score_fn(accuracy) + best_acc_score = Checkpoint.get_default_score_fn("accuracy") best_model_handler = Checkpoint( to_save, save_handler, score_name="val_accuracy", score_function=best_acc_score ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler) + Usage with error-like metric: + + .. code-block:: python + + from ignite.handlers import Checkpoint + + neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0) + + best_model_handler = Checkpoint( + to_save, save_handler, score_name="val_neg_loss", score_function=neg_loss_score + ) + evaluator.add_event_handler(Events.COMPLETED, best_model_handler) + .. versionadded:: 0.4.3 """ + if score_sign not in (1.0, -1.0): + raise ValueError("Argument score_sign should be 1 or -1") - def wrapper(engine: Engine) -> Any: - score = engine.state.metrics[metric_name] - return score + def wrapper(engine: Engine) -> float: + return score_sign * engine.state.metrics[metric_name] return wrapper diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index b1b412178d81..2f1a3ff901ca 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1560,9 +1560,17 @@ def __call__(self, c, f, m): def test_get_default_score_fn(): + with pytest.raises(ValueError, match=r"Argument score_sign should be 1 or -1"): + Checkpoint.get_default_score_fn("acc", 2.0) + engine = Engine(lambda e, b: None) engine.state.metrics["acc"] = 0.9 + engine.state.metrics["loss"] = 0.123 score_fn = Checkpoint.get_default_score_fn("acc") score = score_fn(engine) assert score == 0.9 + + score_fn = Checkpoint.get_default_score_fn("loss", -1) + score = score_fn(engine) + assert score == -0.123 From da7ac8aeff35fbbe741de09781dc2038c5833878 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Sun, 7 Feb 2021 21:58:04 +0000 Subject: [PATCH 3/3] Updated docs --- docs/source/handlers.rst | 2 +- ignite/handlers/checkpoint.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 31828b5ac60d..870776032680 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -11,7 +11,7 @@ Complete list of handlers :autolist: .. autoclass:: Checkpoint - :members: load_objects + :members: reset, setup_filename_pattern, load_objects, state_dict, load_state_dict, get_default_score_fn .. autoclass:: ignite.handlers.checkpoint.BaseSaveHandler :members: __call__, remove diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index ec108e1097d6..3fefa45347f5 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -328,6 +328,7 @@ def reset(self) -> None: trainer.run(data1, max_epochs=max_epochs) print("Last checkpoint:", checkpointer.last_checkpoint) + .. versionadded:: 0.4.3 """ self._saved = [] @@ -462,6 +463,8 @@ def setup_filename_pattern( print(filename_pattern) > "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}" + + .. versionadded:: 0.4.3 """ filename_pattern = "{name}"