From 6d50b8114ac867d7f77f425907566c7bbcfbe5f4 Mon Sep 17 00:00:00 2001 From: Taras Savchyn Date: Sun, 4 Oct 2020 03:56:34 +0300 Subject: [PATCH 1/6] Improve typing for ignite.handlers module (1343) --- ignite/handlers/checkpoint.py | 27 ++++++++++++++++----------- mypy.ini | 2 +- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 6870e1b0ff6f..2a4fb3a761fa 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -2,6 +2,7 @@ import numbers import os import tempfile +from tempfile import _TemporaryFileWrapper # type: ignore import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict, namedtuple @@ -235,7 +236,7 @@ def score_function(engine): def __init__( self, - to_save: Mapping, + to_save: Optional[Mapping], save_handler: Union[Callable, BaseSaveHandler], filename_prefix: str = "", score_function: Optional[Callable] = None, @@ -287,7 +288,7 @@ def __init__( self.ext = "pt" self.global_step_transform = global_step_transform self.filename_pattern = filename_pattern - self._saved = [] + self._saved: list = [] self.include_self = include_self @property @@ -378,10 +379,11 @@ def __call__(self, engine: Engine) -> None: def _setup_checkpoint(self) -> dict: checkpoint = {} - for k, obj in self.to_save.items(): - if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): - obj = obj.module - checkpoint[k] = obj.state_dict() + if self.to_save is not None: + for k, obj in self.to_save.items(): + if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): + obj = obj.module + checkpoint[k] = obj.state_dict() return checkpoint @staticmethod @@ -572,7 +574,7 @@ def _save_native(self, checkpoint: Mapping, path: str): self._save_func(checkpoint, path, torch.save) def _save_xla(self, checkpoint: Mapping, path: str): - import torch_xla.core.xla_model as xm + import torch_xla.core.xla_model as xm # type: ignore # all tpu procs should enter here as internally performs sync across device self._save_func(checkpoint, path, xm.save, rank=idist.get_rank()) @@ -582,8 +584,8 @@ def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = func(checkpoint, path, **self.kwargs) else: tmp_file = None - tmp_name = None - tmp = None + tmp_name = '' + tmp: _TemporaryFileWrapper = None if rank == 0: tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname) tmp_file = tmp.file @@ -728,9 +730,12 @@ def __init__( def last_checkpoint(self) -> Union[str, None]: if len(self._saved) < 1: return None - return os.path.join(self.save_handler.dirname, self._saved[-1].filename) + if isinstance(self.save_handler, DiskSaver): + return os.path.join(self.save_handler.dirname, self._saved[-1].filename) + else: + return None - def __call__(self, engine: Engine, to_save: Mapping) -> None: + def __call__(self, engine: Engine, to_save: Mapping) -> None: # type: ignore if len(to_save) == 0: raise RuntimeError("No objects to checkpoint found.") diff --git a/mypy.ini b/mypy.ini index 5ef86a5e30fe..299843badb37 100644 --- a/mypy.ini +++ b/mypy.ini @@ -13,7 +13,7 @@ ignore_errors = True [mypy-ignite.handlers.*] -ignore_errors = True +ignore_errors = False [mypy-ignite.engine.*] From 7465d4632f5ea00f799c2d2f6668f54c85ab64f6 Mon Sep 17 00:00:00 2001 From: AutoPEP8 <> Date: Sun, 4 Oct 2020 00:59:10 +0000 Subject: [PATCH 2/6] autopep8 fix --- ignite/handlers/checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 2a4fb3a761fa..f7d29ba65f2a 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -2,10 +2,10 @@ import numbers import os import tempfile -from tempfile import _TemporaryFileWrapper # type: ignore import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict, namedtuple +from tempfile import _TemporaryFileWrapper # type: ignore from typing import Callable, Mapping, Optional, Union import torch @@ -584,7 +584,7 @@ def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = func(checkpoint, path, **self.kwargs) else: tmp_file = None - tmp_name = '' + tmp_name = "" tmp: _TemporaryFileWrapper = None if rank == 0: tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname) From f27500cb0b785828c462aff8dbd3c83f2640202b Mon Sep 17 00:00:00 2001 From: Taras Savchyn Date: Sun, 4 Oct 2020 17:13:19 +0300 Subject: [PATCH 3/6] Fix typing for py35, remove handlers block from mypy.ini --- ignite/handlers/checkpoint.py | 4 ++-- mypy.ini | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index f7d29ba65f2a..e3f83e2d526b 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -288,7 +288,7 @@ def __init__( self.ext = "pt" self.global_step_transform = global_step_transform self.filename_pattern = filename_pattern - self._saved: list = [] + self._saved = [] # type: list self.include_self = include_self @property @@ -585,7 +585,7 @@ def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = else: tmp_file = None tmp_name = "" - tmp: _TemporaryFileWrapper = None + tmp = None # type: _TemporaryFileWrapper if rank == 0: tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname) tmp_file = tmp.file diff --git a/mypy.ini b/mypy.ini index 299843badb37..0863b902c305 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,10 +11,6 @@ ignore_errors = True ignore_errors = True -[mypy-ignite.handlers.*] - -ignore_errors = False - [mypy-ignite.engine.*] ignore_errors = True From db57a85ddd7934db6f8e54e401a55a305e168393 Mon Sep 17 00:00:00 2001 From: Taras Savchyn Date: Mon, 5 Oct 2020 13:37:24 +0300 Subject: [PATCH 4/6] Add exception to ModelCheckpoint when saving last checkpoint --- ignite/handlers/checkpoint.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index e3f83e2d526b..6f4321b7812f 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -730,10 +730,13 @@ def __init__( def last_checkpoint(self) -> Union[str, None]: if len(self._saved) < 1: return None - if isinstance(self.save_handler, DiskSaver): - return os.path.join(self.save_handler.dirname, self._saved[-1].filename) - else: - return None + + if not isinstance(self.save_handler, DiskSaver): + raise RuntimeError( + "Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(self.save_handler)) + ) + + return os.path.join(self.save_handler.dirname, self._saved[-1].filename) def __call__(self, engine: Engine, to_save: Mapping) -> None: # type: ignore From 63142939f1b37e33a8cee60a1e852335f2d8f3ae Mon Sep 17 00:00:00 2001 From: Taras Savchyn Date: Tue, 6 Oct 2020 17:03:46 +0300 Subject: [PATCH 5/6] Add test for ModelCheckpoint with redefined save_handler case --- tests/ignite/handlers/test_checkpoint.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 6478ff491040..cb07f783840e 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -563,6 +563,17 @@ def _test(ext, require_empty): _test(".pt", require_empty=False) +def test_model_checkpoint_invalid_save_handler(dirname): + h = ModelCheckpoint(dirname, _PREFIX) + to_save = {"model": DummyModel()} + # Redefine save_handler + h.save_handler = lambda x, y: None + h(Engine(lambda x, y: None), to_save) + + with pytest.raises(RuntimeError, match=r"Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(h.save_handler))): + h.last_checkpoint + + def test_disk_saver_atomic(dirname): model = DummyModel() From e01c95e3ca33f98889b1792ef5d2d504b62775d2 Mon Sep 17 00:00:00 2001 From: trsvchn Date: Tue, 6 Oct 2020 14:05:29 +0000 Subject: [PATCH 6/6] autopep8 fix --- tests/ignite/handlers/test_checkpoint.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index cb07f783840e..11d044e2aa37 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -570,7 +570,10 @@ def test_model_checkpoint_invalid_save_handler(dirname): h.save_handler = lambda x, y: None h(Engine(lambda x, y: None), to_save) - with pytest.raises(RuntimeError, match=r"Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(h.save_handler))): + with pytest.raises( + RuntimeError, + match=r"Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(h.save_handler)), + ): h.last_checkpoint