From 0be5d7831c2553b91fa37ca82535c158a7c95d74 Mon Sep 17 00:00:00 2001 From: Ethan Ng Date: Mon, 15 Sep 2025 08:59:14 -0700 Subject: [PATCH] Clean up test type dispatches (#14228) Summary: use expand decorator in test type dispatches Reviewed By: DrJessop Differential Revision: D82239681 --- backends/cadence/aot/TARGETS | 1 + .../aot/tests/test_type_dispatch_passes.py | 884 ++++++------------ 2 files changed, 295 insertions(+), 590 deletions(-) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 27f9c00f4ac..d547a1ed555 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -344,6 +344,7 @@ python_unittest( typing = True, deps = [ ":ops_registrations", + ":typing_stubs", ":type_dispatch", "//caffe2:torch", "//executorch/backends/cadence/aot:graph_builder", diff --git a/backends/cadence/aot/tests/test_type_dispatch_passes.py b/backends/cadence/aot/tests/test_type_dispatch_passes.py index 704d92a3197..4ae10ea83dd 100644 --- a/backends/cadence/aot/tests/test_type_dispatch_passes.py +++ b/backends/cadence/aot/tests/test_type_dispatch_passes.py @@ -13,41 +13,36 @@ from executorch.backends.cadence.aot.graph_builder import single_op_builder from executorch.backends.cadence.aot.pass_utils import count_node from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops from torch.fx.passes.infra.pass_base import PassResult class TestTypeDispatchPasses(unittest.TestCase): - def test_int8_dispatch_quantized_fully_connected(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant""" - x = torch.randint(-128, 127, (1, 3), dtype=torch.int8) - w = torch.randint(-128, 127, (4, 3), dtype=torch.int8) - b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor, - args=(x, w, b, 0, 0, 1, 0, 0, None), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8", + torch.int8, exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_fully_connected(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant""" - x = torch.randint(0, 255, (1, 3), dtype=torch.uint8) - w = torch.randint(0, 255, (4, 3), dtype=torch.uint8) + ( + "uint8", + torch.uint8, + exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_fully_connected( + self, + _: str, + dtype: torch.dtype, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_fully_connected dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, (1, 3), dtype=dtype) + w = torch.randint(min_val, max_val, (4, 3), dtype=dtype) b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), @@ -61,45 +56,33 @@ def test_uint8_dispatch_quantized_fully_connected(self) -> None: count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor), 0, ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, - ), - 1, - ) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_int8_dispatch_quantized_linear(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_linear""" - x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) - w = torch.randint(-128, 127, (4, 3), dtype=torch.int8) - b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_linear.per_tensor, - args=(x, w, b, 0, 0, 1, 0, 0, None), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_linear.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8", + torch.int8, exir_ops.edge.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_quantized_linear_dispatch(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_linear""" - x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) - w = torch.randint(0, 255, (4, 3), dtype=torch.uint8) + ( + "uint8", + torch.uint8, + exir_ops.edge.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_linear( + self, + _: str, + dtype: torch.dtype, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_linear dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, (2, 3), dtype=dtype) + w = torch.randint(min_val, max_val, (4, 3), dtype=dtype) b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), @@ -113,14 +96,8 @@ def test_uint8_quantized_linear_dispatch(self) -> None: count_node(gm, exir_ops.edge.cadence.quantized_linear.per_tensor), 0, ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, - ), - 1, - ) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) def test_mixed_types_error(self) -> None: """Test mixed int8/uint8 inputs should raise RuntimeError""" @@ -138,33 +115,29 @@ def test_mixed_types_error(self) -> None: cast(PassResult, p(gm)).graph_module self.assertIn("Unsupported input types", str(context.exception)) - def test_int8_dispatch_quantized_relu(self) -> None: - """Test int8 input should dispatch to asym8s_asym8s variant for quantized_relu""" - x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) - gm = single_op_builder( - placeholders=(x,), - op=exir_ops.edge.cadence.quantized_relu.per_tensor, - args=(x, 0, 0, 1, 0), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_relu.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8", + torch.int8, exir_ops.edge.cadence.quantized_relu_asym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_relu(self) -> None: - """Test uint8 input should dispatch to asym8u_asym8u variant for quantized_relu""" - x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) + ( + "uint8", + torch.uint8, + exir_ops.edge.cadence.quantized_relu_asym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_relu( + self, + _: str, + dtype: torch.dtype, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_relu dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, (2, 3), dtype=dtype) gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.cadence.quantized_relu.per_tensor, @@ -177,45 +150,33 @@ def test_uint8_dispatch_quantized_relu(self) -> None: count_node(gm, exir_ops.edge.cadence.quantized_relu.per_tensor), 0, ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_relu_asym8u_asym8u.per_tensor, - ), - 1, - ) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_int8_dispatch_quantized_matmul(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_matmul""" - x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) - y = torch.randint(-128, 127, (3, 4), dtype=torch.int8) - bias = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, y, bias), - op=exir_ops.edge.cadence.quantized_matmul.default, - args=(x, 0, y, 0, bias, 1, 0, 0, False), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_matmul.default), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8", + torch.int8, exir_ops.edge.cadence.quantized_matmul_asym8sxasym8s_asym8s.default, ), - 1, - ) - - def test_uint8_dispatch_quantized_matmul(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_matmul""" - x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) - y = torch.randint(0, 255, (3, 4), dtype=torch.uint8) + ( + "uint8", + torch.uint8, + exir_ops.edge.cadence.quantized_matmul_asym8uxasym8u_asym8u.default, + ), + ] + ) + def test_dispatch_quantized_matmul( + self, + _: str, + dtype: torch.dtype, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_matmul dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, (2, 3), dtype=dtype) + y = torch.randint(min_val, max_val, (3, 4), dtype=dtype) bias = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, y, bias), @@ -229,356 +190,204 @@ def test_uint8_dispatch_quantized_matmul(self) -> None: count_node(gm, exir_ops.edge.cadence.quantized_matmul.default), 0, ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_matmul_asym8uxasym8u_asym8u.default, - ), - 1, - ) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_int8_dispatch_quantized_conv_nchw(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_conv_nchw""" - x = torch.randint(-128, 127, (1, 3, 8, 8), dtype=torch.int8) - w = torch.randint(-128, 127, (16, 3, 3, 3), dtype=torch.int8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8_nchw", + torch.int8, + (1, 3, 8, 8), # x_shape + exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, exir_ops.edge.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nchw(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_conv_nchw""" - x = torch.randint(0, 255, (1, 3, 8, 8), dtype=torch.uint8) - w = torch.randint(0, 255, (16, 3, 3, 3), dtype=torch.uint8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, + ( + "uint8_nchw", + torch.uint8, + (1, 3, 8, 8), # x_shape + exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, exir_ops.edge.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor, ), - 1, - ) - - def test_int8_dispatch_quantized_conv_nhwc(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_conv_nhwc""" - x = torch.randint(-128, 127, (1, 8, 8, 3), dtype=torch.int8) - w = torch.randint(-128, 127, (16, 3, 3, 3), dtype=torch.int8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + ( + "int8_nhwc", + torch.int8, + (1, 8, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, exir_ops.edge.cadence.quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nhwc(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_conv_nhwc""" - x = torch.randint(0, 255, (1, 8, 8, 3), dtype=torch.uint8) - w = torch.randint(0, 255, (16, 3, 3, 3), dtype=torch.uint8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, + ( + "uint8_nhwc", + torch.uint8, + (1, 8, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, exir_ops.edge.cadence.quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor, ), - 1, - ) - - def test_int8_dispatch_quantized_conv_nchw_dilated(self) -> None: - """Test int8 x int8 inputs with dilation should dispatch to dilated_asym8sxasym8s_asym8s variant for quantized_conv_nchw_dilated""" - x = torch.randint(-128, 127, (1, 3, 8, 8), dtype=torch.int8) - w = torch.randint(-128, 127, (16, 3, 3, 3), dtype=torch.int8) + ] + ) + def test_dispatch_quantized_conv_2d( + self, + _: str, + dtype: torch.dtype, + x_shape: tuple[int, ...], + original_op: torch._ops.OpOverload, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_conv_2d (nchw/nhwc) dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, x_shape, dtype=dtype) + w = torch.randint(min_val, max_val, (16, 3, 3, 3), dtype=dtype) b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1), + op=original_op, + args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), ) p = CompileTimeTypeDispatchPass() gm = cast(PassResult, p(gm)).graph_module # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + self.assertEqual(count_node(gm, original_op), 0) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) + + @expand( + [ + ( + "int8_nchw_dilated", + torch.int8, + (1, 3, 8, 8), # x_shape + exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, exir_ops.edge.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nchw_dilated(self) -> None: - """Test uint8 x uint8 inputs with dilation should dispatch to dilated_asym8uxasym8u_asym8u variant for quantized_conv_nchw""" - x = torch.randint(0, 255, (1, 3, 8, 8), dtype=torch.uint8) - w = torch.randint(0, 255, (16, 3, 3, 3), dtype=torch.uint8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, + ( + "uint8_nchw_dilated", + torch.uint8, + (1, 3, 8, 8), # x_shape + exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, exir_ops.edge.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor, ), - 1, - ) - - def test_int8_dispatch_quantized_conv_nhwc_dilated(self) -> None: - """Test int8 x int8 inputs with dilation should dispatch to dilated_asym8sxasym8s_asym8s variant for quantized_conv_nhwc""" - x = torch.randint(-128, 127, (1, 8, 8, 3), dtype=torch.int8) - w = torch.randint(-128, 127, (16, 3, 3, 3), dtype=torch.int8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + ( + "int8_nhwc_dilated", + torch.int8, + (1, 8, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, exir_ops.edge.cadence.quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nhwc_dilated(self) -> None: - """Test uint8 x uint8 inputs with dilation should dispatch to dilated_asym8uxasym8u_asym8u variant for quantized_conv_nhwc""" - x = torch.randint(0, 255, (1, 8, 8, 3), dtype=torch.uint8) - w = torch.randint(0, 255, (16, 3, 3, 3), dtype=torch.uint8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, + ( + "uint8_nhwc_dilated", + torch.uint8, + (1, 8, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, exir_ops.edge.cadence.quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor, ), - 1, - ) - - def test_int8_dispatch_quantized_conv_nchw_1d(self) -> None: - """Test int8 x int8 inputs for 1D conv should dispatch to 1d_asym8sxasym8s_asym8s variant for quantized_conv_nchw""" - x = torch.randint(-128, 127, (1, 3, 8), dtype=torch.int8) - w = torch.randint(-128, 127, (16, 3, 3), dtype=torch.int8) + ] + ) + def test_dispatch_quantized_conv_2d_dilated( + self, + _: str, + dtype: torch.dtype, + x_shape: tuple[int, ...], + original_op: torch._ops.OpOverload, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_conv_2d with dilation dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, x_shape, dtype=dtype) + w = torch.randint(min_val, max_val, (16, 3, 3, 3), dtype=dtype) b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), + op=original_op, + args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1), ) p = CompileTimeTypeDispatchPass() gm = cast(PassResult, p(gm)).graph_module # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with 1D int8 specific variant - self.assertEqual( - count_node( - gm, + self.assertEqual(count_node(gm, original_op), 0) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) + + @expand( + [ + ( + "int8_nchw_1d", + torch.int8, + (1, 3, 8), # x_shape + exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, exir_ops.edge.cadence.quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nchw_1d(self) -> None: - """Test uint8 x uint8 inputs for 1D conv should dispatch to 1d_asym8uxasym8u_asym8u variant for quantized_conv_nchw""" - x = torch.randint(0, 255, (1, 3, 8), dtype=torch.uint8) - w = torch.randint(0, 255, (16, 3, 3), dtype=torch.uint8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with 1D uint8 specific variant - self.assertEqual( - count_node( - gm, + ( + "uint8_nchw_1d", + torch.uint8, + (1, 3, 8), # x_shape + exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, exir_ops.edge.cadence.quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor, ), - 1, - ) - - def test_int8_dispatch_quantized_conv_nhwc_1d(self) -> None: - """Test int8 x int8 inputs for 1D conv should dispatch to 1d_asym8sxasym8s_asym8s variant for quantized_conv_nhwc""" - x = torch.randint(-128, 127, (1, 8, 3), dtype=torch.int8) - w = torch.randint(-128, 127, (16, 3, 3), dtype=torch.int8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with 1D int8 specific variant - self.assertEqual( - count_node( - gm, + ( + "int8_nhwc_1d", + torch.int8, + (1, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, exir_ops.edge.cadence.quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nhwc_1d(self) -> None: - """Test uint8 x uint8 inputs for 1D conv should dispatch to 1d_asym8uxasym8u_asym8u variant for quantized_conv_nhwc""" - x = torch.randint(0, 255, (1, 8, 3), dtype=torch.uint8) - w = torch.randint(0, 255, (16, 3, 3), dtype=torch.uint8) + ( + "uint8_nhwc_1d", + torch.uint8, + (1, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + exir_ops.edge.cadence.quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_conv_1d( + self, + _: str, + dtype: torch.dtype, + x_shape: tuple[int, ...], + original_op: torch._ops.OpOverload, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_conv_1d (nchw/nhwc) dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, x_shape, dtype=dtype) + w = torch.randint(min_val, max_val, (16, 3, 3), dtype=dtype) b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + op=original_op, args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), ) p = CompileTimeTypeDispatchPass() gm = cast(PassResult, p(gm)).graph_module # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with 1D uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor, - ), - 1, - ) + self.assertEqual(count_node(gm, original_op), 0) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_int8_dispatch_quantized_add(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_add""" - x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) - y = torch.randint(-128, 127, (2, 3), dtype=torch.int8) - gm = single_op_builder( - placeholders=(x, y), - op=exir_ops.edge.cadence.quantized_add.per_tensor, - args=(x, 1.0, 0, y, 1.0, 0, 1.0, 0), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_add.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8", + torch.int8, exir_ops.edge.cadence.quantized_add_asym8sxasym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_add(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_add""" - x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) - y = torch.randint(0, 255, (2, 3), dtype=torch.uint8) + ( + "uint8", + torch.uint8, + exir_ops.edge.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_add( + self, + _: str, + dtype: torch.dtype, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_add dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, (2, 3), dtype=dtype) + y = torch.randint(min_val, max_val, (2, 3), dtype=dtype) gm = single_op_builder( placeholders=(x, y), op=exir_ops.edge.cadence.quantized_add.per_tensor, @@ -591,158 +400,62 @@ def test_uint8_dispatch_quantized_add(self) -> None: count_node(gm, exir_ops.edge.cadence.quantized_add.per_tensor), 0, ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor, - ), - 1, - ) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_int8_dispatch_quantized_conv_nchw_depthwise(self) -> None: - """Test int8 x int8 inputs with depthwise should dispatch to depthwise_asym8sxsym8s_asym8s variant for quantized_conv_nchw""" - # Depthwise convolution: groups == input_channels - x = torch.randint(-128, 127, (1, 3, 8, 8), dtype=torch.int8) - w = torch.randint( - -128, 127, (3, 1, 3, 3), dtype=torch.int8 - ) # groups=3, input_channels=3 - b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=( - x, - w, - b, - [1, 1], - [0, 0], - [1, 1], - 3, - 0, - 0, - 1.0, - 1.0, - 0, - 1, - 1, - ), # groups=3 - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with int8 depthwise specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8_nchw_depthwise", + torch.int8, + (1, 3, 8, 8), # x_shape + (3, 1, 3, 3), # w_shape (groups=3, input_channels=3) + exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, exir_ops.edge.cadence.quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nchw_depthwise(self) -> None: - """Test uint8 x uint8 inputs with depthwise should dispatch to depthwise_asym8uxasym8u_asym8u variant for quantized_conv_nchw""" - # Depthwise convolution: groups == input_channels - x = torch.randint(0, 255, (1, 3, 8, 8), dtype=torch.uint8) - w = torch.randint( - 0, 255, (3, 1, 3, 3), dtype=torch.uint8 - ) # groups=3, input_channels=3 - b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=( - x, - w, - b, - [1, 1], - [0, 0], - [1, 1], - 3, - 0, - 0, - 1.0, - 1.0, - 0, - 1, - 1, - ), # groups=3 - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with uint8 depthwise specific variant - self.assertEqual( - count_node( - gm, + ( + "uint8_nchw_depthwise", + torch.uint8, + (1, 3, 8, 8), # x_shape + (3, 1, 3, 3), # w_shape (groups=3, input_channels=3) + exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, exir_ops.edge.cadence.quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor, ), - 1, - ) - - def test_int8_dispatch_quantized_conv_nhwc_depthwise(self) -> None: - """Test int8 x int8 inputs with depthwise should dispatch to depthwise_asym8sxsym8s_asym8s variant for quantized_conv_nhwc""" - # Depthwise convolution: groups == input_channels - x = torch.randint(-128, 127, (1, 8, 8, 3), dtype=torch.int8) - w = torch.randint( - -128, 127, (3, 3, 3, 1), dtype=torch.int8 - ) # groups=3, input_channels=3 - b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, - args=( - x, - w, - b, - [1, 1], - [0, 0], - [1, 1], - 3, - 0, - 0, - 1.0, - 1.0, - 0, - 1, - 1, - ), # groups=3 - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with int8 depthwise specific variant - self.assertEqual( - count_node( - gm, + ( + "int8_nhwc_depthwise", + torch.int8, + (1, 8, 8, 3), # x_shape + (3, 3, 3, 1), # w_shape (groups=3, input_channels=3) + exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, exir_ops.edge.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nhwc_depthwise(self) -> None: - """Test uint8 x uint8 inputs with depthwise should dispatch to depthwise_asym8uxasym8u_asym8u variant for quantized_conv_nhwc""" - # Depthwise convolution: groups == input_channels - x = torch.randint(0, 255, (1, 8, 8, 3), dtype=torch.uint8) - w = torch.randint( - 0, 255, (3, 3, 3, 1), dtype=torch.uint8 - ) # groups=3, input_channels=3 + ( + "uint8_nhwc_depthwise", + torch.uint8, + (1, 8, 8, 3), # x_shape + (3, 3, 3, 1), # w_shape (groups=3, input_channels=3) + exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + exir_ops.edge.cadence.quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_conv_depthwise( + self, + _: str, + dtype: torch.dtype, + x_shape: tuple[int, ...], + w_shape: tuple[int, ...], + original_op: torch._ops.OpOverload, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_conv depthwise (groups == input_channels) dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, x_shape, dtype=dtype) + w = torch.randint(min_val, max_val, w_shape, dtype=dtype) b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + op=original_op, args=( x, w, @@ -758,20 +471,11 @@ def test_uint8_dispatch_quantized_conv_nhwc_depthwise(self) -> None: 0, 1, 1, - ), # groups=3 + ), ) p = CompileTimeTypeDispatchPass() gm = cast(PassResult, p(gm)).graph_module # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with uint8 depthwise specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor, - ), - 1, - ) + self.assertEqual(count_node(gm, original_op), 0) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1)