diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 0c5258615bfd..5fc6cdf456f7 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -921,9 +921,8 @@ def __setstate__(self, state): super(MultiheadAttention, self).__setstate__(state) - def forward(self, query, key, value, key_padding_mask=None, - need_weights=True, attn_mask=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index f22c35fa39ff..6a9c4dcd2ef6 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -530,8 +530,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, # dilation being an optional parameter is for backwards # compatibility - def _output_padding(self, input, output_size, stride, padding, kernel_size, dilation=None): - # type: (Tensor, Optional[List[int]], List[int], List[int], List[int], Optional[List[int]]) -> List[int] + def _output_padding(self, input: Tensor, output_size: Optional[List[int]], + stride: List[int], padding: List[int], kernel_size: List[int], + dilation: Optional[List[int]] = None) -> List[int]: if output_size is None: ret = _single(self.output_padding) # converting to list if was not already else: diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 3e0b93c7afc0..97e4195619cb 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -26,8 +26,7 @@ def _reverse_repeat_tuple(t, n): return tuple(x for x in reversed(t) for _ in range(n)) -def _list_with_default(out_size, defaults): - # type: (List[int], List[int]) -> List[int] +def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]: if isinstance(out_size, int): return out_size if len(defaults) <= len(out_size):