Skip to content

Commit

Permalink
Clean up type annotations in caffe2/torch/nn/modules (#49957)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #49957

Test Plan: Sandcastle

Reviewed By: xush6528

Differential Revision: D25729745

fbshipit-source-id: 85810e2c18ca6856480bef81217da1359b63d8a3
  • Loading branch information
r-barnes authored and facebook-github-bot committed Jan 6, 2021
1 parent 75028f2 commit d80d38c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
5 changes: 2 additions & 3 deletions torch/nn/modules/activation.py
Expand Up @@ -922,9 +922,8 @@ def __setstate__(self, state):

super(MultiheadAttention, self).__setstate__(state)

def forward(self, query, key, value, key_padding_mask=None,
need_weights=True, attn_mask=None):
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
Expand Down
5 changes: 3 additions & 2 deletions torch/nn/modules/conv.py
Expand Up @@ -530,8 +530,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,

# dilation being an optional parameter is for backwards
# compatibility
def _output_padding(self, input, output_size, stride, padding, kernel_size, dilation=None):
# type: (Tensor, Optional[List[int]], List[int], List[int], List[int], Optional[List[int]]) -> List[int]
def _output_padding(self, input: Tensor, output_size: Optional[List[int]],
stride: List[int], padding: List[int], kernel_size: List[int],
dilation: Optional[List[int]] = None) -> List[int]:
if output_size is None:
ret = _single(self.output_padding) # converting to list if was not already
else:
Expand Down
3 changes: 1 addition & 2 deletions torch/nn/modules/utils.py
Expand Up @@ -26,8 +26,7 @@ def _reverse_repeat_tuple(t, n):
return tuple(x for x in reversed(t) for _ in range(n))


def _list_with_default(out_size, defaults):
# type: (List[int], List[int]) -> List[int]
def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]:
if isinstance(out_size, int):
return out_size
if len(defaults) <= len(out_size):
Expand Down

0 comments on commit d80d38c

Please sign in to comment.