Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type annotations to torch.nn.modules.normalization #49035

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions mypy.ini
Expand Up @@ -80,9 +80,6 @@ ignore_errors = True
[mypy-torch.nn.modules.module]
ignore_errors = True

[mypy-torch.nn.modules.normalization]
ignore_errors = True

[mypy-torch.nn.modules.padding]
ignore_errors = True

Expand Down
2 changes: 1 addition & 1 deletion torch/nn/functional.pyi.in
Expand Up @@ -205,7 +205,7 @@ def instance_norm(input: Tensor, running_mean: Optional[Tensor] = ..., running_v
momentum: float = ..., eps: float = ...) -> Tensor: ...


def layer_norm(input: Tensor, normalized_shape: List[int], weight: Optional[Tensor] = ..., bias: Optional[Tensor] = ...,
def layer_norm(input: Tensor, normalized_shape: Sequence[int], weight: Optional[Tensor] = ..., bias: Optional[Tensor] = ...,
eps: float = ...) -> Tensor: ...


Expand Down
15 changes: 8 additions & 7 deletions torch/nn/modules/normalization.py
Expand Up @@ -7,7 +7,7 @@
from .. import init

from torch import Tensor, Size
from typing import Union, List
from typing import Union, List, Tuple


class LocalResponseNorm(Module):
Expand Down Expand Up @@ -141,20 +141,21 @@ class LayerNorm(Module):
>>> output = m(input)
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
normalized_shape: _shape_t
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool

def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
Comment on lines +152 to +153
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalized_

Suggested change
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
normalized_shape_ = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape_)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also wondering do we even need the second tuple() wrapper?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the second tuple() is to make sure self.normalized_shape is a tuple and not a list, for instance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I do not think the suggested change works as expected. If the input arg normalized_shape is a list, then the if block will never be executed and normalized_shape_ will not be defined.

self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.weight = Parameter(torch.Tensor(*self.normalized_shape))
self.bias = Parameter(torch.Tensor(*self.normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
Expand All @@ -169,7 +170,7 @@ def forward(self, input: Tensor) -> Tensor:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)

def extra_repr(self) -> Tensor:
def extra_repr(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder why it was not caught before.

return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)

Expand Down