diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py index 615741f38da7..d853a55b3933 100644 --- a/torch/jit/quantized.py +++ b/torch/jit/quantized.py @@ -130,8 +130,7 @@ def check_forward_input(self, input): input.size(1), self.input_size)) @torch.jit.script_method - def check_forward_hidden(self, input, hx, hidden_label=''): - # type: (Tensor, Tensor, str) -> None + def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None: if input.size(0) != hx.size(0): raise RuntimeError( "Input batch size {} doesn't match hidden{} batch size {}".format( @@ -169,8 +168,7 @@ def __init__(self, other): self.nonlinearity = other.nonlinearity @torch.jit.script_method - def forward(self, input, hx=None): - # type: (Tensor, Optional[Tensor]) -> Tensor + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: self.check_forward_input(input) if hx is None: hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) @@ -201,8 +199,7 @@ def __init__(self, other): super(QuantizedLSTMCell, self).__init__(other) @torch.jit.script_method - def forward(self, input, hx=None): - # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: self.check_forward_input(input) if hx is None: zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) @@ -222,8 +219,7 @@ def __init__(self, other): super(QuantizedGRUCell, self).__init__(other) @torch.jit.script_method - def forward(self, input, hx=None): - # type: (Tensor, Optional[Tensor]) -> Tensor + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: self.check_forward_input(input) if hx is None: hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) @@ -236,8 +232,7 @@ def forward(self, input, hx=None): ) -def apply_permutation(tensor, permutation, dim=1): - # type: (Tensor, Tensor, int) -> Tensor +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: return tensor.index_select(dim, permutation) @@ -303,8 +298,7 @@ def get_weight_bias(ihhh): self.all_weights.append(cell_params) @torch.jit.script_method - def check_input(self, input, batch_sizes): - # type: (Tensor, Optional[Tensor]) -> None + def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: expected_input_dim = 2 if batch_sizes is not None else 3 if input.dim() != expected_input_dim: raise RuntimeError( @@ -316,8 +310,7 @@ def check_input(self, input, batch_sizes): self.input_size, input.size(-1))) @torch.jit.script_method - def get_expected_hidden_size(self, input, batch_sizes): - # type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int] + def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) else: @@ -328,21 +321,19 @@ def get_expected_hidden_size(self, input, batch_sizes): return expected_hidden_size @torch.jit.script_method - def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'): - # type: (Tensor, Tuple[int, int, int], str) -> None + def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int], + msg: str = 'Expected hidden size {}, got {}') -> None: if hx.size() != expected_hidden_size: raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) @torch.jit.script_method - def check_forward_args(self, input, hidden, batch_sizes): - # type: (Tensor, Tensor, Optional[Tensor]) -> None + def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None: self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) self.check_hidden_size(hidden, expected_hidden_size, msg='Expected hidden size {}, got {}') @torch.jit.script_method - def permute_hidden(self, hx, permutation): - # type: (Tensor, Optional[Tensor]) -> Tensor + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: if permutation is None: return hx return apply_permutation(hx, permutation) @@ -355,8 +346,9 @@ def __init__(self, other, dtype): super(QuantizedLSTM, self).__init__(other, dtype) @torch.jit.script_method - def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): - # type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa + def forward_impl(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]], batch_sizes: Optional[Tensor], + max_batch_size: int, sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + # noqa if hx is None: num_directions = 2 if self.bidirectional else 1 zeros = torch.zeros(self.num_layers * num_directions, @@ -379,8 +371,7 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): return output, hidden @torch.jit.script_method - def forward_tensor(self, input, hx=None): - # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def forward_tensor(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: batch_sizes = None max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None @@ -391,8 +382,8 @@ def forward_tensor(self, input, hx=None): return output, self.permute_hidden(hidden, unsorted_indices) @torch.jit.script_method - def forward_packed(self, input, hx=None): - # type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa + def forward_packed(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: input, batch_sizes, sorted_indices, unsorted_indices = input max_batch_size = batch_sizes[0] max_batch_size = int(max_batch_size) @@ -404,15 +395,13 @@ def forward_packed(self, input, hx=None): @torch.jit.script_method - def permute_hidden(self, hx, permutation): - # type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor] + def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]: if permutation is None: return hx return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation) @torch.jit.script_method - def check_forward_args(self, input, hidden, batch_sizes): - # type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None + def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]) -> None: self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) @@ -432,8 +421,9 @@ class QuantizedGRU(QuantizedRNNBase): __overloads__ = {'forward': ['forward_packed', 'forward_tensor']} @torch.jit.script_method - def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): - # type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa + def forward_impl(self, input: Tensor, hx: Optional[Tensor], batch_sizes: Optional[Tensor], max_batch_size: int, + sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tensor]: + # noqa if hx is None: num_directions = 2 if self.bidirectional else 1 hx = torch.zeros(self.num_layers * num_directions, @@ -459,8 +449,7 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): return output, hidden @torch.jit.script_method - def forward_tensor(self, input, hx=None): - # type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor] + def forward_tensor(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: batch_sizes = None max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None @@ -470,8 +459,7 @@ def forward_tensor(self, input, hx=None): return output, self.permute_hidden(hidden, unsorted_indices) @torch.jit.script_method - def forward_packed(self, input, hx=None): - # type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor] + def forward_packed(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]: input, batch_sizes, sorted_indices, unsorted_indices = input max_batch_size = batch_sizes[0] max_batch_size = int(max_batch_size)