Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing annotations for torch.nn.quantized.dynamic.modules.rnn #43186

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions mypy.ini
Expand Up @@ -174,9 +174,6 @@ ignore_errors = True
[mypy-torch.nn.qat.modules.conv]
ignore_errors = True

[mypy-torch.nn.quantized.dynamic.modules.rnn]
ignore_errors = True

[mypy-torch.nn.quantized.dynamic.modules.linear]
ignore_errors = True

Expand Down
91 changes: 48 additions & 43 deletions torch/nn/quantized/dynamic/modules/rnn.py
Expand Up @@ -6,12 +6,11 @@
import torch
import torch.nn as nn
from torch import Tensor # noqa: F401
from torch._jit_internal import Tuple, Optional, List # noqa: F401
from torch._jit_internal import Tuple, Optional, List, Union, Dict # noqa: F401
from torch.nn.utils.rnn import PackedSequence
from torch.nn.quantized.modules.utils import _quantize_weight

def apply_permutation(tensor, permutation, dim=1):
# type: (Tensor, Tensor, int) -> Tensor
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
return tensor.index_select(dim, permutation)

class PackedParameter(torch.nn.Module):
Expand Down Expand Up @@ -53,12 +52,14 @@ def __init__(self, mode, input_size, hidden_size,
self.training = False
num_directions = 2 if bidirectional else 1

if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
isinstance(dropout, bool):
# "type: ignore" is required since ints and Numbers are not fully comparable
# https://github.com/python/mypy/issues/8566
if not isinstance(dropout, numbers.Number) \
or not 0 <= dropout <= 1 or isinstance(dropout, bool): # type: ignore
raise ValueError("dropout should be a number in range [0, 1] "
"representing the probability of an element being "
"zeroed")
if dropout > 0 and num_layers == 1:
if dropout > 0 and num_layers == 1: # type: ignore
warnings.warn("dropout option adds dropout after all but last "
"recurrent layer, so non-zero dropout expects "
"num_layers greater than 1, but got dropout={} and "
Expand Down Expand Up @@ -149,8 +150,7 @@ def __repr__(self):
main_str += ')'
return main_str

def check_input(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> None
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
expected_input_dim = 2 if batch_sizes is not None else 3
if input.dim() != expected_input_dim:
raise RuntimeError(
Expand All @@ -161,33 +161,31 @@ def check_input(self, input, batch_sizes):
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
self.input_size, input.size(-1)))

def get_expected_hidden_size(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
if batch_sizes is not None:
mini_batch = batch_sizes[0]
mini_batch = int(mini_batch)
mini_batch = int(batch_sizes[0])
else:
mini_batch = input.size(0) if self.batch_first else input.size(1)
num_directions = 2 if self.bidirectional else 1
expected_hidden_size = (self.num_layers * num_directions,
mini_batch, self.hidden_size)
return expected_hidden_size

def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
# type: (Tensor, Tuple[int, int, int], str) -> None
def check_hidden_size(
self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
msg: str = 'Expected hidden size {}, got {}'
) -> None:
if hx.size() != expected_hidden_size:
raise RuntimeError(msg.format(
expected_hidden_size, list(hx.size())))

def check_forward_args(self, input, hidden, batch_sizes):
# type: (Tensor, Tensor, Optional[Tensor]) -> None
def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None:
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
self.check_hidden_size(hidden, expected_hidden_size,
msg='Expected hidden size {}, got {}')

def permute_hidden(self, hx, permutation):
# type: (Tensor, Optional[Tensor]) -> Tensor
def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
if permutation is None:
return hx
return apply_permutation(hx, permutation)
Expand Down Expand Up @@ -287,7 +285,7 @@ def quantize_and_pack(w, b):

def _weight_bias(self):
# Returns a dict of weights and biases
weight_bias_dict = {'weight' : {}, 'bias' : {}}
weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
count = 0
num_directions = 2 if self.bidirectional else 1
for layer in range(self.num_layers):
Expand Down Expand Up @@ -337,8 +335,11 @@ def __init__(self, *args, **kwargs):
def _get_name(self):
return 'DynamicQuantizedLSTM'

def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
def forward_impl(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]],
batch_sizes: Optional[Tensor], max_batch_size: int,
sorted_indices: Optional[Tensor]
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
if hx is None:
num_directions = 2 if self.bidirectional else 1
zeros = torch.zeros(self.num_layers * num_directions,
Expand Down Expand Up @@ -367,8 +368,9 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
return output, hidden

@torch.jit.export
def forward_tensor(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
def forward_tensor(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
Expand All @@ -380,27 +382,32 @@ def forward_tensor(self, input, hx=None):
return output, self.permute_hidden(hidden, unsorted_indices)

@torch.jit.export
def forward_packed(self, input, hx=None):
# type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
input, batch_sizes, sorted_indices, unsorted_indices = input
def forward_packed(
self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa
input_, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)

output, hidden = self.forward_impl(
input, hx, batch_sizes, max_batch_size, sorted_indices)
output_, hidden = self.forward_impl(
input_, hx, batch_sizes, max_batch_size, sorted_indices)

output = PackedSequence(output, batch_sizes,
output = PackedSequence(output_, batch_sizes,
sorted_indices, unsorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices)

def permute_hidden(self, hx, permutation):
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
# "type: ignore" is required due to issue #43072
def permute_hidden( # type: ignore
self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]
) -> Tuple[Tensor, Tensor]:
if permutation is None:
return hx
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)

def check_forward_args(self, input, hidden, batch_sizes):
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor])->None
# "type: ignore" is required due to issue #43072
def check_forward_args( # type: ignore
self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]
) -> None:
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)

Expand Down Expand Up @@ -483,8 +490,7 @@ def check_forward_input(self, input):
"input has inconsistent input_size: got {}, expected {}".format(
input.size(1), self.input_size))

def check_forward_hidden(self, input, hx, hidden_label=''):
# type: (Tensor, Tensor, str) -> None
def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None:
if input.size(0) != hx.size(0):
raise RuntimeError(
"Input batch size {} doesn't match hidden{} batch size {}".format(
Expand Down Expand Up @@ -518,6 +524,8 @@ def from_float(cls, mod):
if dtype not in supported_scalar_types:
raise RuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))

qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]

if type(mod) == torch.nn.LSTMCell:
qRNNCellBase = LSTMCell(mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype)
elif type(mod) == torch.nn.GRUCell:
Expand Down Expand Up @@ -561,7 +569,7 @@ def process_weights(weight, bias, dtype):

def _weight_bias(self):
# Returns a dict of weights and biases
weight_bias_dict = {'weight' : {}, 'bias' : {}}
weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
w1, b1 = self._packed_weight_ih.__getstate__()[0]
w2, b2 = self._packed_weight_hh.__getstate__()[0]
weight_bias_dict['weight']['weight_ih'] = w1
Expand Down Expand Up @@ -614,8 +622,7 @@ def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtyp
def _get_name(self):
return 'DynamicQuantizedRNNCell'

def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
self.check_forward_input(input)
if hx is None:
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
Expand Down Expand Up @@ -661,13 +668,12 @@ class LSTMCell(RNNCellBase):
"""

def __init__(self, *args, **kwargs):
super(LSTMCell, self).__init__(*args, num_chunks=4, **kwargs)
super(LSTMCell, self).__init__(*args, num_chunks=4, **kwargs) # type: ignore

def _get_name(self):
return 'DynamicQuantizedLSTMCell'

def forward(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
self.check_forward_input(input)
if hx is None:
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
Expand Down Expand Up @@ -707,8 +713,7 @@ def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8):
def _get_name(self):
return 'DynamicQuantizedGRUCell'

def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
self.check_forward_input(input)
if hx is None:
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
Expand Down