Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ python_library(
],
typing = True,
deps = [
"fbcode//executorch/backends/cadence/aot:utils",
"fbcode//caffe2:torch",
"fbcode//executorch/exir:scalar_type",
],
Expand Down
80 changes: 80 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

# pyre-strict


from typing import Optional

import torch

from executorch.exir.scalar_type import ScalarType
from torch.library import impl, Library

Expand All @@ -21,6 +23,8 @@
ScalarType.QINT32: torch.qint32,
}

_Number = bool | int | float


@impl(m, "quantize_per_tensor")
def quantize_per_tensor(
Expand Down Expand Up @@ -294,6 +298,82 @@ def quantized_layer_norm_per_tensor(
)


@impl(m, "quantized_conv_nchw")
def quantized_conv_nchw(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: tuple[int, int],
padding: tuple[int, int],
dilation: tuple[int, int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
) -> torch.Tensor:
"""
Quantized convolution operation.

Args:
- input_tensor (Tensor): The activations tensor
- weight (Tensor): The weight tensor
- bias (Tensor): The bias tensor
- stride (Tuple[int]): The stride of the convolution
- padding (Tuple[int]): The padding of the convolution
- dilation (Tuple[int]): The dilation of the convolution
- groups (int): The number of groups
- in_zero_point (int): The quantized mapping of zero for the input
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
- bias_scale (Tensor): The quantized bias scale
- output_scale (float): The scale of the output
- output_zero_point (int): The zero point of the output
- out_multiplier (Tensor): Unused
- out_shift (Tensor): Unused
"""
if weight_zero_point.view(-1).shape != (1,):
raise ValueError("Weight zero point must be a scalar")

if bias_scale.view(-1).shape != (1,):
raise ValueError("Bias scale must be a scalar")

if len(input_tensor.shape) == 3:
float_out = torch.nn.functional.conv1d(
(input_tensor - in_zero_point).float(),
(weight - weight_zero_point).float(),
(bias * bias_scale).float(),
stride[1],
padding[1],
dilation[1],
groups,
)

elif len(input_tensor.shape) == 4:
float_out = torch.nn.functional.conv2d(
(input_tensor - in_zero_point).float(),
(weight - weight_zero_point).float(),
(bias * bias_scale).float(),
stride,
padding,
dilation,
groups,
)
else:
raise ValueError("Input tensor must be 3D or 4D")

return quantize_per_tensor(
float_out,
1.0 / output_scale,
output_zero_point,
torch.iinfo(input_tensor.dtype).min,
torch.iinfo(input_tensor.dtype).max,
input_tensor.dtype,
)


@impl(m, "requantize")
def requantize(
input: torch.Tensor,
Expand Down
Loading
Loading