Skip to content

Commit

Permalink
Clean up some type annotations in torch/jit (#49939)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49939

Upgrades type annotations from Python2 to Python3

Test Plan: Sandcastle tests

Reviewed By: xush6528

Differential Revision: D25717573

fbshipit-source-id: 7d5c98fafaa224e0504b73dc69b1e4a6410c0494
  • Loading branch information
r-barnes authored and facebook-github-bot committed Jan 7, 2021
1 parent e49372d commit 6838ece
Showing 1 changed file with 24 additions and 36 deletions.
60 changes: 24 additions & 36 deletions torch/jit/quantized.py
Expand Up @@ -130,8 +130,7 @@ def check_forward_input(self, input):
input.size(1), self.input_size))

@torch.jit.script_method
def check_forward_hidden(self, input, hx, hidden_label=''):
# type: (Tensor, Tensor, str) -> None
def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None:
if input.size(0) != hx.size(0):
raise RuntimeError(
"Input batch size {} doesn't match hidden{} batch size {}".format(
Expand Down Expand Up @@ -169,8 +168,7 @@ def __init__(self, other):
self.nonlinearity = other.nonlinearity

@torch.jit.script_method
def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
self.check_forward_input(input)
if hx is None:
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
Expand Down Expand Up @@ -201,8 +199,7 @@ def __init__(self, other):
super(QuantizedLSTMCell, self).__init__(other)

@torch.jit.script_method
def forward(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
self.check_forward_input(input)
if hx is None:
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
Expand All @@ -222,8 +219,7 @@ def __init__(self, other):
super(QuantizedGRUCell, self).__init__(other)

@torch.jit.script_method
def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
self.check_forward_input(input)
if hx is None:
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
Expand All @@ -236,8 +232,7 @@ def forward(self, input, hx=None):
)


def apply_permutation(tensor, permutation, dim=1):
# type: (Tensor, Tensor, int) -> Tensor
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
return tensor.index_select(dim, permutation)


Expand Down Expand Up @@ -303,8 +298,7 @@ def get_weight_bias(ihhh):
self.all_weights.append(cell_params)

@torch.jit.script_method
def check_input(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> None
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
expected_input_dim = 2 if batch_sizes is not None else 3
if input.dim() != expected_input_dim:
raise RuntimeError(
Expand All @@ -316,8 +310,7 @@ def check_input(self, input, batch_sizes):
self.input_size, input.size(-1)))

@torch.jit.script_method
def get_expected_hidden_size(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
if batch_sizes is not None:
mini_batch = int(batch_sizes[0])
else:
Expand All @@ -328,21 +321,19 @@ def get_expected_hidden_size(self, input, batch_sizes):
return expected_hidden_size

@torch.jit.script_method
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
# type: (Tensor, Tuple[int, int, int], str) -> None
def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
msg: str = 'Expected hidden size {}, got {}') -> None:
if hx.size() != expected_hidden_size:
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))

@torch.jit.script_method
def check_forward_args(self, input, hidden, batch_sizes):
# type: (Tensor, Tensor, Optional[Tensor]) -> None
def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None:
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
self.check_hidden_size(hidden, expected_hidden_size, msg='Expected hidden size {}, got {}')

@torch.jit.script_method
def permute_hidden(self, hx, permutation):
# type: (Tensor, Optional[Tensor]) -> Tensor
def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
if permutation is None:
return hx
return apply_permutation(hx, permutation)
Expand All @@ -355,8 +346,9 @@ def __init__(self, other, dtype):
super(QuantizedLSTM, self).__init__(other, dtype)

@torch.jit.script_method
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
def forward_impl(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]], batch_sizes: Optional[Tensor],
max_batch_size: int, sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
# noqa
if hx is None:
num_directions = 2 if self.bidirectional else 1
zeros = torch.zeros(self.num_layers * num_directions,
Expand All @@ -379,8 +371,7 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
return output, hidden

@torch.jit.script_method
def forward_tensor(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
def forward_tensor(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
Expand All @@ -391,8 +382,8 @@ def forward_tensor(self, input, hx=None):
return output, self.permute_hidden(hidden, unsorted_indices)

@torch.jit.script_method
def forward_packed(self, input, hx=None):
# type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
def forward_packed(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)
Expand All @@ -404,15 +395,13 @@ def forward_packed(self, input, hx=None):


@torch.jit.script_method
def permute_hidden(self, hx, permutation):
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
if permutation is None:
return hx
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)

@torch.jit.script_method
def check_forward_args(self, input, hidden, batch_sizes):
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None
def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]) -> None:
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)

Expand All @@ -432,8 +421,9 @@ class QuantizedGRU(QuantizedRNNBase):
__overloads__ = {'forward': ['forward_packed', 'forward_tensor']}

@torch.jit.script_method
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
# type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa
def forward_impl(self, input: Tensor, hx: Optional[Tensor], batch_sizes: Optional[Tensor], max_batch_size: int,
sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
# noqa
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.zeros(self.num_layers * num_directions,
Expand All @@ -459,8 +449,7 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
return output, hidden

@torch.jit.script_method
def forward_tensor(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
def forward_tensor(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
Expand All @@ -470,8 +459,7 @@ def forward_tensor(self, input, hx=None):
return output, self.permute_hidden(hidden, unsorted_indices)

@torch.jit.script_method
def forward_packed(self, input, hx=None):
# type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor]
def forward_packed(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]:
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)
Expand Down

0 comments on commit 6838ece

Please sign in to comment.