Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type annotations to torch.nn.modules.normalization #49035

Closed
wants to merge 6 commits into from
Closed

add type annotations to torch.nn.modules.normalization #49035

wants to merge 6 commits into from

Conversation

guilhermeleobas
Copy link
Collaborator

Fixes #49034

@guilhermeleobas guilhermeleobas added the module: typing Related to mypy type annotations label Dec 8, 2020
@guilhermeleobas guilhermeleobas self-assigned this Dec 8, 2020
@codecov
Copy link

codecov bot commented Dec 8, 2020

Codecov Report

Merging #49035 (59aba29) into master (e29082b) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #49035      +/-   ##
==========================================
- Coverage   80.71%   80.70%   -0.01%     
==========================================
  Files        1904     1904              
  Lines      206598   206598              
==========================================
- Hits       166750   166742       -8     
- Misses      39848    39856       +8     

@guilhermeleobas guilhermeleobas marked this pull request as ready for review December 9, 2020 14:51
@H-Huang H-Huang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 10, 2020
Copy link
Collaborator

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @guilhermeleobas. Just a couple of minor comments.

torch/nn/modules/normalization.py Outdated Show resolved Hide resolved
torch/nn/modules/normalization.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, thanks @guilhermeleobas

@@ -169,7 +170,7 @@ def forward(self, input: Tensor) -> Tensor:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)

def extra_repr(self) -> Tensor:
def extra_repr(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder why it was not caught before.

@@ -80,7 +80,7 @@ def extra_repr(self) -> str:
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)


_shape_t = Union[int, List[int], Size]
_shape_t = Union[int, Sequence[int], Size]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a test to validate JIT-ability? since we knew Sequence & Tuple doesnt work on JIT.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is not necessary. I am reverting this particular change

@guilhermeleobas
Copy link
Collaborator Author

Do not merge this PR until one checks if the annotations introduce any regression. See:
#49564 (comment)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@walterddr has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@guilhermeleobas guilhermeleobas marked this pull request as ready for review January 8, 2021 15:49
@guilhermeleobas
Copy link
Collaborator Author

Tests are covered by test cases in the TestJitGeneratedModule. See:

test/test_jit.py::TestJitGeneratedModule::test_nn_LayerNorm_1d_elementwise_affine
test/test_jit.py::TestJitGeneratedModule::test_nn_LayerNorm_1d_empty_elementwise_affine
test/test_jit.py::TestJitGeneratedModule::test_nn_LayerNorm_1d_no_elementwise_affine
test/test_jit.py::TestJitGeneratedModule::test_nn_LayerNorm_3d_elementwise_affine
test/test_jit.py::TestJitGeneratedModule::test_nn_LayerNorm_3d_no_elementwise_affine

Copy link
Contributor

@walterddr walterddr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good. one minor comment.

also could you confirm the coverage for LayerNorm1d constructor arguments in torch.test._internal.common_nn actually consist of non-tuple type input? from what I see all of those constructor_args are lists. not sure if they collapse into single integer

Comment on lines +152 to +153
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalized_

Suggested change
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
normalized_shape_ = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape_)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also wondering do we even need the second tuple() wrapper?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the second tuple() is to make sure self.normalized_shape is a tuple and not a list, for instance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I do not think the suggested change works as expected. If the input arg normalized_shape is a list, then the if block will never be executed and normalized_shape_ will not be defined.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@walterddr has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@walterddr merged this pull request in 4411b5a.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: typing Related to mypy type annotations open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable torch.nn.modules.normalization typechecks during CI
7 participants