From 6f7b71b66ea428a7bd2da8cfef94fa0d7c1ebfc0 Mon Sep 17 00:00:00 2001 From: agrebenisan Date: Tue, 9 Sep 2025 09:25:35 -0700 Subject: [PATCH 01/10] Add uint8/int8 specializations for conv per tensor Summary: Continued support of adding custom Cadence python references Differential Revision: D81720359 --- backends/cadence/aot/ops_registrations.py | 60 ++++++++ backends/cadence/aot/ref_implementations.py | 141 +++++++++++++++++- .../aot/tests/test_ref_implementations.py | 124 ++++++++++----- 3 files changed, 285 insertions(+), 40 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 68091e2d521..ce0fba47610 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -873,6 +873,11 @@ def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -917,6 +922,11 @@ def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -961,6 +971,11 @@ def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape @@ -1005,6 +1020,11 @@ def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape @@ -1049,6 +1069,11 @@ def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -1093,6 +1118,11 @@ def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -1137,6 +1167,11 @@ def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape @@ -1181,6 +1216,11 @@ def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape @@ -1225,6 +1265,11 @@ def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -1269,6 +1314,11 @@ def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -1313,6 +1363,11 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape @@ -1357,6 +1412,11 @@ def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 790341f8f5a..0cd55326b86 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Optional +from typing import Callable, Optional import torch @@ -479,6 +479,145 @@ def quantized_conv_nhwc_per_tensor( ) +def quantized_conv_variant( + layout: str, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + """Create a quantized conv variant with type checking.""" + + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + def variant( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, + ) -> torch.Tensor: + assert ( + input_tensor.dtype == input_dtype + ), f"Expected input dtype {input_dtype}, got {input_tensor.dtype}" + assert ( + weight.dtype == weight_dtype + ), f"Expected weight dtype {weight_dtype}, got {weight.dtype}" + + assert ( + bias.dtype == torch.int32 + ), f"Expected bias dtype int32, got {bias.dtype}" + + # Call the appropriate base function + match layout: + case "nchw": + return quantized_conv_nchw_per_tensor( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ) + case "nhwc": + return quantized_conv_nhwc_per_tensor( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ) + case _: + raise ValueError(f"Unknown layout {layout}") + + return variant + + return decorator + + +@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nchw", torch.int8, torch.int8) +def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) +def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nhwc", torch.int8, torch.int8) +def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) +def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nchw", torch.int8, torch.int8) +def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) +def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nhwc", torch.int8, torch.int8) +def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) +def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nchw", torch.int8, torch.int8) +def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) +def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nhwc", torch.int8, torch.int8) +def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) +def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + @impl(m, "quantized_relu") def quantized_relu( X: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 54247e0b53b..4e2829a8460 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -15,7 +15,19 @@ dequantize_per_tensor, quantize_per_tensor, quantized_add, + quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor, quantized_conv_nchw_per_tensor, + quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor, quantized_conv_nhwc_per_tensor, quantized_layer_norm_per_tensor, quantized_linear, @@ -350,7 +362,7 @@ def test_quantized_layer_norm_per_tensor( torch.tensor( [[[[1, 0], [0, 1]]]], dtype=torch.int8 ), # weight: 1x1x2x2 (identity-like) - torch.tensor([0], dtype=torch.int8), # bias + torch.tensor([0], dtype=torch.int32), # bias (1, 1), # stride (0, 0), # padding (1, 1), # dilation @@ -381,7 +393,7 @@ def test_quantized_layer_norm_per_tensor( torch.tensor( [[[[1, 1], [1, 1]]]], dtype=torch.int8 ), # weight: 1x1x2x2 (sum filter) - torch.tensor([0], dtype=torch.int8), # bias + torch.tensor([0], dtype=torch.int32), # bias (1, 1), # stride (0, 0), # padding (1, 1), # dilation @@ -410,7 +422,7 @@ def test_quantized_layer_norm_per_tensor( torch.tensor( [[[[129, 128], [128, 129]]]], dtype=torch.uint8 ), # weight: 1x1x2x2 (values close to zero_point) - torch.tensor([10], dtype=torch.uint8), # bias + torch.tensor([10], dtype=torch.int32), # bias (1, 1), # stride (0, 0), # padding (1, 1), # dilation @@ -441,7 +453,7 @@ def test_quantized_layer_norm_per_tensor( torch.tensor( [[[1, 1]]], dtype=torch.int8 ), # weight: 1x1x2 (OC, IC, KW) - torch.tensor([0], dtype=torch.int8), # bias + torch.tensor([0], dtype=torch.int32), # bias (1, 1), # stride (padding for 2D, actual stride is stride[1]) (0, 0), # padding (padding for 2D, actual padding is padding[1]) (1, 1), # dilation (padding for 2D, actual dilation is dilation[1]) @@ -517,7 +529,7 @@ def test_quantized_layer_norm_per_tensor( ], dtype=torch.int16, ), # weight: 1x2x2x2 (1 output channel, 2 input channels) - torch.tensor([0], dtype=torch.int16), # bias + torch.tensor([0], dtype=torch.int32), # bias (1, 1), # stride (0, 0), # padding (1, 1), # dilation @@ -652,7 +664,7 @@ def test_quantized_layer_norm_per_tensor( torch.tensor( [[[[1, 1], [1, 1]]]], dtype=torch.int8 ), # weight: 1x1x2x2 (sum filter) - torch.tensor([0], dtype=torch.int8), # bias + torch.tensor([0], dtype=torch.int32), # bias (2, 2), # stride=2 (1, 1), # padding=1 (1, 1), # dilation @@ -701,42 +713,76 @@ def test_quantized_conv_per_tensor( input_tensor = input_tensor.to(memory_format=memory_format) - conv = ( - quantized_conv_nchw_per_tensor - if memory_format == torch.contiguous_format - else quantized_conv_nhwc_per_tensor - ) + convs = [ + ( + quantized_conv_nchw_per_tensor + if memory_format == torch.contiguous_format + else quantized_conv_nhwc_per_tensor + ) + ] - output = conv( - input_tensor, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out_multiplier, - out_shift, - ).to(memory_format=torch.contiguous_format) + optimized_convs = [] + if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8: + if input_tensor.is_contiguous(memory_format=torch.contiguous_format): + optimized_convs = [ + quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor, + ] - # Verify output properties - self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}") - self.assertEqual( - output.shape, - expected_output.shape, - "Output shape should match expected shape", - ) + else: + optimized_convs = [ + quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor, + ] + elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8: + if input_tensor.is_contiguous(memory_format=torch.contiguous_format): + optimized_convs = [ + quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor, + ] - # Verify output matches expected values - self.assertTrue( - torch.equal(output, expected_output), - f"Output values don't match expected. Got {output}, expected {expected_output}", - ) + else: + optimized_convs = [ + quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor, + ] + + convs.extend(optimized_convs) + for conv in convs: + output = conv( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ).to(memory_format=torch.contiguous_format) + + # Verify output properties + self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}") + self.assertEqual( + output.shape, + expected_output.shape, + "Output shape should match expected shape", + ) + + # Verify output matches expected values + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected. Got {output}, expected {expected_output}", + ) @expand( [ From 207b475936e10a96d66bbd3335d13c00689f171c Mon Sep 17 00:00:00 2001 From: agrebenisan Date: Tue, 9 Sep 2025 09:25:35 -0700 Subject: [PATCH 02/10] Ensure we can call custom ops from torch cadence lib Summary: Fixes mismatches between op registration names and implementation names, fixes some type issues in tests where unexpected types are passed in given the op definition. Also fixes an incorrect layernorm meta op (normalized_shape should be list, not int). Tests corrected as well. Tests now use the torch cadence custom op library. Differential Revision: D81738196 --- backends/cadence/aot/TARGETS | 1 + backends/cadence/aot/ops_registrations.py | 4 +- backends/cadence/aot/ref_implementations.py | 38 +++--- .../aot/tests/test_ref_implementations.py | 112 +++++++----------- 4 files changed, 67 insertions(+), 88 deletions(-) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 1a2c5a9709f..54b4a8b83f3 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -614,6 +614,7 @@ python_unittest( typing = True, deps = [ ":typing_stubs", + "//executorch/backends/cadence/aot:ops_registrations", "//executorch/backends/cadence/aot:ref_implementations", "//caffe2:torch", ] diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index ce0fba47610..507562526c5 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -1449,7 +1449,7 @@ def quantized_layer_norm_meta( input: torch.Tensor, X_scale: torch.Tensor, X_zero_point: torch.Tensor, - normalized_shape: int, + normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float, @@ -1464,7 +1464,7 @@ def quantized_layer_norm_per_tensor_meta( input: torch.Tensor, X_scale: float, X_zero_point: int, - normalized_shape: int, + normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float, diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 0cd55326b86..aeb62a19784 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -64,9 +64,9 @@ def quantize_per_tensor( f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}" ) - dequantized = torch.round(input_tensor * scale + zero_point).to(dtype) + quantized = torch.round(input_tensor * scale + zero_point).to(dtype) return torch.max( - torch.min(dequantized, torch.tensor(quant_max)), + torch.min(quantized, torch.tensor(quant_max)), torch.tensor(quant_min), ) @@ -247,12 +247,12 @@ def quantized_linear( ).reshape(*leading_dims, N) -@impl(m, "quantized_layer_norm_per_tensor") +@impl(m, "quantized_layer_norm.per_tensor") def quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, X_scale: float, X_zero_point: int, - normalized_shape: int, + normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float, @@ -283,7 +283,7 @@ def quantized_layer_norm_per_tensor( input_tensor, X_scale, X_zero_point, -128, 127, torch.float32 ) out = torch.nn.functional.layer_norm( - float_input_tensor, (normalized_shape,), weight, bias, eps=eps + float_input_tensor, normalized_shape, weight, bias, eps=eps ) return quantize_per_tensor( @@ -365,7 +365,7 @@ def quantized_conv_per_tensor( ) -@impl(m, "quantized_conv_nchw_per_tensor") +@impl(m, "quantized_conv_nchw.per_tensor") def quantized_conv_nchw_per_tensor( input_tensor: torch.Tensor, weight: torch.Tensor, @@ -421,7 +421,7 @@ def quantized_conv_nchw_per_tensor( ) -@impl(m, "quantized_conv_nhwc_per_tensor") +@impl(m, "quantized_conv_nhwc.per_tensor") def quantized_conv_nhwc_per_tensor( input_tensor: torch.Tensor, weight: torch.Tensor, @@ -558,62 +558,62 @@ def variant( return decorator -@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 4e2829a8460..918324876bf 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -8,31 +8,11 @@ import typing import unittest +import executorch.backends.cadence.aot.ops_registrations # noqa +import executorch.backends.cadence.aot.ref_implementations # noqa + import numpy as np import torch - -from executorch.backends.cadence.aot.ref_implementations import ( - dequantize_per_tensor, - quantize_per_tensor, - quantized_add, - quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nchw_per_tensor, - quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nhwc_per_tensor, - quantized_layer_norm_per_tensor, - quantized_linear, - quantized_relu, -) from executorch.backends.cadence.aot.typing_stubs import expand @@ -60,7 +40,7 @@ def test_quantize_per_tensor( zero_point = round(-f_min * inv_scale) + q_min expected_output = torch.tensor([expected_value], dtype=target_dtype) - output = quantize_per_tensor( + output = torch.ops.cadence.quantize_per_tensor( input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype ) @@ -104,7 +84,7 @@ def test_dequantize_per_tensor( zero_point = round(-f_min / scale) + q_min expected_output = torch.tensor([expected_value], dtype=torch.float32) - output = dequantize_per_tensor( + output = torch.ops.cadence.dequantize_per_tensor( input_tensor, scale, zero_point, q_min, q_max, torch.float32 ) @@ -142,7 +122,7 @@ def test_quantized_add( Y_tensor = torch.tensor([Y], dtype=dtype) expected_output = torch.tensor([expected_value], dtype=dtype) - output = quantized_add( + output = torch.ops.cadence.quantized_add( X_tensor, torch.tensor(X_scale), torch.tensor(X_zero_point, dtype=dtype), @@ -238,7 +218,7 @@ def test_quantized_linear( .to(expected_output.dtype) ) bias = torch.arange(weight_shape[0]).to(expected_output.dtype) - output = quantized_linear( + output = torch.ops.cadence.quantized_linear( src, weight, bias, @@ -266,7 +246,7 @@ def test_quantized_linear( ), # input: dequantized to [-0.1, 0.1] 0.1, # X_scale 0, # X_zero_point - 2, # normalized_shape (last dimension) + [2], # normalized_shape (last dimension) torch.tensor([1.0, 1.0]), # weight torch.tensor([0.0, 0.0]), # bias 1e-5, # eps @@ -282,7 +262,7 @@ def test_quantized_linear( ), # input: dequantized to [-0.05, 0.05] 0.05, # X_scale 128, # X_zero_point - 2, # normalized_shape (last dimension) + [2], # normalized_shape (last dimension) torch.tensor([1.0, 1.0]), # weight torch.tensor([0.0, 0.0]), # bias 1e-5, # eps @@ -298,7 +278,7 @@ def test_quantized_linear( ), # input: dequantized to [-0.2, 0.2] 0.1, # X_scale 0, # X_zero_point - 2, # normalized_shape (last dimension) + [2], # normalized_shape (last dimension) torch.tensor( [2.0, 0.5] ), # weight: scale first element by 2, second by 0.5 @@ -318,7 +298,7 @@ def test_quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, X_scale: float, X_zero_point: int, - normalized_shape: int, + normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float, @@ -327,7 +307,7 @@ def test_quantized_layer_norm_per_tensor( dtype: torch.dtype, expected_output: torch.Tensor, ) -> None: - output = quantized_layer_norm_per_tensor( + output = torch.ops.cadence.quantized_layer_norm.per_tensor( input_tensor, X_scale, X_zero_point, @@ -372,10 +352,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.1, # output_scale 0, # output_zero_point - torch.tensor( - [1073741824], dtype=torch.int32 - ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int8), # out_shift + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[[50]]]], dtype=torch.int8 @@ -403,8 +381,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.25, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[[48, 64], [96, 112]]]], dtype=torch.int8 @@ -432,8 +410,8 @@ def test_quantized_layer_norm_per_tensor( 0.1, # bias_scale 0.1, # output_scale 128, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.uint8, # dtype torch.tensor( [[[[238]]]], dtype=torch.uint8 @@ -463,8 +441,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.5, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[6, 10, 14]]], dtype=torch.int8 @@ -498,8 +476,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.2, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[[25]], [[50]]]], dtype=torch.int8 @@ -539,8 +517,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.1, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int16, # dtype torch.tensor( [[[[180]]]], dtype=torch.int16 @@ -592,8 +570,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.05, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int16, # dtype torch.tensor([[[[400]], [[200]]]], dtype=torch.int16), memory_format, @@ -635,8 +613,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.2, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[[50]], [[65]]]], dtype=torch.int8 @@ -674,8 +652,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.5, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[[2, 10, 8], [28, 68, 40], [26, 58, 32]]]], dtype=torch.int8 @@ -715,9 +693,9 @@ def test_quantized_conv_per_tensor( convs = [ ( - quantized_conv_nchw_per_tensor + torch.ops.cadence.quantized_conv_nchw.per_tensor if memory_format == torch.contiguous_format - else quantized_conv_nhwc_per_tensor + else torch.ops.cadence.quantized_conv_nhwc.per_tensor ) ] @@ -725,30 +703,30 @@ def test_quantized_conv_per_tensor( if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8: if input_tensor.is_contiguous(memory_format=torch.contiguous_format): optimized_convs = [ - quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor, + torch.ops.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor, ] else: optimized_convs = [ - quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor, + torch.ops.cadence.quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor, ] elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8: if input_tensor.is_contiguous(memory_format=torch.contiguous_format): optimized_convs = [ - quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor, + torch.ops.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor, ] else: optimized_convs = [ - quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor, + torch.ops.cadence.quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor, ] convs.extend(optimized_convs) @@ -851,7 +829,7 @@ def test_quantized_relu( dtype: torch.dtype, expected_output: torch.Tensor, ) -> None: - output = quantized_relu( + output = torch.ops.cadence.quantized_relu( X, X_zero_point, out_zero_point, out_multiplier, out_shift ) From b3cfcae9bcf720c9e6d36d6ab29b117f74a39955 Mon Sep 17 00:00:00 2001 From: agrebenisan Date: Tue, 9 Sep 2025 09:25:35 -0700 Subject: [PATCH 03/10] Update channels last python reference to not use memory_format=channels_last Summary: The default overload of custom channels last assumes that inputs and weights are permuted and contiguous in memory. Differential Revision: D81842686 --- backends/cadence/aot/ref_implementations.py | 9 ++++++--- backends/cadence/aot/tests/test_ref_implementations.py | 8 +++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index aeb62a19784..9f6dcbd8a2c 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -457,9 +457,12 @@ def quantized_conv_nhwc_per_tensor( - out_multiplier (int): Unused - out_shift (int): Unused """ - - if not input_tensor.is_contiguous(memory_format=torch.channels_last): - raise ValueError("Input tensor must be in NHWC format") + assert input_tensor.is_contiguous(memory_format=torch.contiguous_format) + assert weight.is_contiguous(memory_format=torch.contiguous_format) + input_tensor = torch.permute(input_tensor, (0, -1, 1, 2)).to( + memory_format=torch.channels_last + ) + weight = torch.permute(weight, (0, -1, 1, 2)) return quantized_conv_per_tensor( input_tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 918324876bf..0259b750d56 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -689,7 +689,9 @@ def test_quantized_conv_per_tensor( if len(input_tensor.shape) == 3 and memory_format == torch.channels_last: self.fail("Channels last format is not supported for 3D input tensors") - input_tensor = input_tensor.to(memory_format=memory_format) + if memory_format == torch.channels_last: + input_tensor = torch.permute(input_tensor, (0, 2, 3, 1)).contiguous() + weight = torch.permute(weight, (0, 2, 3, 1)).contiguous() convs = [ ( @@ -701,7 +703,7 @@ def test_quantized_conv_per_tensor( optimized_convs = [] if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8: - if input_tensor.is_contiguous(memory_format=torch.contiguous_format): + if memory_format == torch.contiguous_format: optimized_convs = [ torch.ops.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor, torch.ops.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor, @@ -715,7 +717,7 @@ def test_quantized_conv_per_tensor( torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor, ] elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8: - if input_tensor.is_contiguous(memory_format=torch.contiguous_format): + if memory_format == torch.contiguous_format: optimized_convs = [ torch.ops.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor, torch.ops.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor, From 1bdf227f4dc329d7c867d09f8e2aa7e052905108 Mon Sep 17 00:00:00 2001 From: agrebenisan Date: Tue, 9 Sep 2025 09:25:35 -0700 Subject: [PATCH 04/10] Support for all quantized linear ops Summary: Continued support for reference implementations of all custom Cadence ops. Differential Revision: D81940978 --- backends/cadence/aot/ref_implementations.py | 107 ++++++++- .../aot/tests/test_ref_implementations.py | 205 +++++++++++++----- 2 files changed, 251 insertions(+), 61 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 9f6dcbd8a2c..ed874239d68 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Callable, Optional +from typing import Callable import torch @@ -193,17 +193,15 @@ def quantized_add( ) -@impl(m, "quantized_linear") -def quantized_linear( +def quantized_linear_common( src: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, in_zero_point: int, - weight_zero_point: torch.Tensor, - out_multiplier: torch.Tensor, - out_shift: torch.Tensor, + weight_zero_point: torch.Tensor | int, + out_multiplier: torch.Tensor | int, + out_shift: int, out_zero_point: int, - offset: Optional[torch.Tensor], ) -> torch.Tensor: """ Quantized linear (transposed matmul) operation. @@ -219,7 +217,7 @@ def quantized_linear( - out_zero_point (int): The quantized mapping of zero for the output - offset (Tensor): Unused """ - out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0]) + out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift) out_scale_inv = 1 / out_scale N, K = weight.shape @@ -235,7 +233,9 @@ def quantized_linear( ) out = torch.nn.functional.linear( - src - in_zero_point, weight - weight_zero_point, bias + (src - in_zero_point).float(), + (weight - weight_zero_point).float(), + bias.float(), ) return quantize_per_tensor( out, @@ -247,6 +247,95 @@ def quantized_linear( ).reshape(*leading_dims, N) +def quantized_linear_variant( + per_tensor: bool, + src_dtype: torch.dtype | None = None, + weight_dtype: torch.dtype | None = None, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + def variant( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: int, + weight_zero_point: torch.Tensor | int, + out_multiplier: torch.Tensor | int, + out_shift: torch.Tensor | int, + out_zero_point: int, + offset: torch.Tensor | None = None, + ) -> torch.Tensor: + if src_dtype and src.dtype != src_dtype: + raise ValueError( + f"src dtype must be {src_dtype}. Got {src.dtype} instead" + ) + if weight_dtype and weight.dtype != weight_dtype: + raise ValueError( + f"weight dtype must be {weight_dtype}. Got {weight.dtype} instead" + ) + if bias.dtype != torch.int32: + raise ValueError( + f"bias dtype must be torch.int32. Got {bias.dtype} instead" + ) + + if per_tensor: + assert isinstance(weight_zero_point, int) + assert isinstance(out_multiplier, int) + assert isinstance(out_shift, int) + return quantized_linear_common( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + ) + else: + assert isinstance(out_shift, torch.Tensor) + if out_shift.numel() != 1: + raise ValueError("out_shift must be a scalar") + + if out_shift.dtype != torch.int64: + raise ValueError("out_shift must be an int64") + + return quantized_linear_common( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + int(out_shift.item()), + out_zero_point, + ) + + return variant + + return decorator + + +@impl(m, "quantized_linear") +@quantized_linear_variant(False) +def quantized_linear() -> torch.Tensor: ... + + +@impl(m, "quantized_linear.per_tensor") +@quantized_linear_variant(True) +def quantized_linear_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor") +@quantized_linear_variant(True, torch.int8, torch.int8) +def quantized_linear_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor") +@quantized_linear_variant(True, torch.uint8, torch.uint8) +def quantized_linear_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... + + @impl(m, "quantized_layer_norm.per_tensor") def quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 0259b750d56..253ea0f6f25 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -141,59 +141,139 @@ def test_quantized_add( @expand( [ # Test case 1: 1x2 input, 1x2 weight (1 output feature) - ( - torch.Size([1, 2]), # src_shape: 1 sample, 2 input features - torch.Size([1, 2]), # weight_shape: 1 output feature, 2 input features - 0, # in_zero_point - torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point - torch.tensor( - [1073741824], dtype=torch.int32 - ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int8), # out_shift - 0, # out_zero_point - torch.tensor([[-2]], dtype=torch.int8), # expected_output - ), + *[ + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [1, 2] + ), # weight_shape: 1 output feature, 2 input features + 0, # in_zero_point + torch.tensor([0, 0], dtype=dtype), # weight_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int64), # out_shift + 0, # out_zero_point + torch.tensor([[-2]], dtype=dtype), # expected_output + per_tensor, + ) + for (per_tensor, dtype) in ( + (False, torch.int8), + (True, torch.int8), + (True, torch.uint8), + ) + ], # Test case 2: 1x3 input, 2x3 weight (2 output features) - ( - torch.Size([1, 3]), # src_shape: 1 sample, 3 input features - torch.Size([2, 3]), # weight_shape: 2 output features, 3 input features - 0, # in_zero_point - torch.tensor([0, 0, 0], dtype=torch.int8), # weight_zero_point - torch.tensor( - [1073741824], dtype=torch.int32 - ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int8), # out_shift - 0, # out_zero_point - torch.tensor([[-10, -30]], dtype=torch.int8), # expected_output - ), + *[ + ( + torch.Size([1, 3]), # src_shape: 1 sample, 3 input features + torch.Size( + [2, 3] + ), # weight_shape: 2 output features, 3 input features + 0, # in_zero_point + torch.tensor([0, 0, 0], dtype=dtype), # weight_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int64), # out_shift + 0, # out_zero_point + torch.tensor([[-10, -30]], dtype=dtype), # expected_output + per_tensor, + ) + for (per_tensor, dtype) in ( + (False, torch.int8), + (True, torch.int8), + (True, torch.uint8), + ) + ], # Test case 3: Batch case with different dimensions - ( - torch.Size([1, 2, 2]), # src_shape: batch=1, seq=2, features=2 - torch.Size([3, 2]), # weight_shape: 3 output features, 2 input features - 0, # in_zero_point - torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point - torch.tensor( - [1073741824], dtype=torch.int32 - ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int8), # out_shift - 0, # out_zero_point - torch.tensor( - [[[-2, -8, -14], [-6, -28, -50]]], dtype=torch.int8 - ), # expected_output - ), + *[ + ( + torch.Size([1, 2, 2]), # src_shape: batch=1, seq=2, features=2 + torch.Size( + [3, 2] + ), # weight_shape: 3 output features, 2 input features + 0, # in_zero_point + torch.tensor([0, 0], dtype=dtype), # weight_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int64), # out_shift + 0, # out_zero_point + torch.tensor( + [[[-2, -8, -14], [-6, -28, -50]]], dtype=dtype + ), # expected_output + per_tensor, + ) + for (per_tensor, dtype) in ( + (False, torch.int8), + (True, torch.int8), + (True, torch.uint8), + ) + ], # Test case 4: Non-zero zero points - ( - torch.Size([1, 2]), # src_shape: 1 sample, 2 input features - torch.Size([2, 2]), # weight_shape: 2 output feature, 1 input feature - 2, # in_zero_point - torch.tensor([1, 1], dtype=torch.int8), # weight_zero_point - torch.tensor( - [268435456], dtype=torch.int32 - ), # out_multiplier (1.0 * 2^31) - torch.tensor([0]), # out_shift - 1, # out_zero_point - torch.tensor([[-15, 25]], dtype=torch.int8), # expected_output - ), + *[ + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2] + ), # weight_shape: 2 output feature, 1 input feature + 2, # in_zero_point + torch.tensor([1, 1], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (1.0 * 2^31) + torch.tensor([0], dtype=torch.int64), # out_shift + 1, # out_zero_point + torch.tensor([[-15, 25]], dtype=dtype), # expected_output + per_tensor, + ) + for (per_tensor, dtype) in ( + (False, torch.int8), + (True, torch.int8), + (True, torch.uint8), + ) + ], + # Test case 5: Non-uniform weight zero points + *[ + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2] + ), # weight_shape: 2 output feature, 1 input feature + 2, # in_zero_point + torch.tensor([1, 2], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (1.0 * 2^31) + torch.tensor([0], dtype=torch.int64), # out_shift + 1, # out_zero_point + torch.tensor([[-23, 17]], dtype=dtype), # expected_output + False, + ) + for dtype in (torch.int8, torch.uint8) + ], + # Test case 6: Non-zero out_shift (shift=1) + *[ + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2] + ), # weight_shape: 2 output features, 2 input features + 2, # in_zero_point + torch.tensor([1, 1], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (0.125 * 2^31) + torch.tensor( + [1], dtype=torch.int64 + ), # out_shift (shift=1, doubles the scale) + 1, # out_zero_point + torch.tensor([[-7, 13]], dtype=dtype), # expected_output + per_tensor, + ) + for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8)) + ], ] ) def test_quantized_linear( @@ -206,6 +286,7 @@ def test_quantized_linear( out_shift: torch.Tensor, out_zero_point: int, expected_output: torch.Tensor, + per_tensor: bool, ) -> None: src = ( torch.arange(np.prod(src_shape)) @@ -217,8 +298,28 @@ def test_quantized_linear( .reshape(weight_shape) .to(expected_output.dtype) ) - bias = torch.arange(weight_shape[0]).to(expected_output.dtype) - output = torch.ops.cadence.quantized_linear( + bias = torch.arange(weight_shape[0]).to(torch.int32) + if per_tensor: + weight_zero_point = weight_zero_point[0] + out_multiplier = out_multiplier[0] + out_shift = out_shift[0] + + if per_tensor: + match expected_output.dtype: + case torch.int8: + linear_op = ( + torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor + ) + case torch.uint8: + linear_op = ( + torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor + ) + case _: + linear_op = torch.ops.cadence.quantized_linear.per_tensor + else: + linear_op = torch.ops.cadence.quantized_linear + + output = linear_op( src, weight, bias, From 36bee68a0dffb2f925a31fa3b6067d40f17a55a6 Mon Sep 17 00:00:00 2001 From: agrebenisan Date: Tue, 9 Sep 2025 09:25:35 -0700 Subject: [PATCH 05/10] Add quantized fully connected ops Summary: Quantized fully connected are just aliases for quantized_linear, so created all aliases. Differential Revision: D81942767 --- backends/cadence/aot/ops_registrations.py | 4 ++ backends/cadence/aot/ref_implementations.py | 33 ++++++++++-- .../aot/tests/test_ref_implementations.py | 53 +++++++++++-------- 3 files changed, 64 insertions(+), 26 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 507562526c5..35b4cbf3902 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -1771,6 +1771,7 @@ def quantized_fully_connected_meta( # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] + assert src.shape[0] == 1 out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 @@ -1793,6 +1794,7 @@ def quantized_fully_connected_per_tensor_meta( # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] + assert src.shape[0] == 1 out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 @@ -1815,6 +1817,7 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_meta( # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] + assert src.shape[0] == 1 out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 @@ -1837,6 +1840,7 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta( # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] + assert src.shape[0] == 1 out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index ed874239d68..a9d7178b1df 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -249,6 +249,7 @@ def quantized_linear_common( def quantized_linear_variant( per_tensor: bool, + fully_connected: bool, src_dtype: torch.dtype | None = None, weight_dtype: torch.dtype | None = None, ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: @@ -265,6 +266,10 @@ def variant( out_zero_point: int, offset: torch.Tensor | None = None, ) -> torch.Tensor: + if fully_connected and src.shape[0] != 1: + raise ValueError( + "Fully connected quantized linear only supports batch size of 1" + ) if src_dtype and src.dtype != src_dtype: raise ValueError( f"src dtype must be {src_dtype}. Got {src.dtype} instead" @@ -317,25 +322,45 @@ def variant( @impl(m, "quantized_linear") -@quantized_linear_variant(False) +@quantized_linear_variant(False, False) def quantized_linear() -> torch.Tensor: ... @impl(m, "quantized_linear.per_tensor") -@quantized_linear_variant(True) +@quantized_linear_variant(True, False) def quantized_linear_per_tensor() -> torch.Tensor: ... @impl(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor") -@quantized_linear_variant(True, torch.int8, torch.int8) +@quantized_linear_variant(True, False, torch.int8, torch.int8) def quantized_linear_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ... @impl(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor") -@quantized_linear_variant(True, torch.uint8, torch.uint8) +@quantized_linear_variant(True, False, torch.uint8, torch.uint8) def quantized_linear_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... +@impl(m, "quantized_fully_connected") +@quantized_linear_variant(False, True) +def quantized_fully_connected() -> torch.Tensor: ... + + +@impl(m, "quantized_fully_connected.per_tensor") +@quantized_linear_variant(True, True) +def quantized_fully_connected_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor") +@quantized_linear_variant(True, True, torch.int8, torch.int8) +def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor") +@quantized_linear_variant(True, True, torch.uint8, torch.uint8) +def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... + + @impl(m, "quantized_layer_norm.per_tensor") def quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 253ea0f6f25..f6eede591d3 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -307,36 +307,45 @@ def test_quantized_linear( if per_tensor: match expected_output.dtype: case torch.int8: - linear_op = ( - torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor + linear_ops = ( + torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, ) case torch.uint8: - linear_op = ( - torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor + linear_ops = ( + torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, ) case _: - linear_op = torch.ops.cadence.quantized_linear.per_tensor + linear_ops = ( + torch.ops.cadence.quantized_linear.per_tensor, + torch.ops.cadence.quantized_fully_connected.per_tensor, + ) else: - linear_op = torch.ops.cadence.quantized_linear + linear_ops = ( + torch.ops.cadence.quantized_linear, + torch.ops.cadence.quantized_fully_connected, + ) - output = linear_op( - src, - weight, - bias, - in_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - typing.cast(torch.Tensor, None), - ) + for linear_op in linear_ops: + output = linear_op( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + typing.cast(torch.Tensor, None), + ) - self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch") + self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch") - self.assertTrue( - torch.equal(output, expected_output), - f"Values don't match: got {output}, expected {expected_output}", - ) + self.assertTrue( + torch.equal(output, expected_output), + f"Values don't match: got {output}, expected {expected_output}", + ) @expand( [ From 412e915340d6b5b1edfae75445730284bed8f639 Mon Sep 17 00:00:00 2001 From: agrebenisan Date: Tue, 9 Sep 2025 09:25:35 -0700 Subject: [PATCH 06/10] All variants of quantized relu Summary: Create a generic quantized relu and decorators for all custom quantized relu ops. Differential Revision: D81948125 --- backends/cadence/aot/ref_implementations.py | 84 ++++++++- .../aot/tests/test_ref_implementations.py | 159 ++++++++++++------ 2 files changed, 183 insertions(+), 60 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index a9d7178b1df..27a31c37503 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -735,13 +735,12 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tens def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_relu") -def quantized_relu( +def quantized_relu_common( X: torch.Tensor, - X_zero_point: torch.Tensor, + X_zero_point: torch.Tensor | int, out_zero_point: int, - out_multiplier: torch.Tensor, - out_shift: torch.Tensor, + out_multiplier: int, + out_shift: int, ) -> torch.Tensor: """ Quantized ReLU operation followed by requantization. @@ -757,7 +756,7 @@ def quantized_relu( if X.dtype not in supported_dtypes: raise ValueError(f"X dtype must be one of {supported_dtypes}. Got {X.dtype}") - out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0]) + out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift) dequantized_X = torch.where(X > X_zero_point, X - X_zero_point, torch.zeros_like(X)) return quantize_per_tensor( dequantized_X, @@ -769,6 +768,79 @@ def quantized_relu( ) +def quantized_relu_variant( + per_tensor: bool, + dtype: torch.dtype | None = None, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + """Create a quantized relu variant with type checking.""" + + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + def variant( + X: torch.Tensor, + X_zero_point: torch.Tensor | int, + out_zero_point: int, + out_multiplier: torch.Tensor | int, + out_shift: torch.Tensor | int, + ) -> torch.Tensor: + if per_tensor: + if dtype and X.dtype != dtype: + raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}") + + assert isinstance(out_shift, int) + assert isinstance(out_multiplier, int) + _out_shift = out_shift + _out_multiplier = out_multiplier + else: + assert isinstance(out_multiplier, torch.Tensor) + if out_multiplier.numel() > 1: + raise ValueError("Only scalar out_multiplier is supported") + + assert isinstance(out_shift, torch.Tensor) + if out_shift.numel() > 1: + raise ValueError("Only scalar out_shift is supported") + + assert isinstance(X_zero_point, torch.Tensor) + if X_zero_point.shape != X.shape: + raise ValueError( + f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}" + ) + + _out_multiplier = int(out_multiplier.item()) + _out_shift = int(out_shift.item()) + + return quantized_relu_common( + X, + X_zero_point, + out_zero_point, + _out_multiplier, + _out_shift, + ) + + return variant + + return decorator + + +@impl(m, "quantized_relu") +@quantized_relu_variant(False) +def quantized_relu() -> torch.Tensor: ... + + +@impl(m, "quantized_relu.per_tensor") +@quantized_relu_variant(True) +def quantized_relu_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_relu_asym8s_asym8s.per_tensor") +@quantized_relu_variant(True, torch.int8) +def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_relu_asym8u_asym8u.per_tensor") +@quantized_relu_variant(True, torch.uint8) +def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ... + + @impl(m, "requantize") def requantize( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index f6eede591d3..e5ab6750f60 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -877,73 +877,124 @@ def test_quantized_conv_per_tensor( @expand( [ # Test case 1: Basic int8 case with negative scale - ( - "basic_int8", - torch.tensor([-1, 0, 1, 3], dtype=torch.int8), # input - torch.tensor([0], dtype=torch.int8), # X_zero_point (scalar broadcast) - 0, # out_zero_point - torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) - torch.tensor([0]), # out_shift - torch.int8, # dtype - torch.tensor( - [0, 0, 0, -2], dtype=torch.int8 - ), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2) - ), + *[ + ( + "basic_int8", + torch.tensor([-1, 0, 1, 3], dtype=dtype), # input + 0, # X_zero_point (scalar broadcast) + 0, # out_zero_point + 1073741824, # out_multiplier (0.5 * 2^31) + 0, # out_shift + dtype, # dtype + torch.tensor( + [0, 0, 0, -2], dtype=dtype + ), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2) + ) + for dtype in [torch.int8] + ], # Test case 2: uint8 with non-zero zero point - ( - "uint8_with_zp", - torch.tensor([126, 128, 130, 132], dtype=torch.uint8), # input - torch.tensor([128], dtype=torch.uint8), # X_zero_point - 64, # out_zero_point - torch.tensor([536870912]), # out_multiplier (0.25 * 2^31) - torch.tensor([0]), # out_shift - torch.uint8, # dtype - torch.tensor( - [64, 64, 64, 63], dtype=torch.uint8 - ), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63) - ), + *[ + ( + "uint8_with_zp", + torch.tensor([126, 128, 130, 132], dtype=dtype), # input + 128, # X_zero_point + 64, # out_zero_point + 536870912, # out_multiplier (0.25 * 2^31) + 0, # out_shift + dtype, # dtype + torch.tensor( + [64, 64, 64, 63], dtype=dtype + ), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63) + ) + for dtype in [torch.uint8] + ], # Test case 3: All negative values (should all become zero after ReLU) - ( - "all_negative_int8", - torch.tensor([-5, -3, -1], dtype=torch.int8), # input - torch.tensor([0], dtype=torch.int8), # X_zero_point - 10, # out_zero_point - torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) - torch.tensor([0]), # out_shift - torch.int8, # dtype - torch.tensor( - [10, 10, 10], dtype=torch.int8 - ), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10) - ), + *[ + ( + "all_negative_int8", + torch.tensor([-5, -3, -1], dtype=dtype), # input + 0, # X_zero_point + 10, # out_zero_point + 1073741824, # out_multiplier (0.5 * 2^31) + 0, # out_shift + dtype, # dtype + torch.tensor( + [10, 10, 10], dtype=dtype + ), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10) + ) + for dtype in [torch.int8] + ], # Test case 4: All positive values with shift (scale becomes -0.25) - ( - "positive_with_shift", - torch.tensor([2, 4, 6, 8], dtype=torch.int8), # input - torch.tensor([1], dtype=torch.int8), # X_zero_point - 5, # out_zero_point - torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) - torch.tensor([1]), # out_shift (multiply by 2^1 = 2) - torch.int8, # dtype - torch.tensor( - [4, 2, 0, -2], dtype=torch.int8 - ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2) - ), + *[ + ( + "positive_with_shift", + torch.tensor([2, 4, 6, 8], dtype=dtype), # input + 1, # X_zero_point + 5, # out_zero_point + 1073741824, # out_multiplier (0.5 * 2^31) + 1, # out_shift (multiply by 2^1 = 2) + dtype, # dtype + torch.tensor( + [4, 2, 0, -2], dtype=dtype + ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2) + ) + for dtype in [torch.int8, torch.uint8] + ], + # Test case 4: Non-per-tensor + *[ + ( + "non_per_tensor", + torch.tensor([-1, -2, -3, 1, 2, 3], dtype=dtype), # input + torch.tensor([0, 0, 0, 1, 1, 1]), # X_zero_point + 5, # out_zero_point + torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) + torch.tensor([1]), # out_shift (multiply by 2^1 = 2) + dtype, # dtype + torch.tensor([5, 5, 5, 5, 4, 3], dtype=dtype), + ) + for dtype in [torch.int8] + ], ] ) def test_quantized_relu( self, name: str, X: torch.Tensor, - X_zero_point: torch.Tensor, + X_zero_point: torch.Tensor | int, out_zero_point: int, - out_multiplier: torch.Tensor, - out_shift: torch.Tensor, + out_multiplier: torch.Tensor | int, + out_shift: torch.Tensor | int, dtype: torch.dtype, expected_output: torch.Tensor, ) -> None: - output = torch.ops.cadence.quantized_relu( - X, X_zero_point, out_zero_point, out_multiplier, out_shift - ) + + if isinstance(X_zero_point, int): + assert isinstance(out_multiplier, int) + assert isinstance(out_shift, int) + + match dtype: + case torch.int8: + quantized_relu = ( + torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor + ) + case torch.uint8: + quantized_relu = ( + torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor + ) + case _: + quantized_relu = torch.ops.cadence.quantized_relu_per_tensor + + output = quantized_relu( + X, + X_zero_point, + out_zero_point, + out_multiplier, + out_shift, + ) + else: + output = torch.ops.cadence.quantized_relu( + X, X_zero_point, out_zero_point, out_multiplier, out_shift + ) # Verify output properties self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}") From 3d5202cf19bea0abc408b441650b4f1f4dbef1f7 Mon Sep 17 00:00:00 2001 From: agrebenisan Date: Tue, 9 Sep 2025 09:25:35 -0700 Subject: [PATCH 07/10] Remove non-per-tensor quantized add and replace with per-tensor variant Summary: As discussed offline, we don't need a non-per-tensor variant of quantized_add, so removing from ref implementations. Differential Revision: D81950579 --- backends/cadence/aot/ref_implementations.py | 28 +++++++++---------- .../aot/tests/test_ref_implementations.py | 8 +++--- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 27a31c37503..199c134de85 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -127,14 +127,14 @@ def dequantize_per_tensor( return (input_tensor - zero_point).to(dtype) * scale -@impl(m, "quantized_add") -def quantized_add( +@impl(m, "quantized_add.per_tensor") +def quantized_add_per_tensor( X: torch.Tensor, - X_scale: torch.Tensor, - X_zero_point: torch.Tensor, + X_scale: float, + X_zero_point: int, Y: torch.Tensor, - Y_scale: torch.Tensor, - Y_zero_point: torch.Tensor, + Y_scale: float, + Y_zero_point: int, out_scale: float, out_zero_point: int, ) -> torch.Tensor: @@ -149,17 +149,17 @@ def quantized_add( out = (X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point Args: - - X (Tensor): The first operand - - X_scale (Tensor): The ratio between the sizes of X's floating point and quantized + - X: The first operand + - X_scale: The ratio between the sizes of X's floating point and quantized ranges - - X_zero_point (Tensor): The quantized mapping of zero for X - - Y (Tensor): The second operand - - Y_scale (Tensor): The ratio between the sizes of Y's floating point and quantized + - X_zero_point: The quantized mapping of zero for X + - Y: The second operand + - Y_scale: The ratio between the sizes of Y's floating point and quantized ranges - - Y_zero_point (Tensor): The quantized mapping of zero for Y - - out_scale (float): The ratio between the sizes of the output's floating point and + - Y_zero_point: The quantized mapping of zero for Y + - out_scale: The ratio between the sizes of the output's floating point and quantized ranges - - out_zero_point (int): The quantized mapping of zero for the output + - out_zero_point: The quantized mapping of zero for the output """ supported_dtypes = [torch.int8, torch.uint8] if X.dtype != Y.dtype: diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index e5ab6750f60..2e8e963e104 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -124,11 +124,11 @@ def test_quantized_add( output = torch.ops.cadence.quantized_add( X_tensor, - torch.tensor(X_scale), - torch.tensor(X_zero_point, dtype=dtype), + X_scale, + X_zero_point, Y_tensor, - torch.tensor(Y_scale), - torch.tensor(Y_zero_point, dtype=dtype), + Y_scale, + Y_zero_point, out_scale, out_zero_point, ) From 192c6e5585f726bbe37083b370b6fb12054a7c86 Mon Sep 17 00:00:00 2001 From: agrebenisan Date: Tue, 9 Sep 2025 09:25:35 -0700 Subject: [PATCH 08/10] Add int8/uint8 specialized variants of quantized_add_per_tensor Summary: Add type specialized variants of quantized_add_per_tensor Differential Revision: D81951110 --- backends/cadence/aot/ref_implementations.py | 42 +++++++++++++++++++ .../aot/tests/test_ref_implementations.py | 23 +++++++++- 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 199c134de85..0bfd79b4994 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -193,6 +193,48 @@ def quantized_add_per_tensor( ) +@impl(m, "quantized_add_asym8sxasym8s_asym8s.per_tensor") +def quantized_add_asym8sxasym8s_asym8s_per_tensor( + X: torch.Tensor, + X_scale: float, + X_zero_point: int, + Y: torch.Tensor, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + if X.dtype != torch.int8: + raise ValueError("X dtype must be torch.int8") + if Y.dtype != torch.int8: + raise ValueError("Y dtype must be torch.int8") + + return quantized_add_per_tensor( + X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point + ) + + +@impl(m, "quantized_add_asym8uxasym8u_asym8u.per_tensor") +def quantized_add_asym8uxasym8u_asym8u_per_tensor( + X: torch.Tensor, + X_scale: float, + X_zero_point: int, + Y: torch.Tensor, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + if X.dtype != torch.uint8: + raise ValueError("X dtype must be torch.int8") + if Y.dtype != torch.uint8: + raise ValueError("Y dtype must be torch.int8") + + return quantized_add_per_tensor( + X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point + ) + + def quantized_linear_common( src: torch.Tensor, weight: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 2e8e963e104..b3ff9917ad2 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -100,7 +100,7 @@ def test_dequantize_per_tensor( [ # Only these types need to be tested as per ET_FORALL_JARVIS_QUANTIZED_TYPES in # on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h - ("int16", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8), + ("int8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8), ("uint8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.uint8), ] ) @@ -122,6 +122,27 @@ def test_quantized_add( Y_tensor = torch.tensor([Y], dtype=dtype) expected_output = torch.tensor([expected_value], dtype=dtype) + quantized_add = ( + torch.ops.cadence.quantized_add_asym8sxasym8s_asym8s.per_tensor + if dtype == torch.int8 + else torch.ops.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor + ) + output = quantized_add( + X_tensor, + X_scale, + X_zero_point, + Y_tensor, + Y_scale, + Y_zero_point, + out_scale, + out_zero_point, + ) + + self.assertTrue( + torch.equal(output, expected_output), + f"Values don't match in {name}: got {output}, expected {expected_output}", + ) + output = torch.ops.cadence.quantized_add( X_tensor, X_scale, From bd63826dc068a7e027dc91c31defe6fdb21485bf Mon Sep 17 00:00:00 2001 From: agrebenisan Date: Tue, 9 Sep 2025 09:25:35 -0700 Subject: [PATCH 09/10] Support custom quantized_matmul + variants (#14095) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/14095 Built on top of quantized_linear infrastructure. Differential Revision: D81973532 --- backends/cadence/aot/ref_implementations.py | 144 +++++++++++++++--- .../aot/tests/test_ref_implementations.py | 125 ++++++++++++--- 2 files changed, 225 insertions(+), 44 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 0bfd79b4994..8b5528b3bf8 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -241,7 +241,7 @@ def quantized_linear_common( bias: torch.Tensor, in_zero_point: int, weight_zero_point: torch.Tensor | int, - out_multiplier: torch.Tensor | int, + out_multiplier: int, out_shift: int, out_zero_point: int, ) -> torch.Tensor: @@ -329,34 +329,30 @@ def variant( assert isinstance(weight_zero_point, int) assert isinstance(out_multiplier, int) assert isinstance(out_shift, int) - return quantized_linear_common( - src, - weight, - bias, - in_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - ) + _out_shift = out_shift + _out_multiplier = out_multiplier else: assert isinstance(out_shift, torch.Tensor) + assert isinstance(out_multiplier, torch.Tensor) if out_shift.numel() != 1: raise ValueError("out_shift must be a scalar") if out_shift.dtype != torch.int64: raise ValueError("out_shift must be an int64") - return quantized_linear_common( - src, - weight, - bias, - in_zero_point, - weight_zero_point, - out_multiplier, - int(out_shift.item()), - out_zero_point, - ) + _out_shift = int(out_shift.item()) + _out_multiplier = int(out_multiplier[0].item()) + + return quantized_linear_common( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + _out_multiplier, + _out_shift, + out_zero_point, + ) return variant @@ -403,6 +399,112 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... +@impl(m, "quantized_matmul") +def quantized_matmul( + X: torch.Tensor, + X_zero_point: int, + Y: torch.Tensor, + Y_zero_point: int, + bias: torch.Tensor | None, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + transposed: bool = False, +) -> torch.Tensor: + """ + Quantized matmul operation. + + Args: + - X (Tensor): The activations tensor + - X_zero_point (int): The quantized mapping of zero for the input + - Y (Tensor): The weight tensor + - Y_zero_point (int): The quantized mapping of zero for the weight + - bias (Tensor): The bias tensor + - out_multiplier (int): The multiplier used to scale the output + - out_shift (int): The shift used to scale the output + - out_zero_point (int): The quantized mapping of zero for the output + - transposed (bool): Whether to transpose the weight tensor + """ + if bias is not None and not torch.all(bias == 0): + raise ValueError("bias must be None or all zeros since unused in out variant") + + # Looks weird, but quantized linear assumes weights are pre-transposed, + # hence we transpose only if `transposed` is False. + if not transposed: + Y = Y.T + + return quantized_linear_common( + X, + Y, + bias or torch.zeros(1, dtype=torch.int32), + X_zero_point, + Y_zero_point, + out_multiplier, + out_shift, + out_zero_point, + ) + + +@impl(m, "quantized_matmul_asym8sxasym8s_asym8s") +def quantized_matmul_asym8sxasym8s_asym8s( + X: torch.Tensor, + X_zero_point: int, + Y: torch.Tensor, + Y_zero_point: int, + bias: torch.Tensor | None, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + transposed: bool = False, +) -> torch.Tensor: + if X.dtype != torch.int8: + raise ValueError("X dtype must be torch.int8") + if Y.dtype != torch.int8: + raise ValueError("Y dtype must be torch.int8") + + return quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + ) + + +@impl(m, "quantized_matmul_asym8uxasym8u_asym8u") +def quantized_matmul_asym8uxasym8u_asym8u( + X: torch.Tensor, + X_zero_point: int, + Y: torch.Tensor, + Y_zero_point: int, + bias: torch.Tensor | None, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + transposed: bool = False, +) -> torch.Tensor: + if X.dtype != torch.uint8: + raise ValueError("X dtype must be torch.uint8") + if Y.dtype != torch.uint8: + raise ValueError("Y dtype must be torch.uint8") + + return quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + ) + + @impl(m, "quantized_layer_norm.per_tensor") def quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index b3ff9917ad2..da994c44e4d 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -177,6 +177,8 @@ def test_quantized_add( 0, # out_zero_point torch.tensor([[-2]], dtype=dtype), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), @@ -200,6 +202,8 @@ def test_quantized_add( 0, # out_zero_point torch.tensor([[-10, -30]], dtype=dtype), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), @@ -225,6 +229,8 @@ def test_quantized_add( [[[-2, -8, -14], [-6, -28, -50]]], dtype=dtype ), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), @@ -248,6 +254,8 @@ def test_quantized_add( 1, # out_zero_point torch.tensor([[-15, 25]], dtype=dtype), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), @@ -271,6 +279,8 @@ def test_quantized_add( 1, # out_zero_point torch.tensor([[-23, 17]], dtype=dtype), # expected_output False, + False, + False, ) for dtype in (torch.int8, torch.uint8) ], @@ -292,9 +302,34 @@ def test_quantized_add( 1, # out_zero_point torch.tensor([[-7, 13]], dtype=dtype), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8)) ], + *[ + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2] + ), # weight_shape: 2 output features, 2 input features + 2, # in_zero_point + torch.tensor([1, 1], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (0.125 * 2^31) + torch.tensor( + [1], dtype=torch.int64 + ), # out_shift (shift=1, doubles the scale) + 1, # out_zero_point + torch.tensor([[-7, 17]], dtype=dtype), # expected_output + per_tensor, + matmul, + transposed_matmul, + ) + for (matmul, transposed_matmul) in ((True, False), (True, True)) + for (per_tensor, dtype) in ((True, torch.int8), (True, torch.uint8)) + ], ] ) def test_quantized_linear( @@ -308,7 +343,12 @@ def test_quantized_linear( out_zero_point: int, expected_output: torch.Tensor, per_tensor: bool, + matmul: bool, + transposed_matmul: bool, ) -> None: + if not per_tensor and matmul: + self.skipTest("Only per_tensor supported for matmul") + src = ( torch.arange(np.prod(src_shape)) .reshape(src_shape) @@ -319,7 +359,9 @@ def test_quantized_linear( .reshape(weight_shape) .to(expected_output.dtype) ) - bias = torch.arange(weight_shape[0]).to(torch.int32) + if matmul and not transposed_matmul: + weight = weight.T + if per_tensor: weight_zero_point = weight_zero_point[0] out_multiplier = out_multiplier[0] @@ -328,20 +370,34 @@ def test_quantized_linear( if per_tensor: match expected_output.dtype: case torch.int8: - linear_ops = ( - torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor, - torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, - ) + if matmul: + linear_ops = ( + # Doesn't have per tensor name, but it is per tensor + torch.ops.cadence.quantized_matmul_asym8sxasym8s_asym8s, + ) + else: + linear_ops = ( + torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, + ) case torch.uint8: - linear_ops = ( - torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, - torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, - ) + if matmul: + linear_ops = ( + torch.ops.cadence.quantized_matmul_asym8uxasym8u_asym8u, + ) + else: + linear_ops = ( + torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, + ) case _: - linear_ops = ( - torch.ops.cadence.quantized_linear.per_tensor, - torch.ops.cadence.quantized_fully_connected.per_tensor, - ) + if matmul: + linear_ops = (torch.ops.cadence.quantized_matmul,) + else: + linear_ops = ( + torch.ops.cadence.quantized_linear.per_tensor, + torch.ops.cadence.quantized_fully_connected.per_tensor, + ) else: linear_ops = ( torch.ops.cadence.quantized_linear, @@ -349,17 +405,40 @@ def test_quantized_linear( ) for linear_op in linear_ops: - output = linear_op( - src, - weight, - bias, - in_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - typing.cast(torch.Tensor, None), + # Get the function name for linear_op for debugging + op_name = ( + linear_op.__name__ if hasattr(linear_op, "__name__") else str(linear_op) ) + if matmul: + assert "quantized_matmul" in op_name + output = linear_op( + src, + in_zero_point, + weight, + weight_zero_point, + None, + out_multiplier, + out_shift, + out_zero_point, + transposed_matmul, + ) + else: + assert ( + "quantized_linear" in op_name + or "quantized_fully_connected" in op_name + ) + bias = torch.arange(weight_shape[0]).to(torch.int32) + output = linear_op( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + typing.cast(torch.Tensor, None), + ) self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch") From 2744287c796832d0eca753ebf9a23b1e16bc9443 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Tue, 9 Sep 2025 09:38:36 -0700 Subject: [PATCH 10/10] Utility function for numerical correctness of edge dialect graphs and reference implementations (#14036) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/14036 Created two utility functions 1. Converts an edge dialect graph into one where custom cadence op nodes are replaced with python references 2. Validates the outputs (and optionally intermediates) of the graphs Updated two tests in test_replace_ops_passes to utilize these utility functions. Differential Revision: D81843001 --- backends/cadence/aot/TARGETS | 2 + backends/cadence/aot/pass_utils.py | 131 +++++++++++++++++- backends/cadence/aot/replace_ops.py | 9 +- .../aot/tests/test_replace_ops_passes.py | 73 ++++++---- 4 files changed, 180 insertions(+), 35 deletions(-) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 54b4a8b83f3..b7558a253a4 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -82,6 +82,8 @@ python_library( ], deps = [ ":utils", + ":ops_registrations", + ":ref_implementations", "//caffe2:torch", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index 9aedef2ce2f..efb47f16173 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -5,9 +5,15 @@ # LICENSE file in the root directory of this source tree. # pyre-strict - from dataclasses import dataclass -from typing import Callable, List, Optional, Set, Type, Union +from functools import partial +from operator import attrgetter +from torch.utils._python_dispatch import _disable_current_modes + +from typing import Any, Callable, cast, List, Optional, Set, Type, Union + +import executorch.backends.cadence.aot.ops_registrations # noqa +import executorch.backends.cadence.aot.ref_implementations # noqa import torch from executorch.backends.cadence.aot.utils import get_edge_overload_packet @@ -16,6 +22,8 @@ from executorch.exir.pass_base import PassBase, PassResult from torch._ops import OpOverloadPacket +from torch.fx import GraphModule +from torch.utils._pytree import PyTree # Is an overlap in tensor lifetime and storage allowed at the current opt level? @@ -114,6 +122,125 @@ def op_counts_match( return False return True +def validate_pass( + +) -> Callable[[type[PassBase]], type[PassBase]]: + tolerance = 1e-5 + log_differences = False + fail_on_mismatch = True + + def decorator(pass_class: type[PassBase]) -> type[PassBase]: + class WrappedPass(pass_class): + def call(self, graph_module: GraphModule) -> PassResult: + # Ensure we're not in fake tensor mode for actual execution + with _disable_current_modes(): + # Get inputs for the graph module + original_inputs = self._get_concrete_inputs(graph_module) + + if original_inputs is None: + raise RuntimeError("Could not extract concrete inputs for {pass_class.__name__}") + + # Run original graph and collect outputs + with torch.no_grad(): + original_outputs = graph_module(*original_inputs) + + # Apply the transformation + result = super().call(graph_module) + + # Run transformed graph and collect outputs + with torch.no_grad(): + transformed_outputs = result.graph_module(*original_inputs) + + # Compare outputs + self._compare_outputs( + original_outputs, + transformed_outputs, + pass_class.__name__, + tolerance, + log_differences, + fail_on_mismatch + ) + + return result + + def _get_concrete_inputs(self, graph_module: GraphModule) -> Optional[List[torch.Tensor]]: + """Extract concrete tensor inputs from the graph module metadata.""" + inputs = [] + for node in graph_module.graph.nodes: + if node.op == "placeholder": + if "val" in node.meta: + val = node.meta["val"] + if hasattr(val, "constant") and val.constant is not None: + inputs.append(val.constant.detach().clone()) + elif isinstance(val, torch.Tensor): + # Create a concrete tensor with the same properties + concrete_tensor = torch.testing.make_tensor(val.shape, dtype=val.dtype, device='cpu') + # concrete_tensor = torch.randn(val.shape, dtype=val.dtype) + if hasattr(val, 'device'): + concrete_tensor = concrete_tensor.to(val.device) + inputs.append(concrete_tensor) + else: + raise ValueError(f"Unsupported type for {node.name}: {type(val)}") + else: + raise ValueError(f"Missing 'val' in node metadata for {node.name}") + return inputs + + def _compare_outputs( + self, + original: Any, + transformed: Any, + pass_name: str, + tolerance: float, + log_differences: bool, + fail_on_mismatch: bool + ) -> None: + """Compare outputs and optionally log/fail on differences.""" + if isinstance(original, torch.Tensor) and isinstance(transformed, torch.Tensor): + if not torch.allclose(original, transformed, atol=tolerance, rtol=tolerance): + max_diff = torch.max(torch.abs(original - transformed)).item() + message = f"{pass_name}: Output mismatch detected. Max difference: {max_diff}" + + if log_differences: + pass + # logging.warning(message) + # logging.warning(f"Original shape: {original.shape}, Transformed shape: {transformed.shape}") + + if fail_on_mismatch: + raise ValueError(message) + else: + if log_differences: + pass + # logging.info(f"{pass_name}: Outputs match within tolerance {tolerance}") + + elif isinstance(original, (list, tuple)) and isinstance(transformed, (list, tuple)): + if len(original) != len(transformed): + message = f"{pass_name}: Output count mismatch. Original: {len(original)}, Transformed: {len(transformed)}" + if log_differences: + # logging.warning(message) + pass + if fail_on_mismatch: + raise ValueError(message) + else: + for i, (orig_item, trans_item) in enumerate(zip(original, transformed)): + self._compare_outputs( + orig_item, trans_item, f"{pass_name}[{i}]", + tolerance, log_differences, fail_on_mismatch + ) + else: + if log_differences: + pass + # logging.info(f"{pass_name}: Non-tensor outputs, skipping numerical comparison") + + # Preserve the original class name and documentation + WrappedPass.__name__ = pass_class.__name__ + WrappedPass.__qualname__ = pass_class.__qualname__ + WrappedPass.__doc__ = pass_class.__doc__ + + return cast(type[PassBase], WrappedPass) # type: ignore[return-value] + + return decorator + + # Testing utils # Return the compute/function nodes in the graph diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 75190b9c7be..68f533cc5a4 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -34,6 +34,7 @@ CadencePassAttribute, none_throws, register_cadence_pass, + validate_pass ) from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass from executorch.backends.cadence.aot.utils import get_edge_overload_packet @@ -947,7 +948,7 @@ def transpose_dims( exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta ) - +@validate_pass() @register_cadence_pass(CadencePassAttribute(opt_level=3)) class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper): def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: @@ -979,18 +980,18 @@ def call_operator( ) -> ProxyValue: if op not in { exir_ops.edge.cadence.convolution.default, - exir_ops.edge.cadence.quantized_conv_nchw.default, + exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, }: return super().call_operator(op, args, kwargs, meta) - quantized_op = op == exir_ops.edge.cadence.quantized_conv_nchw.default + quantized_op = op == exir_ops.edge.cadence.quantized_conv_nchw.per_tensor if not quantized_op and len(args) == 8 and args[-1] is True: # Already in NHWC layout. return super().call_operator(op, args, kwargs, meta) new_op = ( - exir_ops.edge.cadence.quantized_conv_nhwc.default + exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor if quantized_op else exir_ops.edge.cadence.convolution.default ) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index ca5168db2be..140b6236d19 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -15,7 +15,11 @@ GraphBuilder, single_op_builder, ) -from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match +from executorch.backends.cadence.aot.pass_utils import ( + count_node, + op_counts_match, + validate_pass +) from executorch.backends.cadence.aot.replace_ops import ( MakeSliceAndCatDimOutermostPass, ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, @@ -1612,7 +1616,7 @@ def test_no_transpose_if_already_channel_last(self) -> None: def create_quantized_convolution_graph_module( self, channels_last: Optional[bool] = None - ) -> torch.fx.GraphModule: + ) -> tuple[torch.fx.GraphModule, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Helper to create a quantized conv node. quantized_conv( @@ -1622,23 +1626,32 @@ def create_quantized_convolution_graph_module( Tensor out_shift, bool channel_last=False) -> (Tensor Z)" """ if channels_last: - x = torch.randn(1, 224, 56, 3) - w = torch.randn(16, 16, 16, 3) + x = torch.randint( + low=-128, high=127, size=(1, 224, 56, 3), dtype=torch.int8 + ) + w = torch.randint( + low=-128, high=127, size=(16, 16, 16, 3), dtype=torch.int8 + ) else: - x = torch.randn(1, 3, 224, 56) - w = torch.randn(16, 3, 16, 16) - b = torch.randn(16) + x = torch.randint( + low=-128, high=127, size=(1, 3, 224, 56), dtype=torch.int8 + ) + w = torch.randint( + low=-128, high=127, size=(16, 3, 16, 16), dtype=torch.int8 + ) + + b = torch.randint(low=-128, high=127, size=(16,), dtype=torch.int32) stride = (2, 2) padding = (0, 0) dilation = (1, 1) groups = 1 input_zero_point = 0 - w_zero_point = torch.randn(1) - b_scale = torch.randn(1) + w_zero_point = 1 + b_scale = 0.8 out_scale = 1 out_zero_point = 0 - out_multiplier = torch.randn(1) - out_shift = torch.randn(1) + out_multiplier = 0 + out_shift = 0 args = ( x, w, @@ -1661,44 +1674,39 @@ def create_quantized_convolution_graph_module( x, w, b, - w_zero_point, - b_scale, - out_multiplier, - out_shift, ), - op=exir_ops.edge.cadence.quantized_conv_nhwc.default, + op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, args=args, - ) + ), (x, w, b) else: return single_op_builder( placeholders=( x, w, b, - w_zero_point, - b_scale, - out_multiplier, - out_shift, ), - op=exir_ops.edge.cadence.quantized_conv_nchw.default, + op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, args=args, - ) + ), (x, w, b) def test_quantized_convolution_default_channel_last(self) -> None: # Create a graph with a single convolution node. - gm = self.create_quantized_convolution_graph_module() + gm, (x, w, b) = self.create_quantized_convolution_graph_module() self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.default), 1 + count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), 1 ) self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + # self.assertTrue(numerically_equivalent(gm, (x, w, b), True)) + # Apply replacement pass. p = ReplaceConvWithChannelLastConvPass() gm_after_replacement = p.call(gm).graph_module # Check that no replacement was made. self.assertEqual( count_node( - gm_after_replacement, exir_ops.edge.cadence.quantized_conv_nhwc.default + gm_after_replacement, + exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, ), 1, ) @@ -1708,14 +1716,19 @@ def test_quantized_convolution_default_channel_last(self) -> None: 3, ) + # self.assertTrue(numerically_equivalent(gm_after_replacement, (x, w, b), True)) + def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: # Create a graph with a single im2row node. - gm = self.create_quantized_convolution_graph_module(channels_last=True) + gm, (x, w, b) = self.create_quantized_convolution_graph_module( + channels_last=True + ) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.default), 1 + count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), 1 ) + # self.assertTrue(numerically_equivalent(gm, (x, w, b), True)) # Apply replacement pass. p = ReplaceConvWithChannelLastConvPass() @@ -1723,11 +1736,13 @@ def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: # Check that no replacement was made. self.assertEqual( count_node( - gm_after_replacement, exir_ops.edge.cadence.quantized_conv_nhwc.default + gm_after_replacement, + exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, ), 1, ) self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + # self.assertTrue(numerically_equivalent(gm_after_replacement, (x, w, b), True)) class TestMakeSliceAndCatDimOutermostPass(unittest.TestCase):