Skip to content

Commit

Permalink
Fix typing errors in torch.distributed.nn.* directory. (#47533)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #47533

Test Plan: Imported from OSS

Reviewed By: walterddr

Differential Revision: D24952500

Pulled By: xuzhao9

fbshipit-source-id: 8e66784fd8f9f111b6329e0bb48d6cd61c690a4a
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 17, 2020
1 parent 915050e commit 7f66fa6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 29 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ ignore_errors = False
[mypy-torch.distributed.distributed_c10d.*]
ignore_errors = False

[mypy-torch.distributed.nn.*]
ignore_errors = False

[mypy-torch.testing._internal.hypothesis_utils.*]
ignore_errors = True

Expand Down
57 changes: 29 additions & 28 deletions torch/distributed/nn/api/remote_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.distributed.rpc.utils import _parse_remote_device
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle
from torch.nn import Module


_grad_t = Union[Tuple[Tensor, ...], Tensor]
Expand Down Expand Up @@ -52,7 +53,7 @@ def _create_module(module_cls, args, kwargs, device="cpu", module_interface_cls=


def _param_rrefs(module_rref, recurse):
ret = []
ret: List[rpc.RRef[Parameter]] = []
for param in module_rref.local_value().parameters(recurse):
ret.append(rpc.RRef(param))
return ret
Expand Down Expand Up @@ -216,45 +217,45 @@ def register_buffer(
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
_raise_not_supported(self.register_parameter.__name__)

def add_module(self, name: str, module: Optional["Module"]) -> None:
def add_module(self, name: str, module: Optional[Module]) -> None:
_raise_not_supported(self.add_module.__name__)

def apply(self: T, fn: Callable[["Module"], None]) -> T:
def apply(self: T, fn: Callable[[Module], None]) -> T: # type: ignore[return]
_raise_not_supported(self.apply.__name__)

def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
_raise_not_supported(self.cuda.__name__)

def cpu(self: T) -> T:
def cpu(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.cpu.__name__)

def type(self: T, dst_type: Union[dtype, str]) -> T:
def type(self: T, dst_type: Union[dtype, str]) -> T: # type: ignore[return]
_raise_not_supported(self.type.__name__)

def float(self: T) -> T:
def float(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.float.__name__)

def double(self: T) -> T:
def double(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.double.__name__)

def half(self: T) -> T:
def half(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.half.__name__)

def bfloat16(self: T) -> T:
def bfloat16(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.bfloat16.__name__)

def to(self, *args, **kwargs):
def to(self, *args, **kwargs) -> T: # type: ignore[return]
_raise_not_supported(self.to.__name__)

def register_backward_hook(
self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, Tensor]]
def register_backward_hook( # type: ignore[return]
self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, Tensor]]
) -> RemovableHandle:
_raise_not_supported(self.register_backward_hook.__name__)

def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return]
_raise_not_supported(self.register_forward_pre_hook.__name__)

def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return]
_raise_not_supported(self.register_forward_hook.__name__)

def state_dict(self, destination=None, prefix="", keep_vars=False):
Expand All @@ -272,47 +273,47 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead."
)

def named_parameters(
def named_parameters( # type: ignore[return]
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, Tensor]]:
_raise_not_supported(self.named_parameters.__name__)

def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return]
_raise_not_supported(self.buffers.__name__)

def named_buffers(
def named_buffers( # type: ignore[return]
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, Tensor]]:
_raise_not_supported(self.named_buffers.__name__)

def children(self) -> Iterator["Module"]:
def children(self) -> Iterator[Module]: # type: ignore[return]
_raise_not_supported(self.children.__name__)

def named_children(self) -> Iterator[Tuple[str, "Module"]]:
def named_children(self) -> Iterator[Tuple[str, Module]]: # type: ignore[return]
_raise_not_supported(self.named_children.__name__)

def modules(self) -> Iterator["Module"]:
def modules(self) -> Iterator[Module]: # type: ignore[return]
_raise_not_supported(self.modules.__name__)

def named_modules(self, memo: Optional[Set["Module"]] = None, prefix: str = ""):
def named_modules(self, memo: Optional[Set[Module]] = None, prefix: str = ""):
_raise_not_supported(self.named_modules.__name__)

def train(self: T, mode: bool = True) -> T:
def train(self: T, mode: bool = True) -> T: # type: ignore[return]
_raise_not_supported(self.train.__name__)

def eval(self: T) -> T:
def eval(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.eval.__name__)

def requires_grad_(self: T, requires_grad: bool = True) -> T:
def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return]
_raise_not_supported(self.requires_grad_.__name__)

def zero_grad(self) -> None:
def zero_grad(self, set_to_none: bool = False) -> None:
_raise_not_supported(self.zero_grad.__name__)

def share_memory(self: T) -> T:
def share_memory(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.share_memory.__name__)

def extra_repr(self) -> str:
def extra_repr(self) -> str: # type: ignore[return]
_raise_not_supported(self.extra_repr.__name__)


Expand Down
5 changes: 4 additions & 1 deletion torch/distributed/nn/jit/instantiator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile

import torch
from typing import Optional
from torch.distributed.nn.jit.templates.remote_module_template import (
REMOTE_MODULE_TEMPLATE,
)
Expand Down Expand Up @@ -37,11 +38,12 @@ def get_arg_return_types_from_interface(module_interface):

arg_str_list = []
arg_type_str_list = []
assert method_schema is not None
for argument in method_schema.arguments:
arg_str_list.append(argument.name)

if argument.has_default_value():
default_value_str = " = {}".format(argument.default)
default_value_str = " = {}".format(argument.default_value)
else:
default_value_str = ""
arg_type_str = "{name}: {type}{default_value}".format(
Expand All @@ -63,6 +65,7 @@ def get_arg_return_types_from_interface(module_interface):


def _write(out_path, text):
old_text: Optional[str]
try:
with open(out_path, "r") as f:
old_text = f.read()
Expand Down

0 comments on commit 7f66fa6

Please sign in to comment.