diff --git a/mypy.ini b/mypy.ini index cbbd502fd4e5..248a83300c5e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -53,18 +53,6 @@ ignore_errors = True [mypy-torch.backends._nnapi.*] ignore_errors = True -[mypy-torch.distributed.*] -ignore_errors = True - -[mypy-torch.distributed.rpc.*] -ignore_errors = False - -[mypy-torch.distributed.distributed_c10d.*] -ignore_errors = False - -[mypy-torch.distributed.nn.*] -ignore_errors = False - [mypy-torch.testing._internal.hypothesis_utils.*] ignore_errors = True diff --git a/torch/distributed/_pipeline/sync/batchnorm.py b/torch/distributed/_pipeline/sync/batchnorm.py index 487c3d096d98..4e1cf7b09879 100644 --- a/torch/distributed/_pipeline/sync/batchnorm.py +++ b/torch/distributed/_pipeline/sync/batchnorm.py @@ -27,6 +27,8 @@ class DeferredBatchNorm(_BatchNorm): sum: Tensor sum_squares: Tensor + running_mean: Tensor + running_var: Tensor def __init__( self, @@ -149,6 +151,8 @@ def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModul if module.affine: module_output.register_parameter("weight", module.weight) module_output.register_parameter("bias", module.bias) + assert isinstance(module.running_mean, Tensor) + assert isinstance(module.running_var, Tensor) module_output.register_buffer("running_mean", module.running_mean) module_output.register_buffer("running_var", module.running_var) module_output.register_buffer("num_batches_tracked", module.num_batches_tracked) diff --git a/torch/distributed/_pipeline/sync/checkpoint.py b/torch/distributed/_pipeline/sync/checkpoint.py index 08e95e2d18fa..bad5eec19469 100644 --- a/torch/distributed/_pipeline/sync/checkpoint.py +++ b/torch/distributed/_pipeline/sync/checkpoint.py @@ -30,7 +30,7 @@ from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union import torch -from torch import ByteTensor, Tensor +from torch import Tensor import torch.autograd from .dependency import fork, join @@ -45,7 +45,7 @@ # Types for shared memory between Checkpoint and Recompute. Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf) -RNGStates = Tuple[ByteTensor, Optional[ByteTensor]] # (cpu_rng_state, gpu_rng_state) +RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state) if TYPE_CHECKING: @@ -207,7 +207,7 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None """ cpu_rng_state = torch.get_rng_state() - gpu_rng_state: Optional[ByteTensor] + gpu_rng_state: Optional[Tensor] if device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state(device) else: diff --git a/torch/distributed/_pipeline/sync/pipe.py b/torch/distributed/_pipeline/sync/pipe.py index 500b15b72771..d8c555cd2f74 100644 --- a/torch/distributed/_pipeline/sync/pipe.py +++ b/torch/distributed/_pipeline/sync/pipe.py @@ -30,7 +30,8 @@ TensorOrTensors = Union[Tensor, Tensors] if TYPE_CHECKING: - Module = nn.Module[TensorOrTensors] + # Typechecking: nn.Module is not a Generic + Module = nn.Module[TensorOrTensors] # type: ignore[type-arg] NamedModules = OrderedDict[str, Module] else: Module = nn.Module diff --git a/torch/distributed/_pipeline/sync/skip/skippable.py b/torch/distributed/_pipeline/sync/skip/skippable.py index b5d07ff9c7a0..9bb258382b9b 100644 --- a/torch/distributed/_pipeline/sync/skip/skippable.py +++ b/torch/distributed/_pipeline/sync/skip/skippable.py @@ -39,7 +39,8 @@ StashPop = Union["stash", "pop"] StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors] if TYPE_CHECKING: - SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] + # Typechecking: nn.Module is not a Generic + SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] # type: ignore[type-arg] else: SkippableModule = nn.Module diff --git a/torch/distributed/_pipeline/sync/stream.py b/torch/distributed/_pipeline/sync/stream.py index 0de4496808a0..41e1591793b6 100644 --- a/torch/distributed/_pipeline/sync/stream.py +++ b/torch/distributed/_pipeline/sync/stream.py @@ -104,7 +104,8 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: # tensor = tensor.new_empty([0]).set_(tensor.storage()) - tensor.record_stream(as_cuda(stream)) + # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream + tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type] def is_cuda(stream: AbstractStream) -> bool: diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index f3d69ccd51de..4e6cbd72aee6 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -3,7 +3,7 @@ def allreduce_hook( - process_group: object, bucket: dist._GradBucket + process_group: dist.ProcessGroup, bucket: dist._GradBucket ) -> torch.futures.Future: """ This DDP communication hook just calls ``allreduce`` using ``GradBucket`` @@ -32,7 +32,7 @@ def then_callback(fut): def fp16_compress_hook( - process_group: object, bucket: dist._GradBucket + process_group: dist.ProcessGroup, bucket: dist._GradBucket ) -> torch.futures.Future: """ This DDP communication hook implements a simple gradient compression @@ -79,7 +79,7 @@ def _get_allgather_out_list(all_gather_in_list, world_size): def _allgather_then_aggregate_hook( - process_group: object, bucket: dist._GradBucket + process_group: dist.ProcessGroup, bucket: dist._GradBucket ) -> torch.futures.Future: """ Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors diff --git a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py index e9a8fa3d4674..cde3b79fc7ce 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py @@ -43,7 +43,7 @@ def _get_allgather_out_list(all_gather_in_list, world_size): def quantization_pertensor_hook( - process_group: object, bucket: dist._GradBucket + process_group: dist.ProcessGroup, bucket: dist._GradBucket ) -> torch.futures.Future: """ Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` @@ -118,7 +118,7 @@ def dequantize_and_aggregate(fut): def quantization_perchannel_hook( - process_group: object, bucket: dist._GradBucket, bucket_size=512 + process_group: dist.ProcessGroup, bucket: dist._GradBucket, bucket_size=512 ) -> torch.futures.Future: """ Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather`` diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index 6219506410e8..e90bc483c7f1 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -141,6 +141,7 @@ import subprocess import os from argparse import ArgumentParser, REMAINDER +from typing import Optional, IO node_local_rank_stdout_filename = "node_{}_local_rank_{}_stdout" node_local_rank_stderr_filename = "node_{}_local_rank_{}_stderr" @@ -269,6 +270,8 @@ def main(): cmd.extend(args.training_script_args) + stdout_handle: Optional[IO] + stderr_handle: Optional[IO] if args.logdir: directory_path = os.path.join(os.getcwd(), args.logdir) node_rank = args.node_rank diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index f5845fa5a79b..55705f987a6e 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -1,12 +1,14 @@ try: from urllib.parse import urlparse, urlunparse except ImportError: - from urlparse import urlparse, urlunparse + raise ImportError("urllib cannot be found, urlparse from python2 is no longer supported.") import torch._six as six import numbers import os import sys +from datetime import timedelta +from typing import Optional, Dict, Union from torch._C._distributed_c10d import FileStore from .constants import default_pg_timeout @@ -47,7 +49,7 @@ def register_rendezvous_handler(scheme, handler): _rendezvous_handlers[scheme] = handler -def rendezvous(url, rank=-1, world_size=-1, **kwargs): +def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): if not isinstance(url, six.string_classes): raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url)) @@ -60,8 +62,9 @@ def rendezvous(url, rank=-1, world_size=-1, **kwargs): # Append node-specific arguments. result = urlparse(url) if rank != -1 or world_size != -1: - query_dict = dict( - pair.split("=") for pair in filter(None, result.query.split("&")) + query_dict: Dict[str, Union[int, str]] = dict( + # mypy doesn't allow dict() to accept List of values (#257) + pair.split("=") for pair in filter(None, result.query.split("&")) # type: ignore[arg-type, misc] ) assert ( "rank" not in query_dict and "world_size" not in query_dict @@ -87,7 +90,7 @@ def _rendezvous_error(msg): return ValueError("Error initializing torch.distributed using " + msg) -def _file_rendezvous_handler(url, **kwargs): +def _file_rendezvous_handler(url: str, **kwargs): def _error(msg): return _rendezvous_error("file:// rendezvous: " + msg) @@ -99,7 +102,9 @@ def _error(msg): if not path: raise _error("path missing") - query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) + query: Dict[str, str] + # mypy doesn't allow dict() to accept List of values (#257) + query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type] if "rank" not in query: raise _error("rank parameter missing") if "world_size" not in query: @@ -114,14 +119,16 @@ def _error(msg): raise RuntimeError("Unable to perform rerendezvous using file:// method") -def _tcp_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs): +def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs): def _error(msg): return _rendezvous_error("tcp:// rendezvous: " + msg) result = urlparse(url) if not result.port: raise _error("port number missing") - query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) + query: Dict[str, Union[int, str]] + # mypy doesn't allow dict() to accept List of values (#257) + query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type] if "rank" not in query: raise _error("rank parameter missing") if "world_size" not in query: @@ -130,6 +137,7 @@ def _error(msg): rank = int(query["rank"]) world_size = int(query["world_size"]) start_daemon = rank == 0 + assert result.hostname is not None store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout) yield (store, rank, world_size) @@ -137,7 +145,7 @@ def _error(msg): raise RuntimeError("Unable to perform rerendezvous using tcp:// method") -def _env_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs): +def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs): def _error(msg): return _rendezvous_error("env:// rendezvous: " + msg) @@ -145,7 +153,13 @@ def _env_error(var): return _error("environment variable %s expected, but not set" % var) result = urlparse(url) - query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) + query: Dict[str, Union[int, str]] + # mypy doesn't allow dict() to accept List of values (#257) + query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type] + + rank: Optional[Union[str, int]] + world_size: Optional[Union[str, int]] + master_port: Optional[Union[str, int]] if "rank" in query: rank = int(query["rank"])