Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type annotations to torch.nn.modules.module #49045

Closed
wants to merge 7 commits into from
3 changes: 0 additions & 3 deletions mypy.ini
Expand Up @@ -76,9 +76,6 @@ ignore_errors = True
[mypy-torch.nn.modules.conv]
ignore_errors = True

[mypy-torch.nn.modules.module]
ignore_errors = True

[mypy-torch.nn.modules.normalization]
ignore_errors = True

Expand Down
8 changes: 7 additions & 1 deletion torch/_C/__init__.pyi.in
Expand Up @@ -756,7 +756,13 @@ def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerP
def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails

# Defined in torch/csrc/jit/python/python_tracer.cpp
class TracingState: ...
class TracingState:
def push_scope(self, scope_name: str) -> None: ...
def pop_scope(self) -> None: ...
def current_scope(self) -> str: ...
def set_graph(self, graph: Graph) -> None: ...
def graph(self) -> Graph: ...
...

def _create_graph_by_tracing(
func: Callable[..., Any],
Expand Down
18 changes: 15 additions & 3 deletions torch/_C/_nn.pyi.in
@@ -1,5 +1,6 @@
from torch import Tensor
from typing import Callable, Optional, List
from torch import Tensor, memory_format
from typing import Callable, Optional, List, overload, Tuple
from torch.types import _bool, _dtype, _device

# Defined in tools/autograd/templates/python_nn_functions.cpp

Expand All @@ -10,4 +11,15 @@ def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tens

# Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
def mkldnn_reorder_conv2d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ...
def mkldnn_reorder_conv3d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ...
def mkldnn_reorder_conv3d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ...

# Defined at tools/autograd/templates/python_nn_functions.cpp
@overload
def _parse_to(device: _device, dtype: _dtype, non_blocking: _bool, copy: _bool, *,
memory_format: memory_format) -> Tuple[_device, _dtype, _bool, memory_format]: ...
@overload
def _parse_to(dtype: _dtype, non_blocking: _bool, copy: _bool, *,
memory_format: memory_format) -> Tuple[_device, _dtype, _bool, memory_format]: ...
@overload
def _parse_to(tensor: Tensor, non_blocking: _bool, copy: _bool, *,
memory_format: memory_format) -> Tuple[_device, _dtype, _bool, memory_format]: ...
2 changes: 1 addition & 1 deletion torch/distributed/nn/api/remote_module.py
Expand Up @@ -279,7 +279,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:

def named_parameters( # type: ignore[return]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we remove this type ignore as well since we fixed the return Parameter

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch/distributed/nn/api/remote_module.py:280: error: Missing return statement  [return]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol. Got it.

self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, Tensor]]:
) -> Iterator[Tuple[str, Parameter]]:
_raise_not_supported(self.named_parameters.__name__)

def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return]
Expand Down
6 changes: 3 additions & 3 deletions torch/fx/symbolic_trace.py
Expand Up @@ -88,9 +88,9 @@ def create_arg(self, a: Any) -> 'Argument':
return self.create_node('get_attr', n, (), {})
raise NameError('parameter is not a member of this module')
elif isinstance(a, torch.Tensor):
for n, p in self.root.named_buffers():
if a is p:
return self.create_node('get_attr', n, (), {})
for n_, p_ in self.root.named_buffers():
if a is p_:
return self.create_node('get_attr', n_, (), {})

# For NamedTuple instances that appear literally as args, we emit
# a node to construct the NamedTuple and use that Node as the argument.
Expand Down
33 changes: 18 additions & 15 deletions torch/nn/modules/module.py
Expand Up @@ -8,7 +8,7 @@
import torch.utils.hooks as hooks

from torch import Tensor, device, dtype
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
from ...utils.hooks import RemovableHandle

_grad_t = Union[Tuple[Tensor, ...], Tensor]
Expand Down Expand Up @@ -49,10 +49,10 @@ def _addindent(s_, numSpaces):
r"""This tracks hooks common to all modules that are executed before/after
calling forward and backward. This is global state used for debugging/profiling
purposes"""
_global_backward_hooks = OrderedDict()
_global_is_full_backward_hook = None
_global_forward_pre_hooks = OrderedDict()
_global_forward_hooks = OrderedDict()
_global_backward_hooks: Dict[int, Callable] = OrderedDict()
_global_is_full_backward_hook: Optional[bool] = None
_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_forward_hooks: Dict[int, Callable] = OrderedDict()


def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
Expand Down Expand Up @@ -732,13 +732,13 @@ def _get_backward_hooks(self):
It returns two lists, one with the full backward hooks and one with the non-full
backward hooks.
"""
full_backward_hooks = []
full_backward_hooks: List[Callable] = []
if (_global_is_full_backward_hook is True):
full_backward_hooks += _global_backward_hooks.values()
if (self._is_full_backward_hook is True):
full_backward_hooks += self._backward_hooks.values()

non_full_backward_hooks = []
non_full_backward_hooks: List[Callable] = []
if (_global_is_full_backward_hook is False):
non_full_backward_hooks += _global_backward_hooks.values()
if (self._is_full_backward_hook is False):
Expand Down Expand Up @@ -841,7 +841,9 @@ def _slow_forward(self, *input, **kwargs):
return self.forward(*input, **kwargs)
recording_scopes = torch.jit._trace._trace_module_map is not None
if recording_scopes:
name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None
# type ignore was added because at this point one knows that
# torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore
if name:
tracing_state.push_scope(name)
else:
Expand Down Expand Up @@ -1158,7 +1160,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)

def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
Expand All @@ -1177,15 +1179,16 @@ def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
"""
missing_keys = []
unexpected_keys = []
error_msgs = []
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []

# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined]

def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
Expand All @@ -1196,7 +1199,7 @@ def load(module, prefix=''):
load(child, prefix + name + '.')

load(self)
load = None # break load->load reference cycle
del load

if strict:
if len(unexpected_keys) > 0:
Expand Down Expand Up @@ -1250,7 +1253,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
for name, param in self.named_parameters(recurse=recurse):
yield param

def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
r"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself.

Expand Down