Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ignite/contrib/metrics/gpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
pass

def compute(self) -> List[Dict[str, Any]]:
data = self.nvsmi.DeviceQuery("memory.used, memory.total, utilization.gpu")
data = self.nvsmi.DeviceQuery(
"memory.used, memory.total, utilization.gpu"
) # type: Dict[str, List[Dict[str, Any]]]
if len(data) == 0 or ("gpu" not in data):
warnings.warn("No GPU information available")
return []
Expand Down
13 changes: 11 additions & 2 deletions ignite/distributed/comp_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from typing import TYPE_CHECKING, List, Tuple, Type, Union

from ignite.distributed.comp_models.base import _SerialModel
from ignite.distributed.comp_models.horovod import has_hvd_support
from ignite.distributed.comp_models.native import has_native_dist_support
from ignite.distributed.comp_models.xla import has_xla_support

if TYPE_CHECKING:
from ignite.distributed.comp_models.horovod import _HorovodDistModel
from ignite.distributed.comp_models.native import _NativeDistModel
from ignite.distributed.comp_models.xla import _XlaDistModel


def setup_available_computation_models(): # type: ignore # inhomogeneous Tuple types are not supported
def setup_available_computation_models() -> Tuple[
Type[Union[_SerialModel, "_NativeDistModel", "_XlaDistModel", "_HorovodDistModel"]], ...
]:
models = [
_SerialModel,
]
] # type: List[Type[Union[_SerialModel, "_NativeDistModel", "_XlaDistModel", "_HorovodDistModel"]]]
if has_native_dist_support:
from ignite.distributed.comp_models.native import _NativeDistModel

Expand Down
10 changes: 5 additions & 5 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class ComputationModel(metaclass=ABCMeta):
# this is an additional local rank storage used when idist is setup from existing native torch dist context
_ext_local_rank = None # type: Optional[int]

def __init__(self):
self._backend = None
self._nproc_per_node = None
self._nnodes = None
self._node = None
def __init__(self) -> None:
self._backend = None # type: Optional[str]
self._nproc_per_node = None # type: Optional[int]
self._nnodes = None # type: Optional[int]
self._node = None # type: Optional[int]

def _setup_attrs(self) -> None:
if self._nproc_per_node is None:
Expand Down
10 changes: 5 additions & 5 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import warnings
from typing import Any, Callable, Mapping, Optional, Tuple
from typing import Any, Callable, Mapping, Optional, Tuple, cast

import torch

Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(self, do_init: bool = False, **kwargs: Any) -> None:
"""This is a private method. Please, use `create_from_backend` or `create_from_context`
"""
super(_HorovodDistModel, self).__init__()
self._backend = HOROVOD
self._backend = HOROVOD # type: str
if do_init:
comm = kwargs.get("comm", None)
hvd.init(comm=comm)
Expand All @@ -87,13 +87,13 @@ def get_world_size(self) -> int:
return hvd.size()

def get_nproc_per_node(self) -> int:
return self._nproc_per_node
return cast(int, self._nproc_per_node)

def get_nnodes(self) -> int:
return self._nnodes
return cast(int, self._nnodes)

def get_node_rank(self) -> int:
return self._node
return cast(int, self._node)

def device(self) -> torch.device:
if torch.cuda.is_available():
Expand Down
8 changes: 4 additions & 4 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import subprocess
import warnings
from distutils.version import LooseVersion
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -214,13 +214,13 @@ def get_world_size(self) -> int:
return dist.get_world_size()

def get_nproc_per_node(self) -> int:
return self._nproc_per_node
return cast(int, self._nproc_per_node)

def get_nnodes(self) -> int:
return self._nnodes
return cast(int, self._nnodes)

def get_node_rank(self) -> int:
return self._node
return cast(int, self._node)

def device(self) -> torch.device:
if self.backend() == dist.Backend.NCCL:
Expand Down
10 changes: 5 additions & 5 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Mapping, Optional, Tuple
from typing import Any, Callable, Mapping, Optional, Tuple, cast

import torch

Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(self, backend: Optional[str] = None, **kwargs: Any):
def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
xm.rendezvous("init")

self._backend = backend
self._backend = backend # type: str
self._setup_attrs()

def _init_from_context(self) -> None:
Expand All @@ -75,13 +75,13 @@ def get_world_size(self) -> int:
return xm.xrt_world_size()

def get_nproc_per_node(self) -> int:
return self._nproc_per_node
return cast(int, self._nproc_per_node)

def get_nnodes(self) -> int:
return self._nnodes
return cast(int, self._nnodes)

def get_node_rank(self) -> int:
return self._node
return cast(int, self._node)

def device(self) -> torch.device:
dev = torch_xla._XLAC._xla_get_default_device()
Expand Down
2 changes: 1 addition & 1 deletion ignite/distributed/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def training(local_rank, config, **kwargs):

def __init__(
self,
backend: str = None,
backend: Optional[str] = None,
nproc_per_node: Optional[int] = None,
nnodes: Optional[int] = None,
node_rank: Optional[int] = None,
Expand Down
2 changes: 1 addition & 1 deletion ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def train_fn(local_rank, a, b, c):
for comp_model_cls in registered_computation_models:
if backend not in comp_model_cls.available_backends:
continue
_set_model(comp_model_cls(backend, **kwargs))
_set_model(comp_model_cls(backend, **kwargs)) # type: ignore[arg-type]


def finalize() -> None:
Expand Down
6 changes: 3 additions & 3 deletions ignite/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Union

from ignite.engine import Engine
from ignite.engine.events import Events
Expand All @@ -18,7 +18,7 @@
]


def global_step_from_engine(engine: Engine, custom_event_name=None) -> Callable:
def global_step_from_engine(engine: Engine, custom_event_name: Optional[Events] = None) -> Callable:
"""Helper method to setup `global_step_transform` function using another engine.
This can be helpful for logging trainer epoch/iteration while output handler is attached to an evaluator.

Expand All @@ -30,7 +30,7 @@ def global_step_from_engine(engine: Engine, custom_event_name=None) -> Callable:
global step
"""

def wrapper(_: Any, event_name: Events):
def wrapper(_: Any, event_name: Events) -> int:
if custom_event_name is not None:
event_name = custom_event_name
return engine.state.get_event_attrib_value(event_name)
Expand Down
36 changes: 18 additions & 18 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, namedtuple
from tempfile import _TemporaryFileWrapper # type: ignore[attr-defined]
from typing import Callable, Mapping, Optional, Union
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -233,7 +233,7 @@ def score_function(engine):

"""

Item = namedtuple("Item", ["priority", "filename"])
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
_state_dict_all_req_keys = ("saved",)

def __init__(
Expand All @@ -244,10 +244,10 @@ def __init__(
score_function: Optional[Callable] = None,
score_name: Optional[str] = None,
n_saved: Optional[int] = 1,
global_step_transform: Callable = None,
global_step_transform: Optional[Callable] = None,
filename_pattern: Optional[str] = None,
include_self: bool = False,
):
) -> None:

if to_save is not None: # for compatibility with ModelCheckpoint
if not isinstance(to_save, collections.Mapping):
Expand Down Expand Up @@ -287,7 +287,7 @@ def __init__(
self.ext = "pt"
self.global_step_transform = global_step_transform
self.filename_pattern = filename_pattern
self._saved = [] # type: list
self._saved = [] # type: List["Checkpoint.Item"]
self.include_self = include_self

@property
Expand All @@ -296,7 +296,7 @@ def last_checkpoint(self) -> Optional[str]:
return None
return self._saved[-1].filename

def _check_lt_n_saved(self, or_equal=False):
def _check_lt_n_saved(self, or_equal: bool = False) -> bool:
if self.n_saved is None:
return True
return len(self._saved) < self.n_saved + int(or_equal)
Expand Down Expand Up @@ -380,7 +380,7 @@ def __call__(self, engine: Engine) -> None:
except TypeError:
self.save_handler(checkpoint, filename)

def _setup_checkpoint(self) -> dict:
def _setup_checkpoint(self) -> Dict[str, Dict[Any, Any]]:
checkpoint = {}
if self.to_save is not None:
for k, obj in self.to_save.items():
Expand Down Expand Up @@ -446,7 +446,7 @@ def _check_objects(objs: Mapping, attr: str) -> None:
raise TypeError("Object {} should have `{}` method".format(type(obj), attr))

@staticmethod
def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs) -> None:
def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.

Exemples:
Expand Down Expand Up @@ -514,7 +514,7 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs) -> None:
else:
obj.load_state_dict(checkpoint[k])

def state_dict(self) -> OrderedDict:
def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
return OrderedDict([("saved", [(p, f) for p, f in self._saved])])

def load_state_dict(self, state_dict: Mapping) -> None:
Expand All @@ -537,16 +537,16 @@ class DiskSaver(BaseSaveHandler):
"""

def __init__(
self, dirname: str, atomic: bool = True, create_dir: bool = True, require_empty: bool = True, **kwargs
):
self, dirname: str, atomic: bool = True, create_dir: bool = True, require_empty: bool = True, **kwargs: Any
) -> None:
self.dirname = os.path.expanduser(dirname)
self._atomic = atomic
self._check_and_setup(dirname, create_dir, require_empty)
self.kwargs = kwargs

@staticmethod
@idist.one_rank_only()
def _check_and_setup(dirname, create_dir, require_empty):
def _check_and_setup(dirname: str, create_dir: bool, require_empty: bool) -> None:
if create_dir:
if not os.path.exists(dirname):
os.makedirs(dirname)
Expand All @@ -573,16 +573,16 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin
self._save_native(checkpoint, path)

@idist.one_rank_only()
def _save_native(self, checkpoint: Mapping, path: str):
def _save_native(self, checkpoint: Mapping, path: str) -> None:
self._save_func(checkpoint, path, torch.save)

def _save_xla(self, checkpoint: Mapping, path: str):
import torch_xla.core.xla_model as xm # type: ignore
def _save_xla(self, checkpoint: Mapping, path: str) -> None:
import torch_xla.core.xla_model as xm

# all tpu procs should enter here as internally performs sync across device
self._save_func(checkpoint, path, xm.save, rank=idist.get_rank())

def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = 0):
def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = 0) -> None:
if not self._atomic:
func(checkpoint, path, **self.kwargs)
else:
Expand Down Expand Up @@ -686,8 +686,8 @@ def __init__(
create_dir: bool = True,
global_step_transform: Optional[Callable] = None,
include_self: bool = False,
**kwargs
):
**kwargs: Any
) -> None:

disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty, **kwargs)

Expand Down
8 changes: 4 additions & 4 deletions ignite/handlers/early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections import OrderedDict
from typing import Callable, Mapping
from typing import Callable, Mapping, Optional, cast

from ignite.base import Serializable
from ignite.engine import Engine
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
self.cumulative_delta = cumulative_delta
self.trainer = trainer
self.counter = 0
self.best_score = None
self.best_score = None # type: Optional[float]
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)

def __call__(self, engine: Engine) -> None:
Expand All @@ -95,8 +95,8 @@ def __call__(self, engine: Engine) -> None:
self.best_score = score
self.counter = 0

def state_dict(self) -> OrderedDict:
return OrderedDict([("counter", self.counter), ("best_score", self.best_score)])
def state_dict(self) -> "OrderedDict[str, float]":
return OrderedDict([("counter", self.counter), ("best_score", cast(float, self.best_score))])

def load_state_dict(self, state_dict: Mapping) -> None:
super().load_state_dict(state_dict)
Expand Down
Loading