diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 1a2c5a9709f..eb0e17f9858 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -129,6 +129,7 @@ python_library( ], typing = True, deps = [ + "fbcode//executorch/backends/cadence/aot:utils", "fbcode//caffe2:torch", "fbcode//executorch/exir:scalar_type", ], diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 5595d74c9c8..05792a5cfa7 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -6,9 +6,11 @@ # pyre-strict + from typing import Optional import torch + from executorch.exir.scalar_type import ScalarType from torch.library import impl, Library @@ -21,6 +23,8 @@ ScalarType.QINT32: torch.qint32, } +_Number = bool | int | float + @impl(m, "quantize_per_tensor") def quantize_per_tensor( @@ -294,6 +298,82 @@ def quantized_layer_norm_per_tensor( ) +@impl(m, "quantized_conv_nchw") +def quantized_conv_nchw( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: torch.Tensor, + bias_scale: torch.Tensor, + output_scale: float, + output_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, +) -> torch.Tensor: + """ + Quantized convolution operation. + + Args: + - input_tensor (Tensor): The activations tensor + - weight (Tensor): The weight tensor + - bias (Tensor): The bias tensor + - stride (Tuple[int]): The stride of the convolution + - padding (Tuple[int]): The padding of the convolution + - dilation (Tuple[int]): The dilation of the convolution + - groups (int): The number of groups + - in_zero_point (int): The quantized mapping of zero for the input + - weight_zero_point (Tensor): The quantized mapping of zero for the weight + - bias_scale (Tensor): The quantized bias scale + - output_scale (float): The scale of the output + - output_zero_point (int): The zero point of the output + - out_multiplier (Tensor): Unused + - out_shift (Tensor): Unused + """ + if weight_zero_point.view(-1).shape != (1,): + raise ValueError("Weight zero point must be a scalar") + + if bias_scale.view(-1).shape != (1,): + raise ValueError("Bias scale must be a scalar") + + if len(input_tensor.shape) == 3: + float_out = torch.nn.functional.conv1d( + (input_tensor - in_zero_point).float(), + (weight - weight_zero_point).float(), + (bias * bias_scale).float(), + stride[1], + padding[1], + dilation[1], + groups, + ) + + elif len(input_tensor.shape) == 4: + float_out = torch.nn.functional.conv2d( + (input_tensor - in_zero_point).float(), + (weight - weight_zero_point).float(), + (bias * bias_scale).float(), + stride, + padding, + dilation, + groups, + ) + else: + raise ValueError("Input tensor must be 3D or 4D") + + return quantize_per_tensor( + float_out, + 1.0 / output_scale, + output_zero_point, + torch.iinfo(input_tensor.dtype).min, + torch.iinfo(input_tensor.dtype).max, + input_tensor.dtype, + ) + + @impl(m, "requantize") def requantize( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 52f6c308da8..37a250c70f7 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -15,6 +15,7 @@ dequantize_per_tensor, quantize_per_tensor, quantized_add, + quantized_conv_nchw, quantized_layer_norm_per_tensor, quantized_linear, ) @@ -335,3 +336,337 @@ def test_quantized_layer_norm_per_tensor( torch.equal(output, expected_output), f"Output values don't match expected. Got {output}, expected {expected_output}", ) + + @expand( + [ + # Test case 1: Basic 2D convolution with int8 + ( + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int8), # input: 1x1x2x2 + torch.tensor( + [[[[1, 0], [0, 1]]]], dtype=torch.int8 + ), # weight: 1x1x2x2 (identity-like) + torch.tensor([0], dtype=torch.int8), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.1, # output_scale + 0, # output_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int8), # out_shift + torch.int8, # dtype + torch.tensor( + [[[[50]]]], dtype=torch.int8 + ), # expected_output: (1*1 + 4*1) / 0.1 = 50 + ), + # Test case 2: 2D convolution with stride and padding + ( + torch.tensor( + [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.int8 + ), # input: 1x1x3x3 + torch.tensor( + [[[[1, 1], [1, 1]]]], dtype=torch.int8 + ), # weight: 1x1x2x2 (sum filter) + torch.tensor([0], dtype=torch.int8), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.25, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[48, 64], [96, 112]]]], dtype=torch.int8 + ), # expected_output: convolution results with output_scale=0.25 + ), + # Test case 3: uint8 with non-zero zero points + ( + torch.tensor( + [[[[130, 132], [134, 136]]]], dtype=torch.uint8 + ), # input: 1x1x2x2 + torch.tensor( + [[[[129, 128], [128, 129]]]], dtype=torch.uint8 + ), # weight: 1x1x2x2 (values close to zero_point) + torch.tensor([10], dtype=torch.uint8), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 128, # in_zero_point + torch.tensor([128], dtype=torch.uint8), # weight_zero_point + torch.tensor([0.1], dtype=torch.float32), # bias_scale + 0.1, # output_scale + 128, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.uint8, # dtype + torch.tensor( + [[[[238]]]], dtype=torch.uint8 + ), # (130 - 128) + (134 - 128) = 10 + # + bias -> 10 + 1 = 11 + # round(11 / 0.1 + 128) = 238 + ), + # Test case 4: 1D convolution (3D input tensor) + ( + torch.tensor( + [[[1, 2, 3, 4]]], dtype=torch.int8 + ), # input: 1x1x4 (N, C, W) + torch.tensor( + [[[1, 1]]], dtype=torch.int8 + ), # weight: 1x1x2 (OC, IC, KW) + torch.tensor([0], dtype=torch.int8), # bias + (1, 1), # stride (padding for 2D, actual stride is stride[1]) + (0, 0), # padding (padding for 2D, actual padding is padding[1]) + (1, 1), # dilation (padding for 2D, actual dilation is dilation[1]) + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.5, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[6, 10, 14]]], dtype=torch.int8 + ), # expected_output: [1+2, 2+3, 3+4] / 0.5 = [6, 10, 14] + ), + # Test case 5: Multiple output channels + ( + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int8), # input: 1x1x2x2 + torch.tensor( + [ + [[[1, 0], [0, 1]]], # first output channel + [[[0, 1], [1, 0]]], # second output channel + ], + dtype=torch.int8, + ), # weight: 2x1x2x2 + torch.tensor([0, 5], dtype=torch.int8), # bias for each output channel + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.2, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[25]], [[50]]]], dtype=torch.int8 + ), # expected_output: [5/0.2, 10/0.2] = [25, 50] + ), + # Test case 6: Multiple input channels + ( + torch.tensor( + [ + [ + [[1, 2], [3, 4]], # first input channel + [[5, 6], [7, 8]], + ] # second input channel + ], + dtype=torch.int16, + ), # input: 1x2x2x2 + torch.tensor( + [ + [ + [[1, 0], [0, 1]], # weights for first input channel + [[0, 1], [1, 0]], + ] # weights for second input channel + ], + dtype=torch.int16, + ), # weight: 1x2x2x2 (1 output channel, 2 input channels) + torch.tensor([0], dtype=torch.int16), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int16), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.1, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int16, # dtype + torch.tensor( + [[[[180]]]], dtype=torch.int16 + ), # (1 + 4 + 6 + 7) / 0.1 = 180 + ), + # Test case 7: Multiple input and output channels + ( + torch.tensor( + [ + [ + [[1, 2], [3, 4]], # first input channel + [[2, 1], [4, 3]], + ] # second input channel + ], + dtype=torch.int16, + ), # input: 1x2x2x2 + torch.tensor( + [ + [ + [ + [1, 1], + [1, 1], + ], # first output channel, first input channel + [[1, 1], [1, 1]], + ], # first output channel, second input channel + [ + [ + [1, 0], + [0, 1], + ], # second output channel, first input channel + [[0, 1], [1, 0]], + ], # second output channel, second input channel + ], + dtype=torch.int16, + ), # weight: 2x2x2x2 (2 output channels, 2 input channels) + torch.tensor([0, 0], dtype=torch.int16), # bias for each output channel + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor( + [0], dtype=torch.int16 + ), # weight_zero_point for each output channel + torch.tensor([1.0], dtype=torch.float32), # bias_scale for each channel + 0.05, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int16, # dtype + torch.tensor([[[[400]], [[200]]]], dtype=torch.int16), + ), + # Test case 8: Grouped convolution (groups=2) + ( + torch.tensor( + [ + [ + [[1, 2], [3, 4]], # first input channel (group 1) + [[5, 6], [7, 8]], + ] # second input channel (group 2) + ], + dtype=torch.int8, + ), # input: 1x2x2x2 + torch.tensor( + [ + [ + [[1, 1], [1, 1]] + ], # first output channel (processes first input channel) + [ + [[1, 0], [0, 1]] + ], # second output channel (processes second input channel) + ], + dtype=torch.int8, + ), # weight: 2x1x2x2 (2 output channels, 1 input channel each due to groups=2) + torch.tensor([0, 0], dtype=torch.int8), # bias for each output channel + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 2, # groups (grouped convolution) + 0, # in_zero_point + torch.tensor( + [0], dtype=torch.int8 + ), # weight_zero_point for each output channel + torch.tensor([1.0], dtype=torch.float32), # bias_scale for each channel + 0.2, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[50]], [[65]]]], dtype=torch.int8 + ), # expected_output: [(1+2+3+4)/0.2, (5+8)/0.2] = [50, 65] + ), + # Test case 9: Convolution with stride=2 and padding=1 + ( + torch.tensor( + [[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]]], + dtype=torch.int8, + ), # input: 1x1x4x4 + torch.tensor( + [[[[1, 1], [1, 1]]]], dtype=torch.int8 + ), # weight: 1x1x2x2 (sum filter) + torch.tensor([0], dtype=torch.int8), # bias + (2, 2), # stride=2 + (1, 1), # padding=1 + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.5, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[2, 10, 8], [28, 68, 40], [26, 58, 32]]]], dtype=torch.int8 + ), + ), + ] + ) + def test_quantized_conv_nchw( + self, + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: torch.Tensor, + bias_scale: torch.Tensor, + output_scale: float, + output_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, + dtype: torch.dtype, + expected_output: torch.Tensor, + ) -> None: + output = quantized_conv_nchw( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ) + + # Verify output properties + self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}") + self.assertEqual( + output.shape, + expected_output.shape, + "Output shape should match expected shape", + ) + + # Verify output matches expected values + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected. Got {output}, expected {expected_output}", + )