New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add type annotations to torch.nn.modules.normalization #49035
Conversation
Codecov Report
@@ Coverage Diff @@
## master #49035 +/- ##
==========================================
- Coverage 80.71% 80.70% -0.01%
==========================================
Files 1904 1904
Lines 206598 206598
==========================================
- Hits 166750 166742 -8
- Misses 39848 39856 +8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @guilhermeleobas. Just a couple of minor comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now, thanks @guilhermeleobas
@@ -169,7 +170,7 @@ def forward(self, input: Tensor) -> Tensor: | |||
return F.layer_norm( | |||
input, self.normalized_shape, self.weight, self.bias, self.eps) | |||
|
|||
def extra_repr(self) -> Tensor: | |||
def extra_repr(self) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wonder why it was not caught before.
torch/nn/modules/normalization.py
Outdated
@@ -80,7 +80,7 @@ def extra_repr(self) -> str: | |||
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__) | |||
|
|||
|
|||
_shape_t = Union[int, List[int], Size] | |||
_shape_t = Union[int, Sequence[int], Size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a test to validate JIT-ability? since we knew Sequence & Tuple doesnt work on JIT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this is not necessary. I am reverting this particular change
Do not merge this PR until one checks if the annotations introduce any regression. See: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@walterddr has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Tests are covered by test cases in the
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good. one minor comment.
also could you confirm the coverage for LayerNorm1d constructor arguments in torch.test._internal.common_nn
actually consist of non-tuple type input? from what I see all of those constructor_args are lists. not sure if they collapse into single integer
normalized_shape = (normalized_shape,) # type: ignore[assignment] | ||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normalized_
normalized_shape = (normalized_shape,) # type: ignore[assignment] | |
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] | |
normalized_shape_ = (normalized_shape,) | |
self.normalized_shape = tuple(normalized_shape_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also wondering do we even need the second tuple()
wrapper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the second tuple()
is to make sure self.normalized_shape
is a tuple and not a list, for instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I do not think the suggested change works as expected. If the input arg normalized_shape
is a list, then the if block will never be executed and normalized_shape_
will not be defined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@walterddr has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@walterddr merged this pull request in 4411b5a. |
Fixes #49034