From 5c869c3749b4d9523161d2f2909c39361a000813 Mon Sep 17 00:00:00 2001 From: gruebel Date: Tue, 3 Nov 2020 23:16:18 +0100 Subject: [PATCH] enable extra flags for stricter type checking --- ignite/contrib/metrics/gpu_info.py | 4 ++- ignite/distributed/comp_models/__init__.py | 13 ++++++-- ignite/distributed/comp_models/base.py | 10 +++--- ignite/distributed/comp_models/horovod.py | 10 +++--- ignite/distributed/comp_models/native.py | 8 ++--- ignite/distributed/comp_models/xla.py | 10 +++--- ignite/distributed/launcher.py | 2 +- ignite/distributed/utils.py | 2 +- ignite/handlers/__init__.py | 6 ++-- ignite/handlers/checkpoint.py | 36 +++++++++++----------- ignite/handlers/early_stopping.py | 8 ++--- ignite/handlers/timing.py | 25 +++++++-------- ignite/metrics/confusion_matrix.py | 20 +++++++----- mypy.ini | 21 +++++++++++++ 14 files changed, 106 insertions(+), 69 deletions(-) diff --git a/ignite/contrib/metrics/gpu_info.py b/ignite/contrib/metrics/gpu_info.py index 06af656ae343..7ab49447f395 100644 --- a/ignite/contrib/metrics/gpu_info.py +++ b/ignite/contrib/metrics/gpu_info.py @@ -59,7 +59,9 @@ def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: pass def compute(self) -> List[Dict[str, Any]]: - data = self.nvsmi.DeviceQuery("memory.used, memory.total, utilization.gpu") + data = self.nvsmi.DeviceQuery( + "memory.used, memory.total, utilization.gpu" + ) # type: Dict[str, List[Dict[str, Any]]] if len(data) == 0 or ("gpu" not in data): warnings.warn("No GPU information available") return [] diff --git a/ignite/distributed/comp_models/__init__.py b/ignite/distributed/comp_models/__init__.py index c9227701078c..587077570f13 100644 --- a/ignite/distributed/comp_models/__init__.py +++ b/ignite/distributed/comp_models/__init__.py @@ -1,13 +1,22 @@ +from typing import TYPE_CHECKING, List, Tuple, Type, Union + from ignite.distributed.comp_models.base import _SerialModel from ignite.distributed.comp_models.horovod import has_hvd_support from ignite.distributed.comp_models.native import has_native_dist_support from ignite.distributed.comp_models.xla import has_xla_support +if TYPE_CHECKING: + from ignite.distributed.comp_models.horovod import _HorovodDistModel + from ignite.distributed.comp_models.native import _NativeDistModel + from ignite.distributed.comp_models.xla import _XlaDistModel + -def setup_available_computation_models(): # type: ignore # inhomogeneous Tuple types are not supported +def setup_available_computation_models() -> Tuple[ + Type[Union[_SerialModel, "_NativeDistModel", "_XlaDistModel", "_HorovodDistModel"]], ... +]: models = [ _SerialModel, - ] + ] # type: List[Type[Union[_SerialModel, "_NativeDistModel", "_XlaDistModel", "_HorovodDistModel"]]] if has_native_dist_support: from ignite.distributed.comp_models.native import _NativeDistModel diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index 0502971908b4..70549f9f1a9f 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -15,11 +15,11 @@ class ComputationModel(metaclass=ABCMeta): # this is an additional local rank storage used when idist is setup from existing native torch dist context _ext_local_rank = None # type: Optional[int] - def __init__(self): - self._backend = None - self._nproc_per_node = None - self._nnodes = None - self._node = None + def __init__(self) -> None: + self._backend = None # type: Optional[str] + self._nproc_per_node = None # type: Optional[int] + self._nnodes = None # type: Optional[int] + self._node = None # type: Optional[int] def _setup_attrs(self) -> None: if self._nproc_per_node is None: diff --git a/ignite/distributed/comp_models/horovod.py b/ignite/distributed/comp_models/horovod.py index ef4600bb4d7b..5f36fd3d08f3 100644 --- a/ignite/distributed/comp_models/horovod.py +++ b/ignite/distributed/comp_models/horovod.py @@ -1,6 +1,6 @@ import os import warnings -from typing import Any, Callable, Mapping, Optional, Tuple +from typing import Any, Callable, Mapping, Optional, Tuple, cast import torch @@ -62,7 +62,7 @@ def __init__(self, do_init: bool = False, **kwargs: Any) -> None: """This is a private method. Please, use `create_from_backend` or `create_from_context` """ super(_HorovodDistModel, self).__init__() - self._backend = HOROVOD + self._backend = HOROVOD # type: str if do_init: comm = kwargs.get("comm", None) hvd.init(comm=comm) @@ -87,13 +87,13 @@ def get_world_size(self) -> int: return hvd.size() def get_nproc_per_node(self) -> int: - return self._nproc_per_node + return cast(int, self._nproc_per_node) def get_nnodes(self) -> int: - return self._nnodes + return cast(int, self._nnodes) def get_node_rank(self) -> int: - return self._node + return cast(int, self._node) def device(self) -> torch.device: if torch.cuda.is_available(): diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index ff489d743a74..2d19e54a0fb4 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -2,7 +2,7 @@ import subprocess import warnings from distutils.version import LooseVersion -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast import torch import torch.distributed as dist @@ -214,13 +214,13 @@ def get_world_size(self) -> int: return dist.get_world_size() def get_nproc_per_node(self) -> int: - return self._nproc_per_node + return cast(int, self._nproc_per_node) def get_nnodes(self) -> int: - return self._nnodes + return cast(int, self._nnodes) def get_node_rank(self) -> int: - return self._node + return cast(int, self._node) def device(self) -> torch.device: if self.backend() == dist.Backend.NCCL: diff --git a/ignite/distributed/comp_models/xla.py b/ignite/distributed/comp_models/xla.py index 533defdb61db..fa981367f51f 100644 --- a/ignite/distributed/comp_models/xla.py +++ b/ignite/distributed/comp_models/xla.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Mapping, Optional, Tuple +from typing import Any, Callable, Mapping, Optional, Tuple, cast import torch @@ -53,7 +53,7 @@ def __init__(self, backend: Optional[str] = None, **kwargs: Any): def _create_from_backend(self, backend: str, **kwargs: Any) -> None: xm.rendezvous("init") - self._backend = backend + self._backend = backend # type: str self._setup_attrs() def _init_from_context(self) -> None: @@ -75,13 +75,13 @@ def get_world_size(self) -> int: return xm.xrt_world_size() def get_nproc_per_node(self) -> int: - return self._nproc_per_node + return cast(int, self._nproc_per_node) def get_nnodes(self) -> int: - return self._nnodes + return cast(int, self._nnodes) def get_node_rank(self) -> int: - return self._node + return cast(int, self._node) def device(self) -> torch.device: dev = torch_xla._XLAC._xla_get_default_device() diff --git a/ignite/distributed/launcher.py b/ignite/distributed/launcher.py index 643650fd14fa..de2b14c5a14b 100644 --- a/ignite/distributed/launcher.py +++ b/ignite/distributed/launcher.py @@ -175,7 +175,7 @@ def training(local_rank, config, **kwargs): def __init__( self, - backend: str = None, + backend: Optional[str] = None, nproc_per_node: Optional[int] = None, nnodes: Optional[int] = None, node_rank: Optional[int] = None, diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index ceecf684cb4c..a7cf8dc4ee05 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -496,7 +496,7 @@ def train_fn(local_rank, a, b, c): for comp_model_cls in registered_computation_models: if backend not in comp_model_cls.available_backends: continue - _set_model(comp_model_cls(backend, **kwargs)) + _set_model(comp_model_cls(backend, **kwargs)) # type: ignore[arg-type] def finalize() -> None: diff --git a/ignite/handlers/__init__.py b/ignite/handlers/__init__.py index 973e1297fdca..cb37d8ce431f 100644 --- a/ignite/handlers/__init__.py +++ b/ignite/handlers/__init__.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union from ignite.engine import Engine from ignite.engine.events import Events @@ -18,7 +18,7 @@ ] -def global_step_from_engine(engine: Engine, custom_event_name=None) -> Callable: +def global_step_from_engine(engine: Engine, custom_event_name: Optional[Events] = None) -> Callable: """Helper method to setup `global_step_transform` function using another engine. This can be helpful for logging trainer epoch/iteration while output handler is attached to an evaluator. @@ -30,7 +30,7 @@ def global_step_from_engine(engine: Engine, custom_event_name=None) -> Callable: global step """ - def wrapper(_: Any, event_name: Events): + def wrapper(_: Any, event_name: Events) -> int: if custom_event_name is not None: event_name = custom_event_name return engine.state.get_event_attrib_value(event_name) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 9bb936ac70ca..b793f4fd659c 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -6,7 +6,7 @@ from abc import ABCMeta, abstractmethod from collections import OrderedDict, namedtuple from tempfile import _TemporaryFileWrapper # type: ignore[attr-defined] -from typing import Callable, Mapping, Optional, Union +from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union import torch import torch.nn as nn @@ -233,7 +233,7 @@ def score_function(engine): """ - Item = namedtuple("Item", ["priority", "filename"]) + Item = NamedTuple("Item", [("priority", int), ("filename", str)]) _state_dict_all_req_keys = ("saved",) def __init__( @@ -244,10 +244,10 @@ def __init__( score_function: Optional[Callable] = None, score_name: Optional[str] = None, n_saved: Optional[int] = 1, - global_step_transform: Callable = None, + global_step_transform: Optional[Callable] = None, filename_pattern: Optional[str] = None, include_self: bool = False, - ): + ) -> None: if to_save is not None: # for compatibility with ModelCheckpoint if not isinstance(to_save, collections.Mapping): @@ -287,7 +287,7 @@ def __init__( self.ext = "pt" self.global_step_transform = global_step_transform self.filename_pattern = filename_pattern - self._saved = [] # type: list + self._saved = [] # type: List["Checkpoint.Item"] self.include_self = include_self @property @@ -296,7 +296,7 @@ def last_checkpoint(self) -> Optional[str]: return None return self._saved[-1].filename - def _check_lt_n_saved(self, or_equal=False): + def _check_lt_n_saved(self, or_equal: bool = False) -> bool: if self.n_saved is None: return True return len(self._saved) < self.n_saved + int(or_equal) @@ -380,7 +380,7 @@ def __call__(self, engine: Engine) -> None: except TypeError: self.save_handler(checkpoint, filename) - def _setup_checkpoint(self) -> dict: + def _setup_checkpoint(self) -> Dict[str, Dict[Any, Any]]: checkpoint = {} if self.to_save is not None: for k, obj in self.to_save.items(): @@ -446,7 +446,7 @@ def _check_objects(objs: Mapping, attr: str) -> None: raise TypeError("Object {} should have `{}` method".format(type(obj), attr)) @staticmethod - def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs) -> None: + def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None: """Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``. Exemples: @@ -514,7 +514,7 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs) -> None: else: obj.load_state_dict(checkpoint[k]) - def state_dict(self) -> OrderedDict: + def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]": return OrderedDict([("saved", [(p, f) for p, f in self._saved])]) def load_state_dict(self, state_dict: Mapping) -> None: @@ -537,8 +537,8 @@ class DiskSaver(BaseSaveHandler): """ def __init__( - self, dirname: str, atomic: bool = True, create_dir: bool = True, require_empty: bool = True, **kwargs - ): + self, dirname: str, atomic: bool = True, create_dir: bool = True, require_empty: bool = True, **kwargs: Any + ) -> None: self.dirname = os.path.expanduser(dirname) self._atomic = atomic self._check_and_setup(dirname, create_dir, require_empty) @@ -546,7 +546,7 @@ def __init__( @staticmethod @idist.one_rank_only() - def _check_and_setup(dirname, create_dir, require_empty): + def _check_and_setup(dirname: str, create_dir: bool, require_empty: bool) -> None: if create_dir: if not os.path.exists(dirname): os.makedirs(dirname) @@ -573,16 +573,16 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin self._save_native(checkpoint, path) @idist.one_rank_only() - def _save_native(self, checkpoint: Mapping, path: str): + def _save_native(self, checkpoint: Mapping, path: str) -> None: self._save_func(checkpoint, path, torch.save) - def _save_xla(self, checkpoint: Mapping, path: str): - import torch_xla.core.xla_model as xm # type: ignore + def _save_xla(self, checkpoint: Mapping, path: str) -> None: + import torch_xla.core.xla_model as xm # all tpu procs should enter here as internally performs sync across device self._save_func(checkpoint, path, xm.save, rank=idist.get_rank()) - def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = 0): + def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = 0) -> None: if not self._atomic: func(checkpoint, path, **self.kwargs) else: @@ -686,8 +686,8 @@ def __init__( create_dir: bool = True, global_step_transform: Optional[Callable] = None, include_self: bool = False, - **kwargs - ): + **kwargs: Any + ) -> None: disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty, **kwargs) diff --git a/ignite/handlers/early_stopping.py b/ignite/handlers/early_stopping.py index b1000b8339ac..b414883a2bd6 100644 --- a/ignite/handlers/early_stopping.py +++ b/ignite/handlers/early_stopping.py @@ -1,6 +1,6 @@ import logging from collections import OrderedDict -from typing import Callable, Mapping +from typing import Callable, Mapping, Optional, cast from ignite.base import Serializable from ignite.engine import Engine @@ -75,7 +75,7 @@ def __init__( self.cumulative_delta = cumulative_delta self.trainer = trainer self.counter = 0 - self.best_score = None + self.best_score = None # type: Optional[float] self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) def __call__(self, engine: Engine) -> None: @@ -95,8 +95,8 @@ def __call__(self, engine: Engine) -> None: self.best_score = score self.counter = 0 - def state_dict(self) -> OrderedDict: - return OrderedDict([("counter", self.counter), ("best_score", self.best_score)]) + def state_dict(self) -> "OrderedDict[str, float]": + return OrderedDict([("counter", self.counter), ("best_score", cast(float, self.best_score))]) def load_state_dict(self, state_dict: Mapping) -> None: super().load_state_dict(state_dict) diff --git a/ignite/handlers/timing.py b/ignite/handlers/timing.py index f104ec4a489b..63d96d0a8596 100644 --- a/ignite/handlers/timing.py +++ b/ignite/handlers/timing.py @@ -1,5 +1,5 @@ from time import perf_counter -from typing import Optional +from typing import Any, Optional from ignite.engine import Engine, Events @@ -76,13 +76,10 @@ class Timer: ... step=Events.ITERATION_COMPLETED) """ - def __init__(self, average: bool = False): + def __init__(self, average: bool = False) -> None: self._average = average - self._t0 = perf_counter() - self.total = 0.0 - self.step_count = 0.0 - self.running = True + self.reset() def attach( self, @@ -91,7 +88,7 @@ def attach( pause: Events = Events.COMPLETED, resume: Optional[Events] = None, step: Optional[Events] = None, - ): + ) -> "Timer": """ Register callbacks to control the timer. Args: @@ -122,16 +119,20 @@ def attach( return self - def reset(self, *args): - self.__init__(self._average) + def reset(self, *args: Any) -> "Timer": + self._t0 = perf_counter() + self.total = 0.0 + self.step_count = 0.0 + self.running = True + return self - def pause(self, *args) -> None: + def pause(self, *args: Any) -> None: if self.running: self.total += self._elapsed() self.running = False - def resume(self, *args) -> None: + def resume(self, *args: Any) -> None: if not self.running: self.running = True self._t0 = perf_counter() @@ -148,7 +149,7 @@ def value(self) -> float: return total / denominator - def step(self, *args) -> None: + def step(self, *args: Any) -> None: self.step_count += 1.0 def _elapsed(self) -> float: diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index f4fd840699b7..c45fc6445681 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -166,7 +166,7 @@ def IoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambd # Increase floating point precision and pass to CPU cm = cm.type(torch.DoubleTensor) - iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-15) + iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-15) # type: MetricsLambda if ignore_index is not None: ignore_idx = ignore_index # type: int # used due to typing issues with mympy @@ -208,7 +208,8 @@ def mIoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLamb """ - return IoU(cm=cm, ignore_index=ignore_index).mean() + iou = IoU(cm=cm, ignore_index=ignore_index).mean() # type: MetricsLambda + return iou def cmAccuracy(cm: ConfusionMatrix) -> MetricsLambda: @@ -222,7 +223,8 @@ def cmAccuracy(cm: ConfusionMatrix) -> MetricsLambda: """ # Increase floating point precision and pass to CPU cm = cm.type(torch.DoubleTensor) - return cm.diag().sum() / (cm.sum() + 1e-15) + accuracy = cm.diag().sum() / (cm.sum() + 1e-15) # type: MetricsLambda + return accuracy def cmPrecision(cm: ConfusionMatrix, average: bool = True) -> MetricsLambda: @@ -237,9 +239,10 @@ def cmPrecision(cm: ConfusionMatrix, average: bool = True) -> MetricsLambda: # Increase floating point precision and pass to CPU cm = cm.type(torch.DoubleTensor) - precision = cm.diag() / (cm.sum(dim=0) + 1e-15) + precision = cm.diag() / (cm.sum(dim=0) + 1e-15) # type: MetricsLambda if average: - return precision.mean() + mean = precision.mean() # type: MetricsLambda + return mean return precision @@ -255,9 +258,10 @@ def cmRecall(cm: ConfusionMatrix, average: bool = True) -> MetricsLambda: # Increase floating point precision and pass to CPU cm = cm.type(torch.DoubleTensor) - recall = cm.diag() / (cm.sum(dim=1) + 1e-15) + recall = cm.diag() / (cm.sum(dim=1) + 1e-15) # type: MetricsLambda if average: - return recall.mean() + mean = recall.mean() # type: MetricsLambda + return mean return recall @@ -278,7 +282,7 @@ def DiceCoefficient(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> # Increase floating point precision and pass to CPU cm = cm.type(torch.DoubleTensor) - dice = 2.0 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + 1e-15) + dice = 2.0 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + 1e-15) # type: MetricsLambda if ignore_index is not None: ignore_idx = ignore_index # type: int # used due to typing issues with mympy diff --git a/mypy.ini b/mypy.ini index bd6d15eabfe1..bd790c5ecc8c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -3,6 +3,27 @@ files = ignite pretty = True show_error_codes = True +check_untyped_defs = True +; a lot of work needed to fix issues +disallow_any_generics = False +disallow_incomplete_defs = True +disallow_subclassing_any = True +; due to missing types in pytorch set to False +disallow_untyped_calls = False +disallow_untyped_decorators = True +disallow_untyped_defs = True +no_implicit_optional = True +; would need a more precise import of pytorch classes and methods, which is not possible, therefore set to False +no_implicit_reexport = False +strict_equality = True +warn_redundant_casts = True +; due to missing types in multiple libs set to False +warn_return_any = False +; results in too many false positives, therefore set to False +warn_unreachable = False +warn_unused_configs = True +warn_unused_ignores = True + [mypy-ignite.contrib.handlers.*] ignore_errors = True