From d80d38cf87a527fded154293038d252d595c7100 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 5 Jan 2021 19:02:38 -0800 Subject: [PATCH] Clean up type annotations in caffe2/torch/nn/modules (#49957) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49957 Test Plan: Sandcastle Reviewed By: xush6528 Differential Revision: D25729745 fbshipit-source-id: 85810e2c18ca6856480bef81217da1359b63d8a3 --- torch/nn/modules/activation.py | 5 ++--- torch/nn/modules/conv.py | 5 +++-- torch/nn/modules/utils.py | 3 +-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 073c95c28619..837ecca6fe9d 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -922,9 +922,8 @@ def __setstate__(self, state): super(MultiheadAttention, self).__setstate__(state) - def forward(self, query, key, value, key_padding_mask=None, - need_weights=True, attn_mask=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index f22c35fa39ff..6a9c4dcd2ef6 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -530,8 +530,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, # dilation being an optional parameter is for backwards # compatibility - def _output_padding(self, input, output_size, stride, padding, kernel_size, dilation=None): - # type: (Tensor, Optional[List[int]], List[int], List[int], List[int], Optional[List[int]]) -> List[int] + def _output_padding(self, input: Tensor, output_size: Optional[List[int]], + stride: List[int], padding: List[int], kernel_size: List[int], + dilation: Optional[List[int]] = None) -> List[int]: if output_size is None: ret = _single(self.output_padding) # converting to list if was not already else: diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 3e0b93c7afc0..97e4195619cb 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -26,8 +26,7 @@ def _reverse_repeat_tuple(t, n): return tuple(x for x in reversed(t) for _ in range(n)) -def _list_with_default(out_size, defaults): - # type: (List[int], List[int]) -> List[int] +def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]: if isinstance(out_size, int): return out_size if len(defaults) <= len(out_size):