From 956d70c3b12294a41cff3709033853c4834bd45c Mon Sep 17 00:00:00 2001 From: Suchir Bhatt Date: Thu, 14 Nov 2024 12:10:02 -0800 Subject: [PATCH] Add unit tests for old lowering flow for op_cat.py (#6847) Summary: The team moved to leveraging a new API which allows them to improve the reliability of our lowering infra. Lowering here refers to converting a PyTorch model that's recognizable by the underlying hardware. This diff makes sure there are still unit tests for the older APIs. Reviewed By: mcr229 Differential Revision: D65914291 --- backends/xnnpack/test/ops/cat.py | 88 +++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 31 deletions(-) diff --git a/backends/xnnpack/test/ops/cat.py b/backends/xnnpack/test/ops/cat.py index 23fca91f5b8..28bf1bec29a 100644 --- a/backends/xnnpack/test/ops/cat.py +++ b/backends/xnnpack/test/ops/cat.py @@ -36,39 +36,45 @@ def forward(self, arg1, arg2, arg3, arg4, arg5): return x + x # Quantize by propagation. def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): - tester = Tester(module, inputs) - - if quant: - tester.quantize() - - tester.export().check_count({"torch.ops.aten.cat": 1}) - tester.dump_artifact() - - if quant: - # Expect multiple quantize ops - one per input, cat, and add. - tester.check_node_count( - { - # Q/DQ pair for each input and quantized op. For most tests, there are - # two quantized ops - cat and add. - torch.ops.quantized_decomposed.quantize_per_tensor.default: ( - cat_num + quant_ops - ) - } + for legacy_mode in (True, False): + tester = Tester(module, inputs) + + if quant: + tester.quantize() + + tester.export().check_count({"torch.ops.aten.cat": 1}) + tester.dump_artifact() + + if quant: + # Expect multiple quantize ops - one per input, cat, and add. + tester.check_node_count( + { + # Q/DQ pair for each input and quantized op. For most tests, there are + # two quantized ops - cat and add. + torch.ops.quantized_decomposed.quantize_per_tensor.default: ( + cat_num + quant_ops + ) + } + ) + + + if legacy_mode: + tester.to_edge() + tester.partition() + else: + tester.to_edge_transform_and_lower() + + if quant: + tester.check_not(["torch.ops.quantized_decomposed"]) + + ( + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() ) - tester.to_edge_transform_and_lower() - - if quant: - tester.check_not(["torch.ops.quantized_decomposed"]) - - ( - tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) - .to_executorch() - .serialize() - .run_method_and_compare_outputs() - ) - def test_fp16_cat2(self): """ Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. @@ -155,6 +161,26 @@ def test_fp32_cat_unsupported(self): .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) ) + def test_fp32_cat_unsupported_legacy_mode(self): + """ + XNNPACK only supports concatenating up to 4 values, so it should not delegate here. + """ + inputs = ( + torch.randn(1, 2, 3), + torch.randn(3, 2, 3), + torch.randn(2, 2, 3), + torch.randn(5, 2, 3), + torch.randn(1, 2, 3), + ) + ( + Tester(self.Cat5(), inputs) + .export() + .check_count({"torch.ops.aten.cat": 1}) + .to_edge() + .partition() + .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) + ) + class CatNegativeDim(torch.nn.Module): def __init__(self): super().__init__()