Skip to content

Commit

Permalink
Fix typing errors in torch.distributed.*, close issue #42967. (#47534)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #47534

Test Plan: Imported from OSS

Reviewed By: walterddr

Differential Revision: D24952497

Pulled By: xuzhao9

fbshipit-source-id: 063bfd0707198436fcfd9431f72f9a392bc0017e
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 17, 2020
1 parent 7f66fa6 commit 49f0e5d
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 33 deletions.
12 changes: 0 additions & 12 deletions mypy.ini
Expand Up @@ -53,18 +53,6 @@ ignore_errors = True
[mypy-torch.backends._nnapi.*]
ignore_errors = True

[mypy-torch.distributed.*]
ignore_errors = True

[mypy-torch.distributed.rpc.*]
ignore_errors = False

[mypy-torch.distributed.distributed_c10d.*]
ignore_errors = False

[mypy-torch.distributed.nn.*]
ignore_errors = False

[mypy-torch.testing._internal.hypothesis_utils.*]
ignore_errors = True

Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/_pipeline/sync/batchnorm.py
Expand Up @@ -27,6 +27,8 @@ class DeferredBatchNorm(_BatchNorm):

sum: Tensor
sum_squares: Tensor
running_mean: Tensor
running_var: Tensor

def __init__(
self,
Expand Down Expand Up @@ -149,6 +151,8 @@ def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModul
if module.affine:
module_output.register_parameter("weight", module.weight)
module_output.register_parameter("bias", module.bias)
assert isinstance(module.running_mean, Tensor)
assert isinstance(module.running_var, Tensor)
module_output.register_buffer("running_mean", module.running_mean)
module_output.register_buffer("running_var", module.running_var)
module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/_pipeline/sync/checkpoint.py
Expand Up @@ -30,7 +30,7 @@
from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union

import torch
from torch import ByteTensor, Tensor
from torch import Tensor
import torch.autograd

from .dependency import fork, join
Expand All @@ -45,7 +45,7 @@

# Types for shared memory between Checkpoint and Recompute.
Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf)
RNGStates = Tuple[ByteTensor, Optional[ByteTensor]] # (cpu_rng_state, gpu_rng_state)
RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state)


if TYPE_CHECKING:
Expand Down Expand Up @@ -207,7 +207,7 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None
"""
cpu_rng_state = torch.get_rng_state()

gpu_rng_state: Optional[ByteTensor]
gpu_rng_state: Optional[Tensor]
if device.type == "cuda":
gpu_rng_state = torch.cuda.get_rng_state(device)
else:
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_pipeline/sync/pipe.py
Expand Up @@ -30,7 +30,8 @@
TensorOrTensors = Union[Tensor, Tensors]

if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
# Typechecking: nn.Module is not a Generic
Module = nn.Module[TensorOrTensors] # type: ignore[type-arg]
NamedModules = OrderedDict[str, Module]
else:
Module = nn.Module
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_pipeline/sync/skip/skippable.py
Expand Up @@ -39,7 +39,8 @@
StashPop = Union["stash", "pop"]
StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors]
if TYPE_CHECKING:
SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]]
# Typechecking: nn.Module is not a Generic
SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] # type: ignore[type-arg]
else:
SkippableModule = nn.Module

Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_pipeline/sync/stream.py
Expand Up @@ -104,7 +104,8 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
#
tensor = tensor.new_empty([0]).set_(tensor.storage())

tensor.record_stream(as_cuda(stream))
# Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream
tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type]


def is_cuda(stream: AbstractStream) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
Expand Up @@ -3,7 +3,7 @@


def allreduce_hook(
process_group: object, bucket: dist._GradBucket
process_group: dist.ProcessGroup, bucket: dist._GradBucket
) -> torch.futures.Future:
"""
This DDP communication hook just calls ``allreduce`` using ``GradBucket``
Expand Down Expand Up @@ -32,7 +32,7 @@ def then_callback(fut):


def fp16_compress_hook(
process_group: object, bucket: dist._GradBucket
process_group: dist.ProcessGroup, bucket: dist._GradBucket
) -> torch.futures.Future:
"""
This DDP communication hook implements a simple gradient compression
Expand Down Expand Up @@ -79,7 +79,7 @@ def _get_allgather_out_list(all_gather_in_list, world_size):


def _allgather_then_aggregate_hook(
process_group: object, bucket: dist._GradBucket
process_group: dist.ProcessGroup, bucket: dist._GradBucket
) -> torch.futures.Future:
"""
Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors
Expand Down
Expand Up @@ -43,7 +43,7 @@ def _get_allgather_out_list(all_gather_in_list, world_size):


def quantization_pertensor_hook(
process_group: object, bucket: dist._GradBucket
process_group: dist.ProcessGroup, bucket: dist._GradBucket
) -> torch.futures.Future:
"""
Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather``
Expand Down Expand Up @@ -118,7 +118,7 @@ def dequantize_and_aggregate(fut):


def quantization_perchannel_hook(
process_group: object, bucket: dist._GradBucket, bucket_size=512
process_group: dist.ProcessGroup, bucket: dist._GradBucket, bucket_size=512
) -> torch.futures.Future:
"""
Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather``
Expand Down
3 changes: 3 additions & 0 deletions torch/distributed/launch.py
Expand Up @@ -141,6 +141,7 @@
import subprocess
import os
from argparse import ArgumentParser, REMAINDER
from typing import Optional, IO

node_local_rank_stdout_filename = "node_{}_local_rank_{}_stdout"
node_local_rank_stderr_filename = "node_{}_local_rank_{}_stderr"
Expand Down Expand Up @@ -269,6 +270,8 @@ def main():

cmd.extend(args.training_script_args)

stdout_handle: Optional[IO]
stderr_handle: Optional[IO]
if args.logdir:
directory_path = os.path.join(os.getcwd(), args.logdir)
node_rank = args.node_rank
Expand Down
34 changes: 24 additions & 10 deletions torch/distributed/rendezvous.py
@@ -1,12 +1,14 @@
try:
from urllib.parse import urlparse, urlunparse
except ImportError:
from urlparse import urlparse, urlunparse
raise ImportError("urllib cannot be found, urlparse from python2 is no longer supported.")

import torch._six as six
import numbers
import os
import sys
from datetime import timedelta
from typing import Optional, Dict, Union
from torch._C._distributed_c10d import FileStore
from .constants import default_pg_timeout

Expand Down Expand Up @@ -47,7 +49,7 @@ def register_rendezvous_handler(scheme, handler):
_rendezvous_handlers[scheme] = handler


def rendezvous(url, rank=-1, world_size=-1, **kwargs):
def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
if not isinstance(url, six.string_classes):
raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url))

Expand All @@ -60,8 +62,9 @@ def rendezvous(url, rank=-1, world_size=-1, **kwargs):
# Append node-specific arguments.
result = urlparse(url)
if rank != -1 or world_size != -1:
query_dict = dict(
pair.split("=") for pair in filter(None, result.query.split("&"))
query_dict: Dict[str, Union[int, str]] = dict(
# mypy doesn't allow dict() to accept List of values (#257)
pair.split("=") for pair in filter(None, result.query.split("&")) # type: ignore[arg-type, misc]
)
assert (
"rank" not in query_dict and "world_size" not in query_dict
Expand All @@ -87,7 +90,7 @@ def _rendezvous_error(msg):
return ValueError("Error initializing torch.distributed using " + msg)


def _file_rendezvous_handler(url, **kwargs):
def _file_rendezvous_handler(url: str, **kwargs):
def _error(msg):
return _rendezvous_error("file:// rendezvous: " + msg)

Expand All @@ -99,7 +102,9 @@ def _error(msg):

if not path:
raise _error("path missing")
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
query: Dict[str, str]
# mypy doesn't allow dict() to accept List of values (#257)
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
if "rank" not in query:
raise _error("rank parameter missing")
if "world_size" not in query:
Expand All @@ -114,14 +119,16 @@ def _error(msg):
raise RuntimeError("Unable to perform rerendezvous using file:// method")


def _tcp_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):
def _error(msg):
return _rendezvous_error("tcp:// rendezvous: " + msg)

result = urlparse(url)
if not result.port:
raise _error("port number missing")
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
query: Dict[str, Union[int, str]]
# mypy doesn't allow dict() to accept List of values (#257)
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
if "rank" not in query:
raise _error("rank parameter missing")
if "world_size" not in query:
Expand All @@ -130,22 +137,29 @@ def _error(msg):
rank = int(query["rank"])
world_size = int(query["world_size"])
start_daemon = rank == 0
assert result.hostname is not None
store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout)
yield (store, rank, world_size)

# If this configuration is invalidated, there is nothing we can do about it
raise RuntimeError("Unable to perform rerendezvous using tcp:// method")


def _env_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):
def _error(msg):
return _rendezvous_error("env:// rendezvous: " + msg)

def _env_error(var):
return _error("environment variable %s expected, but not set" % var)

result = urlparse(url)
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
query: Dict[str, Union[int, str]]
# mypy doesn't allow dict() to accept List of values (#257)
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]

rank: Optional[Union[str, int]]
world_size: Optional[Union[str, int]]
master_port: Optional[Union[str, int]]

if "rank" in query:
rank = int(query["rank"])
Expand Down

0 comments on commit 49f0e5d

Please sign in to comment.