From ff28b0d8cbff2cfe7c96244e81132ebb4e9b5c05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Wed, 16 Oct 2024 09:29:56 +0200 Subject: [PATCH 1/2] Add TosaSpecification to ArmPartitioner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The TOSA specification is added as an argument to TOSASupportedOperators to be able to make decisions on which operators to support. Signed-off-by: Per Åstrand Change-Id: I5ba3b9dd5f5dca54c3c6d8db29ebc9c52d5cc0f7 --- backends/arm/arm_partitioner.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index f4050351d1d..859b8dd1f4f 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -13,6 +13,7 @@ import torch from executorch.backends.arm.arm_backend import ArmBackend # usort: skip from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, @@ -36,6 +37,10 @@ class TOSASupportedOperators(OperatorSupportBase): + def __init__(self, tosa_spec: TosaSpecification): + super().__init__() + self.tosa_spec = tosa_spec + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: supported = node.op == "call_function" and node.target in [ exir_ops.edge.aten.add.Tensor, @@ -111,6 +116,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: logger.info("ArmPartitioner::partition") partition_tags = {} + tosa_spec = TosaSpecification.create_from_compilespecs( + self.delegation_spec.compile_specs + ) + + logger.info(f"Partitioning for {tosa_spec}") + for spec in self.delegation_spec.compile_specs: if spec.key == "quantize_io" and spec.value.decode() == "True": # Exclude IO quantization from the partition @@ -123,7 +134,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - TOSASupportedOperators(), + TOSASupportedOperators(tosa_spec), allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() From 95821867b2c2f8033eb9137b7919e43585368698 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 17 Oct 2024 13:55:30 +0200 Subject: [PATCH 2/2] Refactor TOSAOperatorSupport to allow for maping against TosaSpecification Change-Id: Ib7eb502861948c2f8dee995f28cbb7f2baa00afb --- backends/arm/arm_partitioner.py | 78 +---------- backends/arm/operator_support/__init__.py | 8 ++ .../arm/operator_support/mean_dim_support.py | 33 +++++ .../tosa_supported_operators.py | 128 ++++++++++++++++++ .../var_correction_support.py | 33 +++++ 5 files changed, 206 insertions(+), 74 deletions(-) create mode 100644 backends/arm/operator_support/__init__.py create mode 100644 backends/arm/operator_support/mean_dim_support.py create mode 100644 backends/arm/operator_support/tosa_supported_operators.py create mode 100644 backends/arm/operator_support/var_correction_support.py diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 859b8dd1f4f..bb9df2a054f 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -6,13 +6,15 @@ # pyre-unsafe import logging -import operator import os -from typing import Callable, cast, final, List, Optional, Tuple +from typing import Callable, final, List, Optional, Tuple import torch from executorch.backends.arm.arm_backend import ArmBackend # usort: skip from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + TOSASupportedOperators, +) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( @@ -21,13 +23,10 @@ PartitionResult, ) from executorch.exir.backend.utils import tag_constant_data -from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.passes import PassManager from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner -from torch.fx.passes.operator_support import OperatorSupportBase - logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" @@ -36,75 +35,6 @@ logger.setLevel(logging.INFO) -class TOSASupportedOperators(OperatorSupportBase): - def __init__(self, tosa_spec: TosaSpecification): - super().__init__() - self.tosa_spec = tosa_spec - - def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - supported = node.op == "call_function" and node.target in [ - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.expand_copy.default, - exir_ops.edge.aten.cat.default, - exir_ops.edge.aten.bmm.default, - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.convolution.default, - exir_ops.edge.aten.div.Tensor, - exir_ops.edge.aten.exp.default, - exir_ops.edge.aten.log.default, - exir_ops.edge.aten.linear.default, - exir_ops.edge.aten.split_with_sizes_copy.default, - exir_ops.edge.aten.full.default, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - exir_ops.edge.aten.native_layer_norm.default, - exir_ops.edge.aten.avg_pool2d.default, - exir_ops.edge.aten.max_pool2d_with_indices.default, - exir_ops.edge.aten.sigmoid.default, - exir_ops.edge.aten.mm.default, - exir_ops.edge.aten.repeat.default, - exir_ops.edge.aten.reciprocal.default, - exir_ops.edge.aten.relu.default, - exir_ops.edge.aten.rsqrt.default, - exir_ops.edge.aten._softmax.default, - exir_ops.edge.aten.select_copy.int, - exir_ops.edge.aten._log_softmax.default, - exir_ops.edge.aten.slice_copy.Tensor, - exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.sum.dim_IntList, - exir_ops.edge.aten.tanh.default, - exir_ops.edge.aten.upsample_nearest2d.vec, - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.clone.default, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.var.correction, - exir_ops.edge.aten.unsqueeze_copy.default, - exir_ops.edge.aten.squeeze_copy.dims, - operator.getitem, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - ] - - supported &= self.is_node_supported_custom(node) - - # Override partitioning based on pre partition passes - if "arm_override_partition" in node.meta: - supported = supported & node.meta["arm_override_partition"] - node.meta.pop("arm_override_partition") - - return supported - - def is_node_supported_custom(self, node: torch.fx.Node) -> bool: - if node.target == exir_ops.edge.aten.mean.dim: - keep_dim = node.args[2] if len(node.args) > 2 else False - return cast(bool, keep_dim) - if node.target == exir_ops.edge.aten.var.correction: - keep_dim = node.kwargs.get("keepdim", False) - return cast(bool, keep_dim) - return True - - @final class ArmPartitioner(Partitioner): def __init__(self, compile_spec: List[CompileSpec]) -> None: diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py new file mode 100644 index 00000000000..0a88bc45aa7 --- /dev/null +++ b/backends/arm/operator_support/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from . import mean_dim_support, tosa_supported_operators, var_correction_support # noqa diff --git a/backends/arm/operator_support/mean_dim_support.py b/backends/arm/operator_support/mean_dim_support.py new file mode 100644 index 00000000000..67a7c204069 --- /dev/null +++ b/backends/arm/operator_support/mean_dim_support.py @@ -0,0 +1,33 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import cast + +import torch.fx as fx + +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_tosa_support_check +class MeanDimSupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten.mean.dim] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: + assert node.target in self.targets + + keep_dim = node.args[2] if len(node.args) > 2 else False + return cast(bool, keep_dim) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py new file mode 100644 index 00000000000..0c7c56f3975 --- /dev/null +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -0,0 +1,128 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import operator + +import torch.fx as fx +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx.passes.operator_support import OperatorSupportBase + + +class SupportedTOSAOperatorCheck: + """ + Supported OP for TOSA lowering + """ + + # Should be populated by subclass implementation + tosa_specs: list[TosaSpecification] = [] + targets: list[str] = [] + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: + """ + Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec. + To be implemented by subclasses targeting + """ + raise NotImplementedError("NodeVisitor must be extended.") + + +# container for all SupportedTosaOperatorCheck classes +_tosa_spec_dicts: dict[TosaSpecification, dict[str, SupportedTOSAOperatorCheck]] = { + TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {}, + TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {}, +} + + +def register_tosa_support_check(checker): + """ + Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck + to be registered for checking if a torch.fx.Node is lowerable given + a TOSA specification. + """ + for tosa_spec in checker.tosa_specs: + for target in checker.targets: + _tosa_spec_dicts[tosa_spec][target] = checker + return checker + + +def get_registered_tosa_support_checks( + tosa_spec: TosaSpecification, +) -> dict[str, SupportedTOSAOperatorCheck]: + + if tosa_spec not in _tosa_spec_dicts: + raise RuntimeError + + tosa_support_checks = {} + for target, tosa_check in _tosa_spec_dicts[tosa_spec].items(): + tosa_support_checks[target] = tosa_check() + + return tosa_support_checks + + +class TOSASupportedOperators(OperatorSupportBase): + def __init__(self, tosa_spec: TosaSpecification): + super().__init__() + self.tosa_spec = tosa_spec + + def is_node_supported(self, submodules, node: fx.Node) -> bool: + supported = node.op == "call_function" and node.target in [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.expand_copy.default, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.bmm.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.log.default, + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + exir_ops.edge.aten.native_layer_norm.default, + exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.repeat.default, + exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten.select_copy.int, + exir_ops.edge.aten._log_softmax.default, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.upsample_nearest2d.vec, + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.squeeze_copy.dims, + operator.getitem, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + ] + + if not supported: + supported = self.is_node_supported_custom(node) + + # Override partitioning based on pre partition passes + if "arm_override_partition" in node.meta: + supported = supported & node.meta["arm_override_partition"] + node.meta.pop("arm_override_partition") + + return supported + + def is_node_supported_custom(self, node: fx.Node) -> bool: + tosa_checks = get_registered_tosa_support_checks(self.tosa_spec) + if node.target in tosa_checks.keys(): + return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) + return False diff --git a/backends/arm/operator_support/var_correction_support.py b/backends/arm/operator_support/var_correction_support.py new file mode 100644 index 00000000000..4aa2ae5e97d --- /dev/null +++ b/backends/arm/operator_support/var_correction_support.py @@ -0,0 +1,33 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import cast + +import torch.fx as fx + +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_tosa_support_check +class VarCorrectionSupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten.var.correction] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: + assert node.target in self.targets + + keep_dim = node.kwargs.get("keepdim", False) + return cast(bool, keep_dim)