Skip to content

Commit

Permalink
Clean up type annotations in torch/nn/quantized/modules (#49941)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #49941

Test Plan: Sandcastle

Reviewed By: jerryzh168

Differential Revision: D25718715

fbshipit-source-id: 4c047215eed00ac593282587bcae95b623fcad48
  • Loading branch information
r-barnes authored and facebook-github-bot committed Jan 6, 2021
1 parent 70734f1 commit f498d80
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 54 deletions.
21 changes: 7 additions & 14 deletions torch/nn/quantized/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
def _get_name(self):
return 'QuantizedConv1d'

def set_weight_bias(self, w, b):
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params = torch.ops.quantized.conv1d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)

Expand Down Expand Up @@ -327,8 +326,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
def _get_name(self):
return 'QuantizedConv2d'

def set_weight_bias(self, w, b):
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params = torch.ops.quantized.conv2d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)

Expand Down Expand Up @@ -412,8 +410,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
def _get_name(self):
return 'QuantizedConv3d'

def set_weight_bias(self, w, b):
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params = torch.ops.quantized.conv3d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)

Expand Down Expand Up @@ -466,8 +463,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode)

def _input_padding(self, kernel_size, dilation, padding):
# type: (List[int], List[int], List[int]) -> List[int]
def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: List[int]) -> List[int]:
res = torch.jit.annotate(List[int], [])
for kdx in range(len(kernel_size)):
pad = (dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx])
Expand Down Expand Up @@ -561,8 +557,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
def _get_name(self):
return 'QuantizedConvTranpose1d'

def set_weight_bias(self, w, b):
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(
w, b, self.stride, self.padding, self.output_padding, self.dilation,
self.groups)
Expand Down Expand Up @@ -645,8 +640,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
def _get_name(self):
return 'QuantizedConvTranpose2d'

def set_weight_bias(self, w, b):
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(
w, b, self.stride, self.padding, self.output_padding, self.dilation,
self.groups)
Expand Down Expand Up @@ -730,8 +724,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
def _get_name(self):
return 'QuantizedConvTranpose3d'

def set_weight_bias(self, w, b):
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(
w, b, self.stride, self.padding, self.output_padding, self.dilation,
self.groups)
Expand Down
6 changes: 2 additions & 4 deletions torch/nn/quantized/modules/embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
raise NotImplementedError('Unsupported dtype on quantized embedding! Supports quint8 and quint4x2.')

@torch.jit.export
def set_weight(self, weight):
# type: (torch.Tensor) -> None
def set_weight(self, weight: torch.Tensor) -> None:
if self.dtype in [torch.quint8, torch.quint4x2]:
self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
else:
Expand Down Expand Up @@ -126,8 +125,7 @@ def extra_repr(self):

return extra_repr_str

def set_weight(self, w):
# type: (torch.Tensor) -> None
def set_weight(self, w: torch.Tensor) -> None:
self._packed_params.set_weight(w)

def weight(self):
Expand Down
54 changes: 18 additions & 36 deletions torch/nn/quantized/modules/functional_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,45 +40,39 @@ def forward(self, x):
"'forward'. Please use the underlying operation")

r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
def add(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def add(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
r = self.activation_post_process(r)
return r

r"""Operation equivalent to ``torch.add(Tensor, float)``"""
def add_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
def add_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.add(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r

r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
def mul(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def mul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.mul(x, y)
r = self.activation_post_process(r)
return r

r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
def mul_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.mul(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r

r"""Operation equivalent to ``torch.cat``"""
def cat(self, x, dim=0):
# type: (List[Tensor], int) -> Tensor
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
r = torch.cat(x, dim=dim)
r = self.activation_post_process(r)
return r

r"""Operation equivalent to ``relu(torch.add(x,y))``"""
def add_relu(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
r = torch.nn.functional.relu(r)
r = self.activation_post_process(r)
Expand All @@ -101,38 +95,32 @@ def forward(self, x):
"'forward'. Please use the underlying operation")

r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
def add(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def add(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
return r

r"""Operation equivalent to ``torch.add(Tensor, float)``"""
def add_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
def add_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.add(x, y)
return r

r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
def mul(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def mul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.mul(x, y)
return r

r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
def mul_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.mul(x, y)
return r

r"""Operation equivalent to ``torch.cat``"""
def cat(self, x, dim=0):
# type: (List[Tensor], int) -> Tensor
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
r = torch.cat(x, dim=dim)
return r

r"""Operation equivalent to ``relu(torch.add(x,y))``"""
def add_relu(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
r = torch.nn.functional.relu(r)
return r
Expand Down Expand Up @@ -195,45 +183,39 @@ def forward(self, x):
"'forward'. Please use the underlying operation")

r"""Operation equivalent to ``torch.ops.quantized.add``"""
def add(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def add(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point)
r = self.activation_post_process(r)
return r

r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``"""
def add_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
def add_scalar(self, x: Tensor, y: float) -> Tensor:
r = ops.quantized.add_scalar(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r

r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``"""
def mul(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def mul(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point)
r = self.activation_post_process(r)
return r

r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``"""
def mul_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
r = ops.quantized.mul_scalar(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r

r"""Operation equivalent to ``torch.ops.quantized.cat``"""
def cat(self, x, dim=0):
# type: (List[Tensor], int) -> Tensor
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim)
r = self.activation_post_process(r)
return r

r"""Operation equivalent to ``torch.ops.quantized.add_relu``"""
def add_relu(self, x, y):
# type: (Tensor, Tensor) -> Tensor
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point)
r = self.activation_post_process(r)
return r
Expand Down

0 comments on commit f498d80

Please sign in to comment.