From 71a1371d1d0cbd9de2c1abeebca06ed5b563f512 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 4 Sep 2025 12:06:54 -0700 Subject: [PATCH 1/3] Update backend-agnostic quantize_per_tensor to use scale_inv to match the quantize_per_tensor_out implementation (#13846) Summary: Fixes an interface mismatch between the quantize_per_tensor_out implementation and the old quantize_per_tensor python implementation Reviewed By: hsharma35 Differential Revision: D81459313 --- backends/cadence/aot/ref_implementations.py | 32 +++++++++++++++---- .../aot/tests/test_ref_implementations.py | 5 +-- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index bda8f9425ea..abbfc41889b 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -24,7 +24,7 @@ @impl(m, "quantize_per_tensor") def quantize_per_tensor( - input: torch.Tensor, + input_tensor: torch.Tensor, scale: float, zero_point: int, quant_min: int, @@ -35,10 +35,10 @@ def quantize_per_tensor( Quantizes a floating-point tensor to an integral tensor. Args: - - input (Tensor): input tensor - - scale (float): Quantization scale. Derived from the ratio + - input_tensor (Tensor): input tensor + - scale (float): Inverse of quantization scale. Derived from the ratio between the min/max of the floating-point tensor and the - min/max of the quantized range. + min/max of the quantized range, and then inverted. - zero_point (int): The point which represents 0 in the quantized range. For example, consider the floating point range [-1., 2.] and quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from @@ -61,7 +61,12 @@ def quantize_per_tensor( raise ValueError( f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}" ) - return torch.round(input / scale + zero_point).to(dtype) + + dequantized = torch.round(input_tensor * scale + zero_point).to(dtype) + return torch.max( + torch.min(dequantized, torch.tensor(quant_max)), + torch.tensor(quant_min), + ) @impl(m, "dequantize_per_tensor") @@ -173,9 +178,16 @@ def quantized_add( dequant_X = X_scale * (X - X_zero_point) dequant_Y = Y_scale * (Y - Y_zero_point) + out_scale_inv = 1 / out_scale + # q_min/q_max are unused args return quantize_per_tensor( - dequant_X + dequant_Y, out_scale, out_zero_point, -128, 127, dtype + dequant_X + dequant_Y, + out_scale_inv, + out_zero_point, + torch.iinfo(dtype).min, + torch.iinfo(dtype).max, + dtype, ) @@ -206,6 +218,7 @@ def quantized_linear( - offset (Tensor): Unused """ out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0]) + out_scale_inv = 1 / out_scale N, K = weight.shape @@ -223,7 +236,12 @@ def quantized_linear( src - in_zero_point, weight - weight_zero_point, bias ) return quantize_per_tensor( - out, out_scale, out_zero_point, -128, 127, dtype + out, + out_scale_inv, + out_zero_point, + torch.iinfo(dtype).min, + torch.iinfo(dtype).max, + dtype, ).reshape(*leading_dims, N) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 1b02926b3f8..95e1a374463 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -40,11 +40,12 @@ def test_quantize_per_tensor( ) -> None: input_tensor = torch.tensor([input_value]) scale = (f_max - f_min) / (q_max - q_min) - zero_point = round(-f_min / scale) + q_min + inv_scale = 1.0 / scale + zero_point = round(-f_min * inv_scale) + q_min expected_output = torch.tensor([expected_value], dtype=target_dtype) output = quantize_per_tensor( - input_tensor, scale, zero_point, q_min, q_max, target_dtype + input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype ) self.assertEqual( From 0da9e0f1a3ef14a8a957008fd54bf87260120f89 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 4 Sep 2025 12:06:54 -0700 Subject: [PATCH 2/3] Backend-agnostic implementation of quantized_layer_norm_per_tensor (#13847) Summary: Continuing support for supporting backend-agnostic Cadence custom ops. Reviewed By: hsharma35 Differential Revision: D81459333 --- backends/cadence/aot/ref_implementations.py | 49 ++++++++++ .../aot/tests/test_ref_implementations.py | 95 +++++++++++++++++++ 2 files changed, 144 insertions(+) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index abbfc41889b..5595d74c9c8 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -245,6 +245,55 @@ def quantized_linear( ).reshape(*leading_dims, N) +@impl(m, "quantized_layer_norm_per_tensor") +def quantized_layer_norm_per_tensor( + input_tensor: torch.Tensor, + X_scale: float, + X_zero_point: int, + normalized_shape: int, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + output_scale: float, + output_zero_point: int, +) -> torch.Tensor: + """ + Quantized layer norm operation. + + Args: + - input_tensor (Tensor): The activations tensor + - X_scale (float): The scale of the input + - X_zero_point (int): The zero point of the input + - normalized_shape (int): The shape of the input + - weight (Tensor): The weight tensor + - bias (Tensor): The bias tensor + - eps (float): The epsilon value + - output_scale (float): The scale of the output + - output_zero_point (int): The zero point of the output + """ + supported_dtypes = [torch.int8, torch.uint8] + if input_tensor.dtype not in supported_dtypes: + raise ValueError( + f"Input dtype must be one of {supported_dtypes}. Got {input_tensor.dtype}" + ) + + float_input_tensor = dequantize_per_tensor( + input_tensor, X_scale, X_zero_point, -128, 127, torch.float32 + ) + out = torch.nn.functional.layer_norm( + float_input_tensor, (normalized_shape,), weight, bias, eps=eps + ) + + return quantize_per_tensor( + out, + 1 / output_scale, + output_zero_point, + torch.iinfo(input_tensor.dtype).min, + torch.iinfo(input_tensor.dtype).max, + input_tensor.dtype, + ) + + @impl(m, "requantize") def requantize( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 95e1a374463..52f6c308da8 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -15,6 +15,7 @@ dequantize_per_tensor, quantize_per_tensor, quantized_add, + quantized_layer_norm_per_tensor, quantized_linear, ) from executorch.backends.cadence.aot.typing_stubs import expand @@ -240,3 +241,97 @@ def test_quantized_linear( torch.equal(output, expected_output), f"Values don't match: got {output}, expected {expected_output}", ) + + @expand( + [ + # Test case 1: Simple case with int8, zero mean input + ( + torch.tensor( + [[-1, 1]], dtype=torch.int8 + ), # input: dequantized to [-0.1, 0.1] + 0.1, # X_scale + 0, # X_zero_point + 2, # normalized_shape (last dimension) + torch.tensor([1.0, 1.0]), # weight + torch.tensor([0.0, 0.0]), # bias + 1e-5, # eps + 0.1, # output_scale + 0, # output_zero_point + torch.int8, # dtype + torch.tensor([[-10, 10]], dtype=torch.int8), # expected_output + ), + # Test case 2: uint8 with zero_point offset + ( + torch.tensor( + [[127, 129]], dtype=torch.uint8 + ), # input: dequantized to [-0.05, 0.05] + 0.05, # X_scale + 128, # X_zero_point + 2, # normalized_shape (last dimension) + torch.tensor([1.0, 1.0]), # weight + torch.tensor([0.0, 0.0]), # bias + 1e-5, # eps + 0.05, # output_scale + 128, # output_zero_point + torch.uint8, # dtype + torch.tensor([[108, 148]], dtype=torch.uint8), # expected_output + ), + # Test case 3: Test with weight and bias scaling + ( + torch.tensor( + [[-2, 2]], dtype=torch.int8 + ), # input: dequantized to [-0.2, 0.2] + 0.1, # X_scale + 0, # X_zero_point + 2, # normalized_shape (last dimension) + torch.tensor( + [2.0, 0.5] + ), # weight: scale first element by 2, second by 0.5 + torch.tensor( + [0.1, -0.1] + ), # bias: add 0.1 to first, subtract 0.1 from second + 1e-5, # eps + 0.1, # output_scale + 0, # output_zero_point + torch.int8, # dtype + torch.tensor([[-19, 4]], dtype=torch.int8), # expected_output + ), + ] + ) + def test_quantized_layer_norm_per_tensor( + self, + input_tensor: torch.Tensor, + X_scale: float, + X_zero_point: int, + normalized_shape: int, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + output_scale: float, + output_zero_point: int, + dtype: torch.dtype, + expected_output: torch.Tensor, + ) -> None: + output = quantized_layer_norm_per_tensor( + input_tensor, + X_scale, + X_zero_point, + normalized_shape, + weight, + bias, + eps, + output_scale, + output_zero_point, + ) + + # Verify output properties + self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}") + self.assertEqual( + output.shape, input_tensor.shape, "Output shape should match input shape" + ) + + # Verify output matches expected values + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected. Got {output}, expected {expected_output}", + ) From 8d38923535a7a22099a4d1bbc63f69759efad103 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 4 Sep 2025 12:06:54 -0700 Subject: [PATCH 3/3] Backend-agnostic implementation of quantized_conv_nchw Summary: Continued support of backend agnostic custom Cadence ops Reviewed By: hsharma35 Differential Revision: D81465757 --- backends/cadence/aot/TARGETS | 1 + backends/cadence/aot/ref_implementations.py | 80 +++++ .../aot/tests/test_ref_implementations.py | 335 ++++++++++++++++++ 3 files changed, 416 insertions(+) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 1a2c5a9709f..eb0e17f9858 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -129,6 +129,7 @@ python_library( ], typing = True, deps = [ + "fbcode//executorch/backends/cadence/aot:utils", "fbcode//caffe2:torch", "fbcode//executorch/exir:scalar_type", ], diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 5595d74c9c8..05792a5cfa7 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -6,9 +6,11 @@ # pyre-strict + from typing import Optional import torch + from executorch.exir.scalar_type import ScalarType from torch.library import impl, Library @@ -21,6 +23,8 @@ ScalarType.QINT32: torch.qint32, } +_Number = bool | int | float + @impl(m, "quantize_per_tensor") def quantize_per_tensor( @@ -294,6 +298,82 @@ def quantized_layer_norm_per_tensor( ) +@impl(m, "quantized_conv_nchw") +def quantized_conv_nchw( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: torch.Tensor, + bias_scale: torch.Tensor, + output_scale: float, + output_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, +) -> torch.Tensor: + """ + Quantized convolution operation. + + Args: + - input_tensor (Tensor): The activations tensor + - weight (Tensor): The weight tensor + - bias (Tensor): The bias tensor + - stride (Tuple[int]): The stride of the convolution + - padding (Tuple[int]): The padding of the convolution + - dilation (Tuple[int]): The dilation of the convolution + - groups (int): The number of groups + - in_zero_point (int): The quantized mapping of zero for the input + - weight_zero_point (Tensor): The quantized mapping of zero for the weight + - bias_scale (Tensor): The quantized bias scale + - output_scale (float): The scale of the output + - output_zero_point (int): The zero point of the output + - out_multiplier (Tensor): Unused + - out_shift (Tensor): Unused + """ + if weight_zero_point.view(-1).shape != (1,): + raise ValueError("Weight zero point must be a scalar") + + if bias_scale.view(-1).shape != (1,): + raise ValueError("Bias scale must be a scalar") + + if len(input_tensor.shape) == 3: + float_out = torch.nn.functional.conv1d( + (input_tensor - in_zero_point).float(), + (weight - weight_zero_point).float(), + (bias * bias_scale).float(), + stride[1], + padding[1], + dilation[1], + groups, + ) + + elif len(input_tensor.shape) == 4: + float_out = torch.nn.functional.conv2d( + (input_tensor - in_zero_point).float(), + (weight - weight_zero_point).float(), + (bias * bias_scale).float(), + stride, + padding, + dilation, + groups, + ) + else: + raise ValueError("Input tensor must be 3D or 4D") + + return quantize_per_tensor( + float_out, + 1.0 / output_scale, + output_zero_point, + torch.iinfo(input_tensor.dtype).min, + torch.iinfo(input_tensor.dtype).max, + input_tensor.dtype, + ) + + @impl(m, "requantize") def requantize( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 52f6c308da8..37a250c70f7 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -15,6 +15,7 @@ dequantize_per_tensor, quantize_per_tensor, quantized_add, + quantized_conv_nchw, quantized_layer_norm_per_tensor, quantized_linear, ) @@ -335,3 +336,337 @@ def test_quantized_layer_norm_per_tensor( torch.equal(output, expected_output), f"Output values don't match expected. Got {output}, expected {expected_output}", ) + + @expand( + [ + # Test case 1: Basic 2D convolution with int8 + ( + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int8), # input: 1x1x2x2 + torch.tensor( + [[[[1, 0], [0, 1]]]], dtype=torch.int8 + ), # weight: 1x1x2x2 (identity-like) + torch.tensor([0], dtype=torch.int8), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.1, # output_scale + 0, # output_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int8), # out_shift + torch.int8, # dtype + torch.tensor( + [[[[50]]]], dtype=torch.int8 + ), # expected_output: (1*1 + 4*1) / 0.1 = 50 + ), + # Test case 2: 2D convolution with stride and padding + ( + torch.tensor( + [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.int8 + ), # input: 1x1x3x3 + torch.tensor( + [[[[1, 1], [1, 1]]]], dtype=torch.int8 + ), # weight: 1x1x2x2 (sum filter) + torch.tensor([0], dtype=torch.int8), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.25, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[48, 64], [96, 112]]]], dtype=torch.int8 + ), # expected_output: convolution results with output_scale=0.25 + ), + # Test case 3: uint8 with non-zero zero points + ( + torch.tensor( + [[[[130, 132], [134, 136]]]], dtype=torch.uint8 + ), # input: 1x1x2x2 + torch.tensor( + [[[[129, 128], [128, 129]]]], dtype=torch.uint8 + ), # weight: 1x1x2x2 (values close to zero_point) + torch.tensor([10], dtype=torch.uint8), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 128, # in_zero_point + torch.tensor([128], dtype=torch.uint8), # weight_zero_point + torch.tensor([0.1], dtype=torch.float32), # bias_scale + 0.1, # output_scale + 128, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.uint8, # dtype + torch.tensor( + [[[[238]]]], dtype=torch.uint8 + ), # (130 - 128) + (134 - 128) = 10 + # + bias -> 10 + 1 = 11 + # round(11 / 0.1 + 128) = 238 + ), + # Test case 4: 1D convolution (3D input tensor) + ( + torch.tensor( + [[[1, 2, 3, 4]]], dtype=torch.int8 + ), # input: 1x1x4 (N, C, W) + torch.tensor( + [[[1, 1]]], dtype=torch.int8 + ), # weight: 1x1x2 (OC, IC, KW) + torch.tensor([0], dtype=torch.int8), # bias + (1, 1), # stride (padding for 2D, actual stride is stride[1]) + (0, 0), # padding (padding for 2D, actual padding is padding[1]) + (1, 1), # dilation (padding for 2D, actual dilation is dilation[1]) + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.5, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[6, 10, 14]]], dtype=torch.int8 + ), # expected_output: [1+2, 2+3, 3+4] / 0.5 = [6, 10, 14] + ), + # Test case 5: Multiple output channels + ( + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int8), # input: 1x1x2x2 + torch.tensor( + [ + [[[1, 0], [0, 1]]], # first output channel + [[[0, 1], [1, 0]]], # second output channel + ], + dtype=torch.int8, + ), # weight: 2x1x2x2 + torch.tensor([0, 5], dtype=torch.int8), # bias for each output channel + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.2, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[25]], [[50]]]], dtype=torch.int8 + ), # expected_output: [5/0.2, 10/0.2] = [25, 50] + ), + # Test case 6: Multiple input channels + ( + torch.tensor( + [ + [ + [[1, 2], [3, 4]], # first input channel + [[5, 6], [7, 8]], + ] # second input channel + ], + dtype=torch.int16, + ), # input: 1x2x2x2 + torch.tensor( + [ + [ + [[1, 0], [0, 1]], # weights for first input channel + [[0, 1], [1, 0]], + ] # weights for second input channel + ], + dtype=torch.int16, + ), # weight: 1x2x2x2 (1 output channel, 2 input channels) + torch.tensor([0], dtype=torch.int16), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int16), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.1, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int16, # dtype + torch.tensor( + [[[[180]]]], dtype=torch.int16 + ), # (1 + 4 + 6 + 7) / 0.1 = 180 + ), + # Test case 7: Multiple input and output channels + ( + torch.tensor( + [ + [ + [[1, 2], [3, 4]], # first input channel + [[2, 1], [4, 3]], + ] # second input channel + ], + dtype=torch.int16, + ), # input: 1x2x2x2 + torch.tensor( + [ + [ + [ + [1, 1], + [1, 1], + ], # first output channel, first input channel + [[1, 1], [1, 1]], + ], # first output channel, second input channel + [ + [ + [1, 0], + [0, 1], + ], # second output channel, first input channel + [[0, 1], [1, 0]], + ], # second output channel, second input channel + ], + dtype=torch.int16, + ), # weight: 2x2x2x2 (2 output channels, 2 input channels) + torch.tensor([0, 0], dtype=torch.int16), # bias for each output channel + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor( + [0], dtype=torch.int16 + ), # weight_zero_point for each output channel + torch.tensor([1.0], dtype=torch.float32), # bias_scale for each channel + 0.05, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int16, # dtype + torch.tensor([[[[400]], [[200]]]], dtype=torch.int16), + ), + # Test case 8: Grouped convolution (groups=2) + ( + torch.tensor( + [ + [ + [[1, 2], [3, 4]], # first input channel (group 1) + [[5, 6], [7, 8]], + ] # second input channel (group 2) + ], + dtype=torch.int8, + ), # input: 1x2x2x2 + torch.tensor( + [ + [ + [[1, 1], [1, 1]] + ], # first output channel (processes first input channel) + [ + [[1, 0], [0, 1]] + ], # second output channel (processes second input channel) + ], + dtype=torch.int8, + ), # weight: 2x1x2x2 (2 output channels, 1 input channel each due to groups=2) + torch.tensor([0, 0], dtype=torch.int8), # bias for each output channel + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 2, # groups (grouped convolution) + 0, # in_zero_point + torch.tensor( + [0], dtype=torch.int8 + ), # weight_zero_point for each output channel + torch.tensor([1.0], dtype=torch.float32), # bias_scale for each channel + 0.2, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[50]], [[65]]]], dtype=torch.int8 + ), # expected_output: [(1+2+3+4)/0.2, (5+8)/0.2] = [50, 65] + ), + # Test case 9: Convolution with stride=2 and padding=1 + ( + torch.tensor( + [[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]]], + dtype=torch.int8, + ), # input: 1x1x4x4 + torch.tensor( + [[[[1, 1], [1, 1]]]], dtype=torch.int8 + ), # weight: 1x1x2x2 (sum filter) + torch.tensor([0], dtype=torch.int8), # bias + (2, 2), # stride=2 + (1, 1), # padding=1 + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.5, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[2, 10, 8], [28, 68, 40], [26, 58, 32]]]], dtype=torch.int8 + ), + ), + ] + ) + def test_quantized_conv_nchw( + self, + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: torch.Tensor, + bias_scale: torch.Tensor, + output_scale: float, + output_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, + dtype: torch.dtype, + expected_output: torch.Tensor, + ) -> None: + output = quantized_conv_nchw( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ) + + # Verify output properties + self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}") + self.assertEqual( + output.shape, + expected_output.shape, + "Output shape should match expected shape", + ) + + # Verify output matches expected values + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected. Got {output}, expected {expected_output}", + )