From b566bb60ac5aebb2b963031fd2185ef45540aa51 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Fri, 18 Jul 2025 14:55:45 +0200 Subject: [PATCH] Arm backend: Do not partition noop subgraphs - Reject partitions that will be lowered to empty subgraphs, i.e. containing only clones/ noop expands. - For eye/one/zeros the graph is not really empty since it contains one constant, which works after a Vela update. Simply move the previosuly xfailing tests to MI/BI. (u55/ u85 still fails because of missing CPU ops - this is why we missed when this started working) Signed-off-by: Adrian Lundell Change-Id: Iefa034a8e731d70465eb4883602c958f51aca976 --- .../_passes/convert_expand_copy_to_repeat.py | 49 ++++--- backends/arm/_passes/remove_clone_pass.py | 10 ++ backends/arm/operator_support/__init__.py | 1 + .../arm/operator_support/clone_support.py | 37 ++++++ .../tosa_supported_operators.py | 1 - backends/arm/scripts/parse_test_names.py | 1 + backends/arm/test/ops/test_alias_copy.py | 4 +- backends/arm/test/ops/test_clone.py | 122 ++++++++++-------- backends/arm/test/ops/test_expand.py | 43 ++---- backends/arm/test/ops/test_eye.py | 22 ++-- backends/arm/test/ops/test_ones.py | 23 ++-- backends/arm/test/ops/test_zeros.py | 23 ++-- backends/arm/tosa_partitioner.py | 51 +++++++- 13 files changed, 240 insertions(+), 147 deletions(-) create mode 100644 backends/arm/operator_support/clone_support.py diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index 5632c253437..ee509c7ebb5 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -8,12 +8,43 @@ import logging from typing import cast +import torch + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass logger = logging.getLogger(__name__) +def calculate_multiples(args): + input_node_or_tensor = args[0] + + if isinstance(input_node_or_tensor, torch.fx.node.Node): + input_data = input_node_or_tensor.meta["val"] + else: + input_data = input_node_or_tensor.data + + input_shape = input_data.shape + + multiples = cast(list[int], args[1]) + expanded_rank = len(multiples) + + # Expanded shape is 'input_shape' front-padded with ones. + padding = expanded_rank - len(input_shape) + extended_shape = [ + input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape)) + ] + + # To convert expand arg to repeat arg, non-repeated dims should have + # multiples[dim] = 1. Passing -1 to expand arg means + # not changing the size of that dimension. + multiples = [ + multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1 + for i in range(expanded_rank) + ] + return multiples + + class ConvertExpandCopyToRepeatPass(ExportPass): """ Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions. @@ -26,23 +57,7 @@ def call_operator(self, op, args, kwargs, meta): if op != self.expand_copy: return super().call_operator(op, args, kwargs, meta) - input_shape = args[0].data.shape - multiples = cast(list[int], args[1]) - expanded_rank = len(multiples) - - # Expanded shape is 'input_shape' front-padded with ones. - padding = expanded_rank - len(input_shape) - extended_shape = [ - input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape)) - ] - - # To convert expand arg to repeat arg, non-repeated dims should have - # multiples[dim] = 1. Passing -1 to expand arg means - # not changing the size of that dimension. - multiples = [ - multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1 - for i in range(expanded_rank) - ] + multiples = calculate_multiples(args) if all((x == 1 for x in multiples)): # All dimensions/repetitions occur only once. Remove node diff --git a/backends/arm/_passes/remove_clone_pass.py b/backends/arm/_passes/remove_clone_pass.py index a2822c7378e..09ddd39542c 100644 --- a/backends/arm/_passes/remove_clone_pass.py +++ b/backends/arm/_passes/remove_clone_pass.py @@ -6,9 +6,13 @@ # pyre-unsafe +import logging + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass +logger = logging.getLogger(__name__) + class RemoveClonePass(ExportPass): """Remove all clones from graph_module""" @@ -21,4 +25,10 @@ def call_operator(self, op, args, kwargs, meta): raise ValueError( f"clone operator expects exactly one argument, got {len(args)}" ) + + if "memory_format" in kwargs: + logger.warning( + f"Removing clone with memory_format '{kwargs['memory_format']}'." + ) + return args[0] diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 2075e0f554f..b62cc83ed8f 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -6,6 +6,7 @@ # pyre-unsafe from . import ( # noqa + clone_support, convolution_support, embedding_support, ethos_u55_support, diff --git a/backends/arm/operator_support/clone_support.py b/backends/arm/operator_support/clone_support.py new file mode 100644 index 00000000000..786c2803c33 --- /dev/null +++ b/backends/arm/operator_support/clone_support.py @@ -0,0 +1,37 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch.fx as fx +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + +logger = logging.getLogger(__name__) + + +@register_tosa_support_check +class CloneSupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten.clone.default] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: + + input_node = node.args[0] + if not isinstance(input_node, fx.Node): + self.reporter.report_reject(node, "Non tensor clones are not supported") + return False + + return True diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index d813fbda531..838e0164a17 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -228,7 +228,6 @@ def is_node_supported( exir_ops.edge.aten.var.correction, exir_ops.edge.aten.var.dim, exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.clone.default, exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.pow.Tensor_Scalar, diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index 9ceb5d73d23..c6eaafa597b 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -25,6 +25,7 @@ "unflatten.int", "_native_batch_norm_legit_no_training.default", "_native_batch_norm_legit.no_stats", + "alias_copy.default", ] ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS diff --git a/backends/arm/test/ops/test_alias_copy.py b/backends/arm/test/ops/test_alias_copy.py index cf8caca02c4..8b951a4d856 100644 --- a/backends/arm/test/ops/test_alias_copy.py +++ b/backends/arm/test/ops/test_alias_copy.py @@ -41,7 +41,9 @@ def __init__(self): super().__init__() def forward(self, x: torch.Tensor): - return torch.alias_copy(x) + return ( + torch.alias_copy(x) * 1 + ) # Multiply by one to make sure it is partitioned. @common.parametrize("test_data", AliasCopy.test_data) diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index b4f2879be48..cf7b4a417cd 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -3,13 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# -# Tests the clone op which copies the data of the input tensor (possibly with new data format) -# from typing import Tuple -import pytest import torch from executorch.backends.arm.test import common @@ -28,57 +24,82 @@ input_t = Tuple[torch.Tensor] -class Clone(torch.nn.Module): - """A simple module that clones an input tensor.""" +class CloneFirstArg(torch.nn.Module): + def forward(self, x): + return x.clone() + x - def forward(self, x: torch.Tensor): - return x.clone() +class CloneSecondArg(torch.nn.Module): + def forward(self, x): + return x * x.clone() + + +class CloneOutput(torch.nn.Module): + def forward(self, x): + return (x / x).clone() + + +class CloneBothArgs(torch.nn.Module): + def forward(self, x): + return x.clone() + x.clone() + + +class CloneAfterOtherOp(torch.nn.Module): + def forward(self, x): + x = x * 2 + return x.clone() + x + + +class CloneParallelToOtherOp(torch.nn.Module): + def forward(self, x): + return x * 2 + x.clone() -test_data_suite = { - "ones_1D_10": lambda: (torch.ones(10),), - "ones_1D_50": lambda: (torch.ones(50),), - "rand_1D_20": lambda: (torch.rand(20),), - "rand_2D_10x10": lambda: (torch.rand(10, 10),), - "rand_3D_5x5x5": lambda: (torch.rand(5, 5, 5),), - "rand_4D_2x3x4x5": lambda: (torch.rand(2, 3, 4, 5),), - "large_tensor": lambda: (torch.rand(1000),), -} +delegated_clones = { + "clone_first_arg": lambda: (CloneFirstArg, (torch.rand(1, 2, 3, 4),)), + "clone_second_arg": lambda: (CloneSecondArg, (torch.rand(1, 2, 3, 4),)), + "clone_output": lambda: (CloneOutput, (torch.rand(1, 2, 3, 4),)), + "clone_both_args": lambda: (CloneBothArgs, (torch.rand(1, 2, 3, 4),)), + "clone_after_other_op": lambda: (CloneAfterOtherOp, (torch.rand(1, 2, 3, 4),)), + "clone_parallel_to_other_op": lambda: ( + CloneParallelToOtherOp, + (torch.rand(1, 2, 3, 4),), + ), +} -@common.parametrize("test_data", test_data_suite) -def test_clone_tosa_FP(test_data: Tuple[torch.Tensor]): +@common.parametrize("input_data", delegated_clones) +def test_clone_tosa_FP(input_data): + module, input_tensor = input_data() pipeline = TosaPipelineFP[input_t]( - Clone(), - test_data(), - aten_op, - exir_op, + module(), + input_tensor, + [], ) - pipeline.run() -@common.parametrize("test_data", test_data_suite) -def test_clone_tosa_INT(test_data): +@common.parametrize("input_data", delegated_clones) +def test_clone_tosa_INT(input_data): + module, input_tensor = input_data() + pipeline = TosaPipelineINT[input_t]( - Clone(), - test_data(), + module(), + input_tensor, aten_op, exir_op, ) pipeline.run() -@common.parametrize("test_data", test_data_suite) +@common.parametrize("input_data", delegated_clones) @common.XfailIfNoCorstone300 -@pytest.mark.xfail( - reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477" -) -def test_clone_u55_INT(test_data): +def test_clone_u55_INT(input_data): + module, input_tensor = input_data() + pipeline = EthosU55PipelineINT[input_t]( - Clone(), - test_data(), + module(), + input_tensor, aten_op, exir_op, run_on_fvp=True, @@ -87,15 +108,14 @@ def test_clone_u55_INT(test_data): pipeline.run() -@common.parametrize("test_data", test_data_suite) +@common.parametrize("input_data", delegated_clones) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477" -) -def test_clone_u85_INT(test_data): +def test_clone_u85_INT(input_data): + module, input_tensor = input_data() + pipeline = EthosU85PipelineINT[input_t]( - Clone(), - test_data(), + module(), + input_tensor, aten_op, exir_op, run_on_fvp=True, @@ -104,27 +124,23 @@ def test_clone_u85_INT(test_data): pipeline.run() -@common.parametrize("test_data", test_data_suite) +@common.parametrize("test_data", delegated_clones) @common.SkipIfNoModelConverter -@pytest.mark.xfail( - reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477" -) def test_clone_vgf_FP(test_data): + module, input_tensor = test_data() pipeline = VgfPipeline[input_t]( - Clone(), test_data(), aten_op, exir_op, tosa_version="TOSA-1.0+FP" + module(), input_tensor, aten_op, exir_op, tosa_version="TOSA-1.0+FP" ) pipeline.run() -@common.parametrize("test_data", test_data_suite) +@common.parametrize("test_data", delegated_clones) @common.SkipIfNoModelConverter -@pytest.mark.xfail( - reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477" -) def test_clone_vgf_INT(test_data): + module, input_tensor = test_data() pipeline = VgfPipeline[input_t]( - Clone(), - test_data(), + module(), + input_tensor, aten_op, exir_op, tosa_version="TOSA-1.0+INT", diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index a0a7ccadeb4..b5784c9ff93 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -10,20 +10,21 @@ from typing import Sequence, Tuple -import pytest - import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, + OpNotSupportedPipeline, TosaPipelineFP, TosaPipelineINT, VgfPipeline, ) aten_op = "torch.ops.aten.expand.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_expand_copy_default" + input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x, Input y @@ -48,7 +49,7 @@ def forward(self, x: torch.Tensor, m: Sequence): return x.expand(m) -@common.parametrize("test_data", Expand.test_parameters | Expand.test_reject_set) +@common.parametrize("test_data", Expand.test_parameters) def test_expand_tosa_FP(test_data: Tuple): pipeline = TosaPipelineFP[input_t1]( Expand(), @@ -59,7 +60,7 @@ def test_expand_tosa_FP(test_data: Tuple): pipeline.run() -@common.parametrize("test_data", Expand.test_parameters | Expand.test_reject_set) +@common.parametrize("test_data", Expand.test_parameters) def test_expand_tosa_INT(test_data: Tuple): pipeline = TosaPipelineINT[input_t1]( Expand(), @@ -96,7 +97,7 @@ def test_expand_u85_INT(test_data: Tuple): pipeline.run() -@common.parametrize("test_data", Expand.test_parameters | Expand.test_reject_set) +@common.parametrize("test_data", Expand.test_parameters) @common.SkipIfNoModelConverter def test_expand_vgf_FP(test_data: Tuple): pipeline = VgfPipeline[input_t1]( @@ -109,7 +110,7 @@ def test_expand_vgf_FP(test_data: Tuple): pipeline.run() -@common.parametrize("test_data", Expand.test_parameters | Expand.test_reject_set) +@common.parametrize("test_data", Expand.test_parameters) @common.SkipIfNoModelConverter def test_expand_vgf_INT(test_data: Tuple): pipeline = VgfPipeline[input_t1]( @@ -123,32 +124,8 @@ def test_expand_vgf_INT(test_data: Tuple): @common.parametrize("test_data", Expand.test_reject_set) -@common.XfailIfNoCorstone300 -@pytest.mark.xfail( - reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs" -) -def test_expand_u55_INT_failure_set(test_data: Tuple): - pipeline = EthosU55PipelineINT[input_t1]( - Expand(), - test_data(), - aten_op, - exir_ops=[], - run_on_fvp=True, - ) - pipeline.run() - - -@common.parametrize("test_data", Expand.test_reject_set) -@common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs" -) -def test_expand_u85_INT_failure_set(test_data: Tuple): - pipeline = EthosU85PipelineINT[input_t1]( - Expand(), - test_data(), - aten_op, - exir_ops=[], - run_on_fvp=True, +def test_expand_u55_INT_not_delegated(test_data: Tuple): + pipeline = OpNotSupportedPipeline[input_t1]( + Expand(), test_data(), {exir_op: 1}, n_expected_delegates=0 ) pipeline.run() diff --git a/backends/arm/test/ops/test_eye.py b/backends/arm/test/ops/test_eye.py index d93a5dcc5fe..eef32259c10 100644 --- a/backends/arm/test/ops/test_eye.py +++ b/backends/arm/test/ops/test_eye.py @@ -39,13 +39,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ), } - test_data_not_delegated: dict[str, test_data_t] = { + # Mixed dtypes - the eye op is delegated, but it leads to a non-delegated add op. + test_data_mixed_dtypes: dict[str, test_data_t] = { "fp32_int64": (lambda: (torch.randn(10),), (10, torch.int64)), "fp32_int32": (lambda: (torch.randn(10),), (10, torch.int32)), - "int32_int64": ( - lambda: (torch.randint(0, 10, [10], dtype=torch.int32),), - (10, torch.int64), - ), } @@ -63,7 +60,7 @@ def test_eye_tosa_FP(test_data: test_data_t): pipeline.run() -@common.parametrize("test_data", EyeAdd.test_data) +@common.parametrize("test_data", EyeAdd.test_data | EyeAdd.test_data_mixed_dtypes) def test_eye_tosa_INT(test_data: test_data_t): input_data, init_data = test_data pipeline = TosaPipelineINT[input_t]( @@ -141,16 +138,15 @@ def test_eye_vgf_INT(test_data: test_data_t): @common.parametrize( "test_data", - EyeAdd.test_data_not_delegated, - xfails={ - "fp32_int32": "MLETORCG-716: Do not delegate empty networks to vela", - "fp32_int64": "MLETORCG-716: Do not delegate empty networks to vela", - "int32_int64": "MLETORCG-716: Do not delegate empty networks to vela", - }, + EyeAdd.test_data_mixed_dtypes, ) def test_eye_tosa_INT_not_delegated(test_data: test_data_t): input_data, init_data = test_data pipeline = OpNotSupportedPipeline[input_t]( - EyeAdd(*init_data), input_data(), non_delegated_ops={}, quantize=True + EyeAdd(*init_data), + input_data(), + non_delegated_ops={"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1}, + n_expected_delegates=1, + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_ones.py b/backends/arm/test/ops/test_ones.py index ca2313b11a0..f4dafca5e10 100644 --- a/backends/arm/test/ops/test_ones.py +++ b/backends/arm/test/ops/test_ones.py @@ -39,17 +39,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ), } - test_data_not_delegated: dict[str, test_data_t] = { + # Mixed dtypes - the ones op is delegated, but it leads to a non-delegated add op. + test_data_mixed_dtypes: dict[str, test_data_t] = { "fp32_int64": (lambda: (torch.randn(10),), (10, torch.int64)), "fp32_int32": (lambda: (torch.randn(10),), (10, torch.int32)), - "int32_int64": ( - lambda: (torch.randint(0, 10, [10], dtype=torch.int32),), - (10, torch.int64), - ), } -@common.parametrize("test_data", OnesAdd.test_data) +@common.parametrize("test_data", OnesAdd.test_data | OnesAdd.test_data_mixed_dtypes) def test_ones_tosa_FP(test_data: test_data_t): input_data, init_data = test_data pipeline = TosaPipelineFP[input_t]( @@ -60,7 +57,7 @@ def test_ones_tosa_FP(test_data: test_data_t): pipeline.run() -@common.parametrize("test_data", OnesAdd.test_data) +@common.parametrize("test_data", OnesAdd.test_data | OnesAdd.test_data_mixed_dtypes) def test_ones_tosa_INT(test_data: test_data_t): input_data, init_data = test_data pipeline = TosaPipelineINT[input_t]( @@ -102,16 +99,16 @@ def test_ones_u85_INT(test_data: test_data_t): @common.parametrize( "test_data", - OnesAdd.test_data_not_delegated, - xfails={ - "fp32_int32": "MLETORCG-716: Do not delegate empty networks to vela", - "fp32_int64": "MLETORCG-716: Do not delegate empty networks to vela", - }, + OnesAdd.test_data_mixed_dtypes, ) def test_ones_tosa_INT_not_delegated(test_data: test_data_t): input_data, init_data = test_data pipeline = OpNotSupportedPipeline[input_t]( - OnesAdd(*init_data), input_data(), non_delegated_ops={}, quantize=True + OnesAdd(*init_data), + input_data(), + non_delegated_ops={"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1}, + n_expected_delegates=1, + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_zeros.py b/backends/arm/test/ops/test_zeros.py index 88ac62cea1c..caee678282a 100644 --- a/backends/arm/test/ops/test_zeros.py +++ b/backends/arm/test/ops/test_zeros.py @@ -39,17 +39,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ), } - test_data_not_delegated: dict[str, test_data_t] = { + # Mixed dtypes - the zeros op is delegated, but it leads to a non-delegated add op. + test_data_mixed_dtypes: dict[str, test_data_t] = { "fp32_int64": (lambda: (torch.randn(10),), (10, torch.int64)), "fp32_int32": (lambda: (torch.randn(10),), (10, torch.int32)), - "int32_int64": ( - lambda: (torch.randint(0, 10, [10], dtype=torch.int32),), - (10, torch.int64), - ), } -@common.parametrize("test_data", ZerosAdd.test_data) +@common.parametrize("test_data", ZerosAdd.test_data | ZerosAdd.test_data_mixed_dtypes) def test_zeros_tosa_FP(test_data: test_data_t): input_data, init_data = test_data pipeline = TosaPipelineFP[input_t]( @@ -60,7 +57,7 @@ def test_zeros_tosa_FP(test_data: test_data_t): pipeline.run() -@common.parametrize("test_data", ZerosAdd.test_data) +@common.parametrize("test_data", ZerosAdd.test_data | ZerosAdd.test_data_mixed_dtypes) def test_zeros_tosa_INT(test_data: test_data_t): input_data, init_data = test_data pipeline = TosaPipelineINT[input_t]( @@ -102,16 +99,16 @@ def test_zeros_u85_INT(test_data: test_data_t): @common.parametrize( "test_data", - ZerosAdd.test_data_not_delegated, - xfails={ - "fp32_int32": "MLETORCG-716: Do not delegate empty networks to vela", - "fp32_int64": "MLETORCG-716: Do not delegate empty networks to vela", - }, + ZerosAdd.test_data_mixed_dtypes, ) def test_zeros_tosa_INT_not_delegated(test_data: test_data_t): input_data, init_data = test_data pipeline = OpNotSupportedPipeline[input_t]( - ZerosAdd(*init_data), input_data(), non_delegated_ops={}, quantize=True + ZerosAdd(*init_data), + input_data(), + non_delegated_ops={"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1}, + n_expected_delegates=1, + quantize=True, ) pipeline.run() diff --git a/backends/arm/tosa_partitioner.py b/backends/arm/tosa_partitioner.py index 3c51f781ea5..9ba46dd1d2d 100644 --- a/backends/arm/tosa_partitioner.py +++ b/backends/arm/tosa_partitioner.py @@ -9,11 +9,14 @@ from typing import Callable, List, Optional, Sequence, Tuple import torch -from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.arm_backend import ( is_tosa, ) # usort: skip from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( + calculate_multiples, +) +from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.operator_support.tosa_supported_operators import ( tosa_support_factory, ) @@ -26,14 +29,30 @@ PartitionResult, ) from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter +from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase - logger = logging.getLogger(__name__) +def is_noop_clone(node: torch.fx.node.Node) -> bool: + return node.target == exir_ops.edge.aten.clone.default + + +def is_noop_alias_copy(node: torch.fx.node.Node) -> bool: + return node.target == exir_ops.edge.aten.alias_copy.default + + +def is_noop_expand(node: torch.fx.node.Node) -> bool: + if node.target != exir_ops.edge.aten.expand_copy.default: + return False + else: + multiples = calculate_multiples(node.args) + return all(m == 1 for m in multiples) + + class TOSAPartitioner(Partitioner): def __init__( self, @@ -50,7 +69,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no # subgraphs containing the nodes with the tags logger.info("TOSAPartitioner::partition") - partition_tags = {} + partition_tags: dict[str, DelegationSpec] = {} tosa_spec = get_tosa_spec(self.delegation_spec.compile_specs) @@ -66,6 +85,17 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() + + def reject_partition(reason: str, partition, tag) -> None: + for node in partition.nodes: + if "delegation_tag" in node.meta: + del node.meta["delegation_tag"] + reporter.report_reject( + node, + reason, + ) + partition_tags.pop(tag, None) + for partition in partition_list: tag = f"tag{partition.id}" @@ -112,6 +142,21 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: del node.meta["delegation_tag"] break + is_noop_partition = all( + is_noop_clone(node) + or is_noop_alias_copy(node) + or is_noop_expand(node) + or node.target in Q_OPS + or node.target in DQ_OPS + for node in partition.nodes + ) + if is_noop_partition: + reject_partition( + "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.", + partition, + tag, + ) + tag_constant_data(exported_program) logger.info(f"The following nodes were rejected for {tosa_spec}:") logger.info("\n" + reporter.get_table_report())