From 915ecbba75b9de1c24dc28c97b6af3913cc990d2 Mon Sep 17 00:00:00 2001 From: gruebel Date: Sun, 4 Oct 2020 22:29:05 +0200 Subject: [PATCH 1/6] Activate mypy in ignite.distributed --- ignite/distributed/auto.py | 89 +++++++++++++--------- ignite/distributed/comp_models/__init__.py | 2 +- ignite/distributed/comp_models/base.py | 66 +++++++++++----- ignite/distributed/comp_models/horovod.py | 24 +++--- ignite/distributed/comp_models/native.py | 67 +++++++++------- ignite/distributed/comp_models/xla.py | 26 ++++--- ignite/distributed/launcher.py | 21 +++-- ignite/distributed/utils.py | 37 +++++---- mypy.ini | 8 +- 9 files changed, 207 insertions(+), 133 deletions(-) diff --git a/ignite/distributed/auto.py b/ignite/distributed/auto.py index 49ae70cae296..273301ea1b04 100644 --- a/ignite/distributed/auto.py +++ b/ignite/distributed/auto.py @@ -1,4 +1,5 @@ import warnings +from typing import Any, Callable, Iterator, List, Optional, Union import torch import torch.nn as nn @@ -16,7 +17,42 @@ __all__ = ["auto_dataloader", "auto_model", "auto_optim", "DistributedProxySampler"] -def auto_dataloader(dataset, **kwargs): +if idist.has_xla_support: + + import torch_xla.core.xla_model as xm + from torch_xla.distributed.parallel_loader import ParallelLoader + + class _MpDeviceLoader: + # https://github.com/pytorch/xla/pull/2117 + # From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available + def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None: + self._loader = loader + self._device = device + self._parallel_loader_kwargs = kwargs + + def __setattr__(self, name: str, value: Any) -> None: + super().__setattr__(name, value) + + def __getattr__(self, name: str) -> Any: + super().__getattribute__(name) + + def __iter__(self) -> Iterator: + parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs) + return parallel_loader.per_device_loader(self._device) + + def __len__(self) -> int: + return len(self._loader) + + class _XLADistributedOptimizer(Optimizer): + def __init__(self, optimizer: Optimizer) -> None: + super(self.__class__, self).__init__(optimizer.param_groups, {}) + self.wrapped_optimizer = optimizer + + def step(self, closure: Optional[Callable] = None) -> None: + xm.optimizer_step(self.wrapped_optimizer, barrier=True) + + +def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, _MpDeviceLoader]: """Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting all available backends from :meth:`~ignite.distributed.utils.available_backends()`). @@ -73,6 +109,8 @@ def auto_dataloader(dataset, **kwargs): kwargs["num_workers"] = (kwargs["num_workers"] + nproc - 1) // nproc if "batch_sampler" not in kwargs: + sampler: Union[DistributedProxySampler, DistributedSampler] + if kwargs.get("sampler", None) is not None: sampler = DistributedProxySampler(kwargs["sampler"], num_replicas=world_size, rank=rank) else: @@ -101,7 +139,7 @@ def auto_dataloader(dataset, **kwargs): kwargs["pin_memory"] = kwargs.get("pin_memory", "cuda" in idist.device().type) logger.info("Use data loader kwargs for dataset '{}': \n\t{}".format(repr(dataset)[:20].strip(), kwargs)) - dataloader = DataLoader(dataset, **kwargs) + dataloader: Union[DataLoader, _MpDeviceLoader] = DataLoader(dataset, **kwargs) if idist.has_xla_support and idist.backend() == idist_xla.XLA_TPU and world_size > 1: @@ -115,7 +153,7 @@ def auto_dataloader(dataset, **kwargs): except ImportError: pass - sampler = dataloader.sampler + sampler = dataloader.sampler # type: ignore[union-attr] dataloader = mp_device_loader_cls(dataloader, idist.device()) dataloader.sampler = sampler @@ -266,7 +304,7 @@ class DistributedProxySampler(DistributedSampler): """ - def __init__(self, sampler: Sampler, num_replicas=None, rank=None): + def __init__(self, sampler: Sampler, num_replicas: Optional[int] = None, rank: Optional[int] = None) -> None: if not isinstance(sampler, Sampler): raise TypeError("Argument sampler should be instance of torch Sampler, but given: {}".format(type(sampler))) @@ -274,14 +312,22 @@ def __init__(self, sampler: Sampler, num_replicas=None, rank=None): if not hasattr(sampler, "__len__"): raise TypeError("Argument sampler should have length") - super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False) + super(DistributedProxySampler, self).__init__( + sampler, num_replicas=num_replicas, rank=rank, shuffle=False # type: ignore[arg-type] + ) self.sampler = sampler - def __iter__(self): + def __setattr__(self, name: str, value: Any) -> None: + super().__setattr__(name, value) + + def __getattr__(self, name: str) -> Any: + super().__getattribute__(name) + + def __iter__(self) -> Iterator: # deterministically shuffle based on epoch torch.manual_seed(self.epoch) - indices = [] + indices: List = [] while len(indices) < self.total_size: indices += list(self.sampler) @@ -294,32 +340,3 @@ def __iter__(self): raise RuntimeError("{} vs {}".format(len(indices), self.num_samples)) return iter(indices) - - -if idist.has_xla_support: - - import torch_xla.core.xla_model as xm - from torch_xla.distributed.parallel_loader import ParallelLoader - - class _MpDeviceLoader: - # https://github.com/pytorch/xla/pull/2117 - # From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available - def __init__(self, loader, device, **kwargs): - self._loader = loader - self._device = device - self._parallel_loader_kwargs = kwargs - - def __iter__(self): - parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs) - return parallel_loader.per_device_loader(self._device) - - def __len__(self): - return len(self._loader) - - class _XLADistributedOptimizer(Optimizer): - def __init__(self, optimizer): - super(self.__class__, self).__init__(optimizer.param_groups) - self.wrapped_optimizer = optimizer - - def step(self, closure=None): - xm.optimizer_step(self.wrapped_optimizer, barrier=True) diff --git a/ignite/distributed/comp_models/__init__.py b/ignite/distributed/comp_models/__init__.py index 3001edcb0671..c9227701078c 100644 --- a/ignite/distributed/comp_models/__init__.py +++ b/ignite/distributed/comp_models/__init__.py @@ -4,7 +4,7 @@ from ignite.distributed.comp_models.xla import has_xla_support -def setup_available_computation_models(): +def setup_available_computation_models(): # type: ignore # inhomogeneous Tuple types are not supported models = [ _SerialModel, ] diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index cd3fad630436..f39bb71fe758 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -1,7 +1,7 @@ import warnings from abc import ABCMeta, abstractmethod from numbers import Number -from typing import Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union, cast, overload import torch @@ -13,7 +13,7 @@ class ComputationModel(metaclass=ABCMeta): """ # this is an additional local rank storage used when idist is setup from existing native torch dist context - _ext_local_rank = None + _ext_local_rank: Optional[int] = None def __init__(self): self._backend = None @@ -21,7 +21,7 @@ def __init__(self): self._nnodes = None self._node = None - def _setup_attrs(self): + def _setup_attrs(self) -> None: if self._nproc_per_node is None: self._nproc_per_node = self._compute_nproc_per_node() if self.get_world_size() > 1 else 1 if self._nnodes is None: @@ -66,7 +66,7 @@ def backend(self) -> Optional[str]: pass @abstractmethod - def finalize(self): + def finalize(self) -> None: pass @staticmethod @@ -76,15 +76,15 @@ def create_from_context() -> Optional["ComputationModel"]: @staticmethod @abstractmethod - def create_from_backend(backend: str, **kwargs) -> "ComputationModel": + def create_from_backend(backend: str, **kwargs: Any) -> "ComputationModel": pass @staticmethod @abstractmethod - def spawn(*args, **kwargs): + def spawn(*args: Any, **kwargs: Any) -> None: pass - _collective_op_dtype = None + _collective_op_dtype: Any = None @staticmethod def _encode_str(x: str, device: torch.device) -> torch.Tensor: @@ -107,7 +107,9 @@ def _decode_str(xs: torch.Tensor) -> List[str]: out = [bytearray(x[: x[-1]].tolist()).decode("utf-8") for x in xs] return out - def _apply_op(self, tensor: torch.Tensor, device: torch.device, fn: Callable, *args, **kwargs) -> torch.Tensor: + def _apply_op( + self, tensor: torch.Tensor, device: torch.device, fn: Callable, *args: Any, **kwargs: Any + ) -> torch.Tensor: out_dtype = None tensor_device = None @@ -132,9 +134,20 @@ def _apply_op(self, tensor: torch.Tensor, device: torch.device, fn: Callable, *a return tensor.to(device=tensor_device) return tensor + @overload def _collective_op( - self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args, **kwargs + self, tensor: Union[torch.Tensor, Number], fn: Callable, *args: Any, **kwargs: Any + ) -> Union[torch.Tensor, Number]: + ... + + @overload + def _collective_op( + self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any ) -> Union[torch.Tensor, Number, List[str]]: + ... + + # mypy doesn't support overload for no-untyped-def check + def _collective_op(self, tensor, fn, *args, **kwargs): # type: ignore tensor_to_number = tensor_to_str = False device = self.device() if isinstance(tensor, Number): @@ -164,7 +177,15 @@ def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Te return self._collective_op(tensor, self._do_all_gather) + @overload + def broadcast(self, tensor: Union[torch.Tensor, Number], src: int = 0) -> Union[torch.Tensor, Number]: + ... + + @overload def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]: + ... + + def broadcast(self, tensor, src=0): # type: ignore # mypy doesn't support overload for no-untyped-def check if not isinstance(tensor, (torch.Tensor, Number, str)): raise TypeError("Unhandled input type {}".format(type(tensor))) @@ -208,7 +229,7 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: pass @abstractmethod - def barrier(self): + def barrier(self) -> None: pass @@ -219,6 +240,9 @@ class _SerialModel(ComputationModel): name = "serial" available_backends = () + def __init__(self, _backend: Optional[str] = None, **_kwargs: Any) -> None: + super(_SerialModel, self).__init__() + def get_local_rank(self) -> int: return 0 @@ -242,10 +266,10 @@ def device(self) -> torch.device: return torch.device("cuda") return torch.device("cpu") - def backend(self) -> None: + def backend(self) -> Optional[str]: return None - def finalize(self): + def finalize(self) -> None: pass def _compute_nproc_per_node(self) -> int: @@ -256,20 +280,28 @@ def create_from_context() -> "_SerialModel": return _SerialModel() @staticmethod - def create_from_backend(backend: Optional[str] = None, **kwargs) -> "_SerialModel": + def create_from_backend(backend: Optional[str] = None, **kwargs: Any) -> "_SerialModel": return _SerialModel() @staticmethod - def spawn(*args, **kwargs): + def spawn(*args: Any, **kwargs: Any) -> None: raise NotImplementedError("Serial computation model does not implement spawn method") def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Union[torch.Tensor, Number]: return tensor - def all_gather(self, tensor: Union[torch.Tensor, Number]) -> Union[torch.Tensor, Number]: - return tensor + def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]: + return cast(Union[torch.Tensor, Number], tensor) + @overload + def broadcast(self, tensor: Union[torch.Tensor, Number], src: int = 0) -> Union[torch.Tensor, Number]: + ... + + @overload def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]: + ... + + def broadcast(self, tensor, src=0): # type: ignore # mypy doesn't support overload for no-untyped-def check return tensor def _do_all_reduce(self, tensor: torch.Tensor, op: str = "sum") -> torch.Tensor: @@ -281,5 +313,5 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: pass - def barrier(self): + def barrier(self) -> None: pass diff --git a/ignite/distributed/comp_models/horovod.py b/ignite/distributed/comp_models/horovod.py index 1bdcb1402ad7..1b54f1f42f98 100644 --- a/ignite/distributed/comp_models/horovod.py +++ b/ignite/distributed/comp_models/horovod.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Mapping, Optional, Tuple +from typing import Any, Callable, Mapping, Optional, Tuple import torch @@ -33,7 +33,7 @@ class _HorovodDistModel(ComputationModel): available_backends = (HOROVOD,) @staticmethod - def _get_hvd_rank(): + def _get_hvd_rank() -> int: try: rank = hvd.rank() except ValueError as e: @@ -48,7 +48,7 @@ def create_from_context() -> Optional["_HorovodDistModel"]: return _HorovodDistModel() @staticmethod - def create_from_backend(backend: str, **kwargs) -> "_HorovodDistModel": + def create_from_backend(backend: str, **kwargs: Any) -> "_HorovodDistModel": if backend not in _HorovodDistModel.available_backends: raise ValueError("Backend should be one of '{}'".format(_HorovodDistModel.available_backends)) @@ -57,7 +57,7 @@ def create_from_backend(backend: str, **kwargs) -> "_HorovodDistModel": raise RuntimeError("Can not re-initialize Horovod if it is already initialized") return _HorovodDistModel(do_init=True, **kwargs) - def __init__(self, do_init=False, **kwargs): + def __init__(self, do_init: bool = False, **kwargs: Any) -> None: """This is a private method. Please, use `create_from_backend` or `create_from_context` """ super(_HorovodDistModel, self).__init__() @@ -73,7 +73,7 @@ def __init__(self, do_init=False, **kwargs): self._setup_attrs() - def _compute_nproc_per_node(self): + def _compute_nproc_per_node(self) -> int: return hvd.local_size() def get_local_rank(self) -> int: @@ -103,11 +103,11 @@ def device(self) -> torch.device: def backend(self) -> str: return self._backend - def finalize(self): + def finalize(self) -> None: hvd.shutdown() @staticmethod - def _dist_worker_task_fn(backend, fn, args, kwargs_dict): + def _dist_worker_task_fn(backend: str, fn: Callable, args: Tuple, kwargs_dict: Mapping) -> None: from ignite.distributed.utils import _set_model, finalize model = _HorovodDistModel.create_from_backend(backend) @@ -116,15 +116,15 @@ def _dist_worker_task_fn(backend, fn, args, kwargs_dict): finalize() @staticmethod - def spawn( + def spawn( # type: ignore[override] fn: Callable, args: Tuple, kwargs_dict: Optional[Mapping] = None, nproc_per_node: int = 1, - hosts=None, + hosts: Optional[str] = None, backend: str = HOROVOD, - **kwargs - ): + **kwargs: Any + ) -> None: c1 = "nnodes" in kwargs and kwargs["nnodes"] > 1 c2 = "node_rank" in kwargs and kwargs["node_rank"] > 0 if c1 or c2: @@ -166,7 +166,7 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: return hvd.broadcast(tensor, root_rank=src) - def barrier(self): + def barrier(self) -> None: # https://github.com/horovod/horovod/issues/159#issuecomment-424834603 # hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier") hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier") diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index be44b8211b5a..acdd4c715e35 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -2,7 +2,7 @@ import subprocess import warnings from distutils.version import LooseVersion -from typing import Callable, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, overload import torch import torch.distributed as dist @@ -48,22 +48,22 @@ def create_from_context() -> Optional["_NativeDistModel"]: return _NativeDistModel() @staticmethod - def create_from_backend(backend: str, **kwargs) -> "_NativeDistModel": + def create_from_backend(backend: str, **kwargs: Any) -> "_NativeDistModel": if dist.is_available() and dist.is_initialized(): raise RuntimeError("Can not create new distributed process group if default one is already initialized") return _NativeDistModel(backend=backend, **kwargs) - def __init__(self, backend=None, timeout=None, **kwargs): + def __init__(self, backend: Optional[str] = None, timeout: Optional[int] = None, **kwargs: Any) -> None: """This is a private method. Please, use `create_from_backend` or `create_from_context` """ super(_NativeDistModel, self).__init__() - self._env_backup = None + self._env_backup: Optional[Dict[str, str]] = None if backend is not None: self._create_from_backend(backend, timeout=timeout, **kwargs) else: self._init_from_context() - def _create_from_backend(self, backend, timeout=None, **kwargs): + def _create_from_backend(self, backend: str, timeout: Optional[int] = None, **_kwargs: Any) -> None: if backend == dist.Backend.NCCL and not torch.cuda.is_available(): raise RuntimeError("Nccl backend is required but no cuda capable devices") @@ -71,8 +71,8 @@ def _create_from_backend(self, backend, timeout=None, **kwargs): self._local_rank = int(os.environ["LOCAL_RANK"]) # for debug purposes - self._master_port = int(os.environ["MASTER_PORT"]) - self._master_addr = os.environ["MASTER_ADDR"] + self._master_port: Optional[int] = int(os.environ["MASTER_PORT"]) + self._master_addr: Optional[str] = os.environ["MASTER_ADDR"] init_pg_kwargs = {} if timeout is not None: @@ -87,7 +87,7 @@ def _create_from_backend(self, backend, timeout=None, **kwargs): self._setup_attrs() - def _init_from_context(self): + def _init_from_context(self) -> None: self._identify_local_rank() @@ -96,39 +96,38 @@ def _init_from_context(self): self._master_addr = None self._setup_attrs() - def _compute_nproc_per_node(self): + def _compute_nproc_per_node(self) -> int: tensor = torch.tensor([self.get_local_rank() + 1]).to(self.device()) dist.all_reduce(tensor, op=dist.ReduceOp.MAX) - return tensor.item() + return int(tensor.item()) - def _get_all_hostnames(self): + def _get_all_hostnames(self) -> List[Tuple[str, ...]]: import socket device = "cpu" if self.backend() == dist.Backend.NCCL: index = torch.cuda.current_device() device = "cuda:{}".format(index) - name = socket.gethostname() - name = torch.tensor(bytearray(name, "utf-8")).to(device) + hostname = socket.gethostname() + name = torch.tensor(bytearray(hostname, "utf-8")).to(device) padded_t_name = torch.zeros(256, device=device, dtype=torch.long) padded_t_name[: len(name)] = name out_t_names = [torch.zeros_like(padded_t_name) for _ in range(self.get_world_size())] dist.all_gather(out_t_names, padded_t_name) - out_t_names = [tuple(t.cpu().tolist()) for t in out_t_names] - return out_t_names + return [tuple(t.cpu().tolist()) for t in out_t_names] @staticmethod - def _compute_node_and_local_ranks(rank, hostnames): + def _compute_node_and_local_ranks(rank: int, hostnames: List[Tuple[str, ...]]) -> Tuple[int, int]: from collections import Counter - c = Counter(hostnames) + c: Counter = Counter(hostnames) sizes = torch.tensor([0,] + list(c.values())) cumsum_sizes = torch.cumsum(sizes, dim=0) node_rank = (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item() local_rank = rank - cumsum_sizes[node_rank].item() - return local_rank, node_rank + return int(local_rank), node_rank - def _compute_local_rank_via_hostname(self): + def _compute_local_rank_via_hostname(self) -> int: # get all hostnames hostnames = self._get_all_hostnames() local_rank, self._node = self._compute_node_and_local_ranks(self.get_rank(), hostnames) @@ -142,7 +141,7 @@ def _compute_local_rank_via_hostname(self): ) return local_rank - def _identify_local_rank(self): + def _identify_local_rank(self) -> None: if "SLURM_JOBID" in os.environ: os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] @@ -161,7 +160,7 @@ def _identify_local_rank(self): # use socket gethostname heuristic to determine number of nodes => local rank self._local_rank = self._compute_local_rank_via_hostname() - def setup_env_vars(self): + def setup_env_vars(self) -> None: self._env_backup = os.environ.copy() @@ -184,7 +183,7 @@ def setup_env_vars(self): os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "15000") - def _setup_env_in_slurm(self): + def _setup_env_in_slurm(self) -> None: for k in ["SLURM_PROCID", "SLURM_LOCALID", "SLURM_NTASKS", "SLURM_JOB_NODELIST"]: if k not in os.environ: raise RuntimeError("SLURM distributed configuration is missing '{}' in env variables".format(k)) @@ -227,7 +226,7 @@ def device(self) -> torch.device: def backend(self) -> str: return dist.get_backend() - def finalize(self): + def finalize(self) -> None: dist.destroy_process_group() # restore backed-up env if self._env_backup is not None: @@ -236,8 +235,18 @@ def finalize(self): @staticmethod def _dist_worker_task_fn( - local_rank, backend, fn, args, kw_dict, world_size, nprocs_per_node, node_rank, master_addr, master_port, kw - ): + local_rank: int, + backend: str, + fn: Callable, + args: Tuple, + kw_dict: Mapping, + world_size: int, + nprocs_per_node: int, + node_rank: int, + master_addr: str, + master_port: str, + **kw: Any, + ) -> None: from ignite.distributed.utils import _set_model, finalize copy_env_vars = os.environ.copy() @@ -257,7 +266,7 @@ def _dist_worker_task_fn( os.environ.update(copy_env_vars) @staticmethod - def spawn( + def spawn( # type: ignore[override] fn: Callable, args: Tuple, kwargs_dict: Optional[Mapping] = None, @@ -267,8 +276,8 @@ def spawn( master_addr: str = "127.0.0.1", master_port: int = 2222, backend: str = "nccl", - **kwargs - ): + **kwargs: Any + ) -> None: world_size = nnodes * nproc_per_node spawn_kwargs = { @@ -327,5 +336,5 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: dist.broadcast(tensor, src=src) return tensor - def barrier(self): + def barrier(self) -> None: dist.barrier() diff --git a/ignite/distributed/comp_models/xla.py b/ignite/distributed/comp_models/xla.py index 27c38ee959ca..b40b2f437b00 100644 --- a/ignite/distributed/comp_models/xla.py +++ b/ignite/distributed/comp_models/xla.py @@ -1,4 +1,4 @@ -from typing import Callable, Mapping, Optional, Tuple +from typing import Any, Callable, Mapping, Optional, Tuple import torch @@ -38,10 +38,10 @@ def create_from_context() -> Optional["_XlaDistModel"]: return _XlaDistModel() @staticmethod - def create_from_backend(backend: str = XLA_TPU, **kwargs) -> "_XlaDistModel": + def create_from_backend(backend: str = XLA_TPU, **kwargs: Any) -> "_XlaDistModel": return _XlaDistModel(backend=backend, **kwargs) - def __init__(self, backend=None, **kwargs): + def __init__(self, backend: Optional[str] = None, **kwargs: Any): """This is a private method. Please, use `create_from_backend` or `create_from_context` """ super(_XlaDistModel, self).__init__() @@ -50,17 +50,17 @@ def __init__(self, backend=None, **kwargs): else: self._init_from_context() - def _create_from_backend(self, backend, **kwargs): + def _create_from_backend(self, backend: str, **_kwargs: Any) -> None: xm.rendezvous("init") self._backend = backend self._setup_attrs() - def _init_from_context(self): + def _init_from_context(self) -> None: self._backend = XLA_TPU self._setup_attrs() - def _compute_nproc_per_node(self): + def _compute_nproc_per_node(self) -> int: tensor = torch.tensor([self.get_local_rank() + 1.0], dtype=torch.float).to(self.device()) xm.all_reduce("max", [tensor,]) return int(tensor.item()) @@ -90,11 +90,13 @@ def device(self) -> torch.device: def backend(self) -> str: return self._backend - def finalize(self): + def finalize(self) -> None: pass @staticmethod - def _dist_worker_task_fn(local_rank, backend, fn, args, kwargs_dict): + def _dist_worker_task_fn( + local_rank: int, backend: str, fn: Callable, args: Tuple, kwargs_dict: Mapping + ) -> None: from ignite.distributed.utils import _set_model, finalize model = _XlaDistModel.create_from_backend(backend) @@ -103,7 +105,7 @@ def _dist_worker_task_fn(local_rank, backend, fn, args, kwargs_dict): finalize() @staticmethod - def spawn( + def spawn( # type: ignore[override] fn: Callable, args: Tuple, kwargs_dict: Optional[Mapping] = None, @@ -111,8 +113,8 @@ def spawn( nnodes: int = 1, node_rank: int = 0, backend: str = XLA_TPU, - **kwargs - ): + **kwargs: Any + ) -> None: if "start_method" not in kwargs: kwargs["start_method"] = "fork" @@ -155,5 +157,5 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: xm.all_reduce("sum", [tensor,]) return tensor - def barrier(self): + def barrier(self) -> None: xm.rendezvous("barrier") diff --git a/ignite/distributed/launcher.py b/ignite/distributed/launcher.py index f170aa7516ed..9077781c8504 100644 --- a/ignite/distributed/launcher.py +++ b/ignite/distributed/launcher.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Any, Callable, Dict, Optional from ignite.distributed import utils as idist from ignite.utils import setup_logger @@ -181,8 +181,8 @@ def __init__( node_rank: Optional[int] = None, master_addr: Optional[str] = None, master_port: Optional[str] = None, - **spawn_kwargs - ): + **spawn_kwargs: Any + ) -> None: if backend is not None: if backend not in idist.available_backends(): raise ValueError( @@ -214,7 +214,14 @@ def __init__( self.logger.info("- Parameters to spawn processes: \n\t{}".format(msg)) @staticmethod - def _setup_spawn_params(nproc_per_node, nnodes, node_rank, master_addr, master_port, **spawn_kwargs): + def _setup_spawn_params( + nproc_per_node: int, + nnodes: Optional[int], + node_rank: Optional[int], + master_addr: Optional[str], + master_port: Optional[str], + **spawn_kwargs: Any + ) -> Dict: if nproc_per_node < 1: raise ValueError("Argument nproc_per_node should positive, but given {}".format(nproc_per_node)) if nnodes is None: @@ -244,7 +251,7 @@ def _setup_spawn_params(nproc_per_node, nnodes, node_rank, master_addr, master_p params.update(spawn_kwargs) return {k: v for k, v in params.items() if v is not None} - def run(self, func: Callable, *args, **kwargs): + def run(self, func: Callable, *args: Any, **kwargs: Any) -> None: """Execute ``func`` with provided arguments in distributed context. Example @@ -276,7 +283,7 @@ def training(local_rank, config, **kwargs): self.logger.info("End of run") - def __enter__(self): + def __enter__(self) -> "Parallel": if (self.backend is not None) and self._spawn_params is None: idist.initialize(self.backend) self.logger = setup_logger(__name__ + "." + self.__class__.__name__) @@ -284,7 +291,7 @@ def __enter__(self): return self - def __exit__(self, *args, **kwargs): + def __exit__(self, *args: Any, **kwargs: Any) -> None: if (self.backend is not None) and self._spawn_params is None: self.logger.info("Finalized processing group with backend: '{}'".format(self.backend)) idist.finalize() diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index 60481d03952b..0e870ca952f1 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -1,7 +1,7 @@ import socket from functools import wraps from numbers import Number -from typing import Callable, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, List, Mapping, Optional, Tuple, Union import torch @@ -48,7 +48,7 @@ _need_to_sync = True -def sync(temporary=False): +def sync(temporary: bool = False) -> None: """Helper method to force this module to synchronize with current distributed context. This method should be used when distributed context is manually created or destroyed. @@ -102,10 +102,10 @@ def backend() -> Optional[str]: return _model.backend() -def available_backends() -> Tuple[str]: +def available_backends() -> Tuple[str, ...]: """Returns available backends. """ - out = () + out: Tuple[str, ...] = () for m in registered_computation_models: out += m.available_backends return out @@ -190,8 +190,13 @@ def hostname() -> str: def spawn( - backend: str, fn: Callable, args: Tuple, kwargs_dict: Optional[Mapping] = None, nproc_per_node: int = 1, **kwargs -): + backend: Optional[str], + fn: Callable, + args: Tuple, + kwargs_dict: Optional[Mapping] = None, + nproc_per_node: int = 1, + **kwargs: Any +) -> None: """Spawns ``nproc_per_node`` processes that run ``fn`` with ``args``/``kwargs_dict`` and initialize distributed configuration defined by ``backend``. @@ -390,7 +395,7 @@ def broadcast(tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[t return _model.broadcast(tensor, src=src) -def barrier(): +def barrier() -> None: """Helper method to synchronize all processes. """ if _need_to_sync and isinstance(_model, _SerialModel): @@ -399,7 +404,7 @@ def barrier(): _model.barrier() -def set_local_rank(index: int): +def set_local_rank(index: int) -> None: """Method to hint the local rank in case if torch native distributed context is created by user without using :meth:`~ignite.distributed.initialize` or :meth:`~ignite.distributed.spawn`. @@ -427,7 +432,7 @@ def run(local_rank, *args, **kwargs): ComputationModel._ext_local_rank = index -def _set_model(model, temporary=False): +def _set_model(model: Any, temporary: bool = False) -> None: global _model, _need_to_sync _model = model _need_to_sync = True @@ -435,13 +440,13 @@ def _set_model(model, temporary=False): _need_to_sync = False -def _assert_backend(backend): +def _assert_backend(backend: Optional[str]) -> None: backends = available_backends() if backend not in backends: raise ValueError("Backend should be one of '{}'".format(backends)) -def initialize(backend: str, **kwargs): +def initialize(backend: str, **kwargs: Any) -> None: """Initializes distributed configuration according to provided ``backend`` Examples: @@ -495,7 +500,7 @@ def train_fn(local_rank, a, b, c): _set_model(comp_model_cls(backend, **kwargs)) -def finalize(): +def finalize() -> None: """Finalizes distributed configuration. For example, in case of native pytorch distributed configuration, it calls ``dist.destroy_process_group()``. """ @@ -503,7 +508,7 @@ def finalize(): _set_model(_SerialModel()) -def show_config(): +def show_config() -> None: """Helper method to display distributed configuration via ``logging``. """ @@ -522,7 +527,7 @@ def show_config(): logger.info("node rank: {}".format(get_node_rank())) -def one_rank_only(rank: int = 0, with_barrier: bool = False): +def one_rank_only(rank: int = 0, with_barrier: bool = False) -> Optional[Callable]: """Decorator to filter handlers wrt a rank number Args: @@ -544,9 +549,9 @@ def some_handler(_): ... """ - def _one_rank_only(func): + def _one_rank_only(func: Callable) -> Optional[Callable]: @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Optional[Callable]: ret = None if get_rank() == rank: ret = func(*args, **kwargs) diff --git a/mypy.ini b/mypy.ini index 586ae4633cf3..3a53733e935c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -23,9 +23,11 @@ ignore_errors = True ignore_errors = True -[mypy-ignite.distributed.*] - -ignore_errors = True +[mypy-horovod.*] +ignore_missing_imports = True [mypy-numpy.*] ignore_missing_imports = True + +[mypy-torch_xla.*] +ignore_missing_imports = True From d627d6ab596cded3eec7cebbc108fdd6c740ae15 Mon Sep 17 00:00:00 2001 From: gruebel Date: Sun, 4 Oct 2020 23:45:44 +0200 Subject: [PATCH 2/6] Fix tests & py3.5 inline type hints --- ignite/distributed/auto.py | 82 ++++++++++++------------ ignite/distributed/comp_models/base.py | 4 +- ignite/distributed/comp_models/native.py | 8 +-- ignite/distributed/launcher.py | 4 +- ignite/distributed/utils.py | 2 +- 5 files changed, 50 insertions(+), 50 deletions(-) diff --git a/ignite/distributed/auto.py b/ignite/distributed/auto.py index 273301ea1b04..a4fce442f0b5 100644 --- a/ignite/distributed/auto.py +++ b/ignite/distributed/auto.py @@ -17,42 +17,7 @@ __all__ = ["auto_dataloader", "auto_model", "auto_optim", "DistributedProxySampler"] -if idist.has_xla_support: - - import torch_xla.core.xla_model as xm - from torch_xla.distributed.parallel_loader import ParallelLoader - - class _MpDeviceLoader: - # https://github.com/pytorch/xla/pull/2117 - # From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available - def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None: - self._loader = loader - self._device = device - self._parallel_loader_kwargs = kwargs - - def __setattr__(self, name: str, value: Any) -> None: - super().__setattr__(name, value) - - def __getattr__(self, name: str) -> Any: - super().__getattribute__(name) - - def __iter__(self) -> Iterator: - parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs) - return parallel_loader.per_device_loader(self._device) - - def __len__(self) -> int: - return len(self._loader) - - class _XLADistributedOptimizer(Optimizer): - def __init__(self, optimizer: Optimizer) -> None: - super(self.__class__, self).__init__(optimizer.param_groups, {}) - self.wrapped_optimizer = optimizer - - def step(self, closure: Optional[Callable] = None) -> None: - xm.optimizer_step(self.wrapped_optimizer, barrier=True) - - -def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, _MpDeviceLoader]: +def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, "_MpDeviceLoader"]: """Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting all available backends from :meth:`~ignite.distributed.utils.available_backends()`). @@ -109,10 +74,10 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, _MpDev kwargs["num_workers"] = (kwargs["num_workers"] + nproc - 1) // nproc if "batch_sampler" not in kwargs: - sampler: Union[DistributedProxySampler, DistributedSampler] - if kwargs.get("sampler", None) is not None: - sampler = DistributedProxySampler(kwargs["sampler"], num_replicas=world_size, rank=rank) + sampler = DistributedProxySampler( + kwargs["sampler"], num_replicas=world_size, rank=rank + ) # type: Union[DistributedProxySampler, DistributedSampler] else: sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=kwargs.get("shuffle", True) @@ -139,7 +104,7 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, _MpDev kwargs["pin_memory"] = kwargs.get("pin_memory", "cuda" in idist.device().type) logger.info("Use data loader kwargs for dataset '{}': \n\t{}".format(repr(dataset)[:20].strip(), kwargs)) - dataloader: Union[DataLoader, _MpDeviceLoader] = DataLoader(dataset, **kwargs) + dataloader = DataLoader(dataset, **kwargs) # type: Union[DataLoader, "_MpDeviceLoader"] if idist.has_xla_support and idist.backend() == idist_xla.XLA_TPU and world_size > 1: @@ -327,7 +292,7 @@ def __iter__(self) -> Iterator: # deterministically shuffle based on epoch torch.manual_seed(self.epoch) - indices: List = [] + indices = [] # type: List while len(indices) < self.total_size: indices += list(self.sampler) @@ -340,3 +305,38 @@ def __iter__(self) -> Iterator: raise RuntimeError("{} vs {}".format(len(indices), self.num_samples)) return iter(indices) + + +if idist.has_xla_support: + + import torch_xla.core.xla_model as xm + from torch_xla.distributed.parallel_loader import ParallelLoader + + class _MpDeviceLoader: + # https://github.com/pytorch/xla/pull/2117 + # From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available + def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None: + self._loader = loader + self._device = device + self._parallel_loader_kwargs = kwargs + + def __setattr__(self, name: str, value: Any) -> None: + super().__setattr__(name, value) + + def __getattr__(self, name: str) -> Any: + super().__getattribute__(name) + + def __iter__(self) -> Iterator: + parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs) + return parallel_loader.per_device_loader(self._device) + + def __len__(self) -> int: + return len(self._loader) + + class _XLADistributedOptimizer(Optimizer): + def __init__(self, optimizer: Optimizer) -> None: + super(self.__class__, self).__init__(optimizer.param_groups, {}) + self.wrapped_optimizer = optimizer + + def step(self, closure: Optional[Callable] = None) -> None: + xm.optimizer_step(self.wrapped_optimizer, barrier=True) diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index f39bb71fe758..ac0b318d8be9 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -13,7 +13,7 @@ class ComputationModel(metaclass=ABCMeta): """ # this is an additional local rank storage used when idist is setup from existing native torch dist context - _ext_local_rank: Optional[int] = None + _ext_local_rank = None # type: Optional[int] def __init__(self): self._backend = None @@ -84,7 +84,7 @@ def create_from_backend(backend: str, **kwargs: Any) -> "ComputationModel": def spawn(*args: Any, **kwargs: Any) -> None: pass - _collective_op_dtype: Any = None + _collective_op_dtype = None # type: Any @staticmethod def _encode_str(x: str, device: torch.device) -> torch.Tensor: diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index acdd4c715e35..cbf31e43ee7b 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -57,7 +57,7 @@ def __init__(self, backend: Optional[str] = None, timeout: Optional[int] = None, """This is a private method. Please, use `create_from_backend` or `create_from_context` """ super(_NativeDistModel, self).__init__() - self._env_backup: Optional[Dict[str, str]] = None + self._env_backup = None # type: Optional[Dict[str, str]] if backend is not None: self._create_from_backend(backend, timeout=timeout, **kwargs) else: @@ -71,8 +71,8 @@ def _create_from_backend(self, backend: str, timeout: Optional[int] = None, **_k self._local_rank = int(os.environ["LOCAL_RANK"]) # for debug purposes - self._master_port: Optional[int] = int(os.environ["MASTER_PORT"]) - self._master_addr: Optional[str] = os.environ["MASTER_ADDR"] + self._master_port = int(os.environ["MASTER_PORT"]) # type: Optional[int] + self._master_addr = os.environ["MASTER_ADDR"] # type: Optional[str] init_pg_kwargs = {} if timeout is not None: @@ -120,7 +120,7 @@ def _get_all_hostnames(self) -> List[Tuple[str, ...]]: def _compute_node_and_local_ranks(rank: int, hostnames: List[Tuple[str, ...]]) -> Tuple[int, int]: from collections import Counter - c: Counter = Counter(hostnames) + c = Counter(hostnames) # type: Counter sizes = torch.tensor([0,] + list(c.values())) cumsum_sizes = torch.cumsum(sizes, dim=0) node_rank = (rank // cumsum_sizes[1:]).clamp(0, 1).sum().item() diff --git a/ignite/distributed/launcher.py b/ignite/distributed/launcher.py index 9077781c8504..0cd72ebe4824 100644 --- a/ignite/distributed/launcher.py +++ b/ignite/distributed/launcher.py @@ -180,7 +180,7 @@ def __init__( nnodes: Optional[int] = None, node_rank: Optional[int] = None, master_addr: Optional[str] = None, - master_port: Optional[str] = None, + master_port: Optional[int] = None, **spawn_kwargs: Any ) -> None: if backend is not None: @@ -219,7 +219,7 @@ def _setup_spawn_params( nnodes: Optional[int], node_rank: Optional[int], master_addr: Optional[str], - master_port: Optional[str], + master_port: Optional[int], **spawn_kwargs: Any ) -> Dict: if nproc_per_node < 1: diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index 0e870ca952f1..7eadf918cbd9 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -105,7 +105,7 @@ def backend() -> Optional[str]: def available_backends() -> Tuple[str, ...]: """Returns available backends. """ - out: Tuple[str, ...] = () + out = () # type: Tuple[str, ...] for m in registered_computation_models: out += m.available_backends return out From b977b746edcbec9c29b8b88f40157f7d7db2601f Mon Sep 17 00:00:00 2001 From: gruebel Date: Mon, 5 Oct 2020 10:55:53 +0200 Subject: [PATCH 3/6] Remove typing,overload --- ignite/distributed/comp_models/base.py | 35 +++--------------------- ignite/distributed/comp_models/native.py | 4 +-- 2 files changed, 6 insertions(+), 33 deletions(-) diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index ac0b318d8be9..1a6509d2669d 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -1,7 +1,7 @@ import warnings from abc import ABCMeta, abstractmethod from numbers import Number -from typing import Any, Callable, List, Optional, Union, cast, overload +from typing import Any, Callable, List, Optional, Union, cast import torch @@ -134,20 +134,9 @@ def _apply_op( return tensor.to(device=tensor_device) return tensor - @overload - def _collective_op( - self, tensor: Union[torch.Tensor, Number], fn: Callable, *args: Any, **kwargs: Any - ) -> Union[torch.Tensor, Number]: - ... - - @overload def _collective_op( self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any ) -> Union[torch.Tensor, Number, List[str]]: - ... - - # mypy doesn't support overload for no-untyped-def check - def _collective_op(self, tensor, fn, *args, **kwargs): # type: ignore tensor_to_number = tensor_to_str = False device = self.device() if isinstance(tensor, Number): @@ -160,7 +149,7 @@ def _collective_op(self, tensor, fn, *args, **kwargs): # type: ignore tensor = self._apply_op(tensor, device, fn, *args, **kwargs) if tensor_to_number and tensor.numel() == 1: - return tensor.item() + return cast(Number, tensor.item()) elif tensor_to_str: return self._decode_str(tensor) return tensor @@ -169,7 +158,7 @@ def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Un if not isinstance(tensor, (torch.Tensor, Number)): raise TypeError("Unhandled input type {}".format(type(tensor))) - return self._collective_op(tensor, self._do_all_reduce, op) + return cast(Union[torch.Tensor, Number], self._collective_op(tensor, self._do_all_reduce, op)) def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]: if not isinstance(tensor, (torch.Tensor, Number, str)): @@ -177,15 +166,7 @@ def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Te return self._collective_op(tensor, self._do_all_gather) - @overload - def broadcast(self, tensor: Union[torch.Tensor, Number], src: int = 0) -> Union[torch.Tensor, Number]: - ... - - @overload def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]: - ... - - def broadcast(self, tensor, src=0): # type: ignore # mypy doesn't support overload for no-untyped-def check if not isinstance(tensor, (torch.Tensor, Number, str)): raise TypeError("Unhandled input type {}".format(type(tensor))) @@ -210,7 +191,7 @@ def broadcast(self, tensor, src=0): # type: ignore # mypy doesn't support overl tensor = self._apply_op(tensor, device, self._do_broadcast, src) if tensor_to_number: - return tensor.item() + return cast(Number, tensor.item()) if tensor_to_str: list_str = self._decode_str(tensor) return list_str[0] @@ -293,15 +274,7 @@ def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Un def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]: return cast(Union[torch.Tensor, Number], tensor) - @overload - def broadcast(self, tensor: Union[torch.Tensor, Number], src: int = 0) -> Union[torch.Tensor, Number]: - ... - - @overload def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]: - ... - - def broadcast(self, tensor, src=0): # type: ignore # mypy doesn't support overload for no-untyped-def check return tensor def _do_all_reduce(self, tensor: torch.Tensor, op: str = "sum") -> torch.Tensor: diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index cbf31e43ee7b..677f9666b2bd 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -2,7 +2,7 @@ import subprocess import warnings from distutils.version import LooseVersion -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, overload +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple import torch import torch.distributed as dist @@ -245,7 +245,7 @@ def _dist_worker_task_fn( node_rank: int, master_addr: str, master_port: str, - **kw: Any, + kw: Any, ) -> None: from ignite.distributed.utils import _set_model, finalize From 8ae6caa7a9b9467bda117e3b7870db18cb4b8bd1 Mon Sep 17 00:00:00 2001 From: gruebel Date: Tue, 6 Oct 2020 00:03:20 +0200 Subject: [PATCH 4/6] Fix multiple typing issues --- ignite/distributed/auto.py | 8 +------- ignite/distributed/comp_models/base.py | 6 +++--- ignite/distributed/comp_models/native.py | 2 +- ignite/distributed/comp_models/xla.py | 2 +- ignite/distributed/launcher.py | 10 +++++----- ignite/distributed/utils.py | 12 ++++++------ 6 files changed, 17 insertions(+), 23 deletions(-) diff --git a/ignite/distributed/auto.py b/ignite/distributed/auto.py index a4fce442f0b5..57ab894ccacf 100644 --- a/ignite/distributed/auto.py +++ b/ignite/distributed/auto.py @@ -120,7 +120,7 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, "_MpDe sampler = dataloader.sampler # type: ignore[union-attr] dataloader = mp_device_loader_cls(dataloader, idist.device()) - dataloader.sampler = sampler + dataloader.sampler = sampler # type: ignore[attr-defined] return dataloader @@ -320,12 +320,6 @@ def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None: self._device = device self._parallel_loader_kwargs = kwargs - def __setattr__(self, name: str, value: Any) -> None: - super().__setattr__(name, value) - - def __getattr__(self, name: str) -> Any: - super().__getattribute__(name) - def __iter__(self) -> Iterator: parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs) return parallel_loader.per_device_loader(self._device) diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index 1a6509d2669d..f31b074c398e 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -221,7 +221,7 @@ class _SerialModel(ComputationModel): name = "serial" available_backends = () - def __init__(self, _backend: Optional[str] = None, **_kwargs: Any) -> None: + def __init__(self, _backend: Optional[str] = None, **kwargs: Any) -> None: super(_SerialModel, self).__init__() def get_local_rank(self) -> int: @@ -271,8 +271,8 @@ def spawn(*args: Any, **kwargs: Any) -> None: def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Union[torch.Tensor, Number]: return tensor - def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]: - return cast(Union[torch.Tensor, Number], tensor) + def all_gather(self, tensor: Union[torch.Tensor, Number]) -> Union[torch.Tensor, Number]: # type: ignore + return tensor def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]: return tensor diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index 677f9666b2bd..e64ddeb7a697 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -63,7 +63,7 @@ def __init__(self, backend: Optional[str] = None, timeout: Optional[int] = None, else: self._init_from_context() - def _create_from_backend(self, backend: str, timeout: Optional[int] = None, **_kwargs: Any) -> None: + def _create_from_backend(self, backend: str, timeout: Optional[int] = None, **kwargs: Any) -> None: if backend == dist.Backend.NCCL and not torch.cuda.is_available(): raise RuntimeError("Nccl backend is required but no cuda capable devices") diff --git a/ignite/distributed/comp_models/xla.py b/ignite/distributed/comp_models/xla.py index b40b2f437b00..533defdb61db 100644 --- a/ignite/distributed/comp_models/xla.py +++ b/ignite/distributed/comp_models/xla.py @@ -50,7 +50,7 @@ def __init__(self, backend: Optional[str] = None, **kwargs: Any): else: self._init_from_context() - def _create_from_backend(self, backend: str, **_kwargs: Any) -> None: + def _create_from_backend(self, backend: str, **kwargs: Any) -> None: xm.rendezvous("init") self._backend = backend diff --git a/ignite/distributed/launcher.py b/ignite/distributed/launcher.py index 0cd72ebe4824..643650fd14fa 100644 --- a/ignite/distributed/launcher.py +++ b/ignite/distributed/launcher.py @@ -216,10 +216,10 @@ def __init__( @staticmethod def _setup_spawn_params( nproc_per_node: int, - nnodes: Optional[int], - node_rank: Optional[int], - master_addr: Optional[str], - master_port: Optional[int], + nnodes: Optional[int] = None, + node_rank: Optional[int] = None, + master_addr: Optional[str] = None, + master_port: Optional[int] = None, **spawn_kwargs: Any ) -> Dict: if nproc_per_node < 1: @@ -273,7 +273,7 @@ def training(local_rank, config, **kwargs): **kwargs: keyword arguments of ``func``. """ - if self._spawn_params is not None: + if self._spawn_params is not None and self.backend is not None: self.logger.info("Spawn function '{}' in {} processes".format(func, self._spawn_params["nproc_per_node"])) idist.spawn(self.backend, func, args=args, kwargs_dict=kwargs, **self._spawn_params) else: diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index 7eadf918cbd9..ec17dbe9a29b 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -190,7 +190,7 @@ def hostname() -> str: def spawn( - backend: Optional[str], + backend: str, fn: Callable, args: Tuple, kwargs_dict: Optional[Mapping] = None, @@ -349,7 +349,7 @@ def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, if _need_to_sync and isinstance(_model, _SerialModel): sync(temporary=True) - return _model.all_gather(tensor) + return _model.all_gather(tensor) # type: ignore[arg-type] def broadcast(tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]: @@ -440,7 +440,7 @@ def _set_model(model: Any, temporary: bool = False) -> None: _need_to_sync = False -def _assert_backend(backend: Optional[str]) -> None: +def _assert_backend(backend: str) -> None: backends = available_backends() if backend not in backends: raise ValueError("Backend should be one of '{}'".format(backends)) @@ -527,7 +527,7 @@ def show_config() -> None: logger.info("node rank: {}".format(get_node_rank())) -def one_rank_only(rank: int = 0, with_barrier: bool = False) -> Optional[Callable]: +def one_rank_only(rank: int = 0, with_barrier: bool = False) -> Callable: """Decorator to filter handlers wrt a rank number Args: @@ -549,9 +549,9 @@ def some_handler(_): ... """ - def _one_rank_only(func: Callable) -> Optional[Callable]: + def _one_rank_only(func: Callable) -> Callable: @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Optional[Callable]: + def wrapper(*args: Any, **kwargs: Any) -> Optional[Any]: ret = None if get_rank() == rank: ret = func(*args, **kwargs) From e3693635dc5898f6ae1fd8bfdf718cac43f61505 Mon Sep 17 00:00:00 2001 From: gruebel Date: Tue, 6 Oct 2020 14:07:20 +0200 Subject: [PATCH 5/6] Fix typing issues --- ignite/distributed/auto.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/ignite/distributed/auto.py b/ignite/distributed/auto.py index 57ab894ccacf..11d25199ecda 100644 --- a/ignite/distributed/auto.py +++ b/ignite/distributed/auto.py @@ -77,7 +77,7 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, "_MpDe if kwargs.get("sampler", None) is not None: sampler = DistributedProxySampler( kwargs["sampler"], num_replicas=world_size, rank=rank - ) # type: Union[DistributedProxySampler, DistributedSampler] + ) # type: Union[DistributedProxySampler, DistributedSampler, Sampler] else: sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=kwargs.get("shuffle", True) @@ -104,7 +104,7 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, "_MpDe kwargs["pin_memory"] = kwargs.get("pin_memory", "cuda" in idist.device().type) logger.info("Use data loader kwargs for dataset '{}': \n\t{}".format(repr(dataset)[:20].strip(), kwargs)) - dataloader = DataLoader(dataset, **kwargs) # type: Union[DataLoader, "_MpDeviceLoader"] + dataloader = DataLoader(dataset, **kwargs) if idist.has_xla_support and idist.backend() == idist_xla.XLA_TPU and world_size > 1: @@ -118,9 +118,9 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, "_MpDe except ImportError: pass - sampler = dataloader.sampler # type: ignore[union-attr] - dataloader = mp_device_loader_cls(dataloader, idist.device()) - dataloader.sampler = sampler # type: ignore[attr-defined] + mp_dataloader = mp_device_loader_cls(dataloader, idist.device()) + mp_dataloader.sampler = dataloader.sampler # type: ignore[attr-defined] + return mp_dataloader return dataloader @@ -282,27 +282,21 @@ def __init__(self, sampler: Sampler, num_replicas: Optional[int] = None, rank: O ) self.sampler = sampler - def __setattr__(self, name: str, value: Any) -> None: - super().__setattr__(name, value) - - def __getattr__(self, name: str) -> Any: - super().__getattribute__(name) - def __iter__(self) -> Iterator: # deterministically shuffle based on epoch - torch.manual_seed(self.epoch) + torch.manual_seed(self.epoch) # type: ignore[attr-defined] indices = [] # type: List - while len(indices) < self.total_size: + while len(indices) < self.total_size: # type: ignore[attr-defined] indices += list(self.sampler) - if len(indices) > self.total_size: - indices = indices[: self.total_size] + if len(indices) > self.total_size: # type: ignore[attr-defined] + indices = indices[: self.total_size] # type: ignore[attr-defined] # subsample - indices = indices[self.rank : self.total_size : self.num_replicas] - if len(indices) != self.num_samples: - raise RuntimeError("{} vs {}".format(len(indices), self.num_samples)) + indices = indices[self.rank : self.total_size : self.num_replicas] # type: ignore[attr-defined] + if len(indices) != self.num_samples: # type: ignore[attr-defined] + raise RuntimeError("{} vs {}".format(len(indices), self.num_samples)) # type: ignore[attr-defined] return iter(indices) From 2b6805bcd7d1e17c7f9ecda277f77839c1a3bfc4 Mon Sep 17 00:00:00 2001 From: gruebel Date: Tue, 6 Oct 2020 15:05:49 +0200 Subject: [PATCH 6/6] Fix TPU test --- ignite/distributed/auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/distributed/auto.py b/ignite/distributed/auto.py index 11d25199ecda..cc888ce0346c 100644 --- a/ignite/distributed/auto.py +++ b/ignite/distributed/auto.py @@ -323,7 +323,7 @@ def __len__(self) -> int: class _XLADistributedOptimizer(Optimizer): def __init__(self, optimizer: Optimizer) -> None: - super(self.__class__, self).__init__(optimizer.param_groups, {}) + super(self.__class__, self).__init__(optimizer.param_groups) # type: ignore[call-arg] self.wrapped_optimizer = optimizer def step(self, closure: Optional[Callable] = None) -> None: