Skip to content

Commit

Permalink
annotate a few torch.nn.modules.* modules (#45772)
Browse files Browse the repository at this point in the history
Summary:
Fixes #45771

Pull Request resolved: #45772

Reviewed By: mruberry

Differential Revision: D24682013

Pulled By: albanD

fbshipit-source-id: e32bc4fe9c586c079f7070924a874c70f3d127fa
  • Loading branch information
guilhermeleobas authored and facebook-github-bot committed Nov 2, 2020
1 parent 7178790 commit 9b52654
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 26 deletions.
12 changes: 0 additions & 12 deletions mypy.ini
Expand Up @@ -77,9 +77,6 @@ ignore_errors = True
[mypy-torch._tensor_str]
ignore_errors = True

[mypy-torch.nn.modules.batchnorm]
ignore_errors = True

[mypy-torch.nn.modules.container]
ignore_errors = True

Expand All @@ -89,12 +86,6 @@ ignore_errors = True
[mypy-torch.nn.modules.fold]
ignore_errors = True

[mypy-torch.nn.modules.instancenorm]
ignore_errors = True

[mypy-torch.nn.modules.linear]
ignore_errors = True

[mypy-torch.nn.modules.loss]
ignore_errors = True

Expand All @@ -113,9 +104,6 @@ ignore_errors = True
[mypy-torch.nn.modules.rnn]
ignore_errors = True

[mypy-torch.nn.modules.sparse]
ignore_errors = True

[mypy-torch.nn.parallel._functions]
ignore_errors = True

Expand Down
3 changes: 2 additions & 1 deletion torch/nn/functional.pyi.in
Expand Up @@ -189,7 +189,8 @@ def embedding(input: Tensor, weight: Tensor, padding_idx: Optional[int] = ..., m

def embedding_bag(input: Tensor, weight: Tensor, offsets: Optional[Tensor] = ..., max_norm: Optional[float] = ...,
norm_type: float = ..., scale_grad_by_freq: bool = ..., mode: str = ...,
sparse: bool = ...) -> Tensor: ...
sparse: bool = ..., per_sample_weights: Optional[Tensor] = ...,
include_last_offset: bool = ...) -> Tensor: ...

def batch_norm(input: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor],
weight: Optional[Tensor] = ..., bias: Optional[Tensor] = ..., training: bool = ...,
Expand Down
17 changes: 12 additions & 5 deletions torch/nn/modules/batchnorm.py
Expand Up @@ -54,9 +54,11 @@ def __init__(

def reset_running_stats(self) -> None:
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()
# running_mean/running_var/num_batches... are registerd at runtime depending
# if self.track_running_stats is on
self.running_mean.zero_() # type: ignore[operator]
self.running_var.fill_(1) # type: ignore[operator]
self.num_batches_tracked.zero_() # type: ignore[operator]

def reset_parameters(self) -> None:
self.reset_running_stats()
Expand Down Expand Up @@ -107,8 +109,8 @@ def forward(self, input: Tensor) -> Tensor:

if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None:
self.num_batches_tracked = self.num_batches_tracked + 1
if self.num_batches_tracked is not None: # type: ignore
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
Expand All @@ -128,6 +130,8 @@ def forward(self, input: Tensor) -> Tensor:
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
Expand Down Expand Up @@ -487,6 +491,7 @@ def forward(self, input: Tensor) -> Tensor:
exponential_average_factor = self.momentum

if self.training and self.track_running_stats:
assert self.num_batches_tracked is not None
self.num_batches_tracked = self.num_batches_tracked + 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
Expand All @@ -508,6 +513,8 @@ def forward(self, input: Tensor) -> Tensor:
used for normalization (i.e. in eval mode when buffers are not None).
"""
# If buffers are not to be tracked, ensure that they won't be updated
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
running_mean = self.running_mean if not self.training or self.track_running_stats else None
running_var = self.running_var if not self.training or self.track_running_stats else None

Expand Down
2 changes: 2 additions & 0 deletions torch/nn/modules/instancenorm.py
Expand Up @@ -52,6 +52,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)

assert self.running_mean is None or isinstance(self.running_mean, Tensor)
assert self.running_var is None or isinstance(self.running_var, Tensor)
return F.instance_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats, self.momentum, self.eps)
Expand Down
9 changes: 5 additions & 4 deletions torch/nn/modules/linear.py
Expand Up @@ -102,10 +102,10 @@ def extra_repr(self) -> str:
# This class exists solely for Transformer; it has an annotation stating
# that bias is never None, which appeases TorchScript
class _LinearWithBias(Linear):
bias: Tensor
bias: Tensor # type: ignore

def __init__(self, in_features: int, out_features: int) -> None:
super().__init__(in_features, out_features, bias=True)
super().__init__(in_features, out_features, bias=True) # type: ignore


class Bilinear(Module):
Expand Down Expand Up @@ -208,7 +208,8 @@ class LazyLinear(LazyModuleMixin, Linear):
"""

cls_to_become = Linear
cls_to_become = Linear # type: ignore[assignment]
weight: UninitializedParameter

def __init__(self, out_features: int, bias: bool = True) -> None:
super().__init__(0, out_features, bias)
Expand All @@ -218,7 +219,7 @@ def reset_parameters(self) -> None:
if not self.has_uninitialized_params() and self.in_features != 0:
super().reset_parameters()

def initialize_parameters(self, input) -> None:
def initialize_parameters(self, input) -> None: # type: ignore
if self.has_uninitialized_params():
with torch.no_grad():
self.in_features = input.shape[-1]
Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/sparse.py
Expand Up @@ -99,8 +99,8 @@ class Embedding(Module):

num_embeddings: int
embedding_dim: int
padding_idx: int
max_norm: float
padding_idx: Optional[int]
max_norm: Optional[float]
norm_type: float
scale_grad_by_freq: bool
weight: Tensor
Expand Down Expand Up @@ -284,7 +284,7 @@ class EmbeddingBag(Module):

num_embeddings: int
embedding_dim: int
max_norm: float
max_norm: Optional[float]
norm_type: float
scale_grad_by_freq: bool
weight: Tensor
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/parameter.pyi
Expand Up @@ -11,5 +11,5 @@ class Parameter(Tensor):
class UninitializedParameter(Tensor):
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...

def materialize(self, shape: Tuple[int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
...

0 comments on commit 9b52654

Please sign in to comment.