diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index de9a793b9aa..aa3131e5954 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -42,6 +42,7 @@ from .decompose_elu_pass import DecomposeEluPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa from .decompose_expm1_pass import DecomposeExpm1Pass # noqa +from .decompose_floor_divide_pass import DecomposeFloorDividePass # noqa from .decompose_gelu_pass import DecomposeGeluPass # noqa from .decompose_glu_pass import DecomposeGluPass # noqa from .decompose_grouped_conv import DecomposeGroupedConv # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b579d910752..ba1c9ef87cf 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -51,6 +51,7 @@ DecomposeEluPass, DecomposeEmbeddingPass, DecomposeExpm1Pass, + DecomposeFloorDividePass, DecomposeGeluPass, DecomposeGluPass, DecomposeGroupedConv, @@ -243,6 +244,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSinhPass()) self.add_pass(DecomposeSignPass()) + self.add_pass(DecomposeFloorDividePass()) self.add_pass(DecomposeDivTensorModePass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) self.add_pass(DecomposeEmbeddingPass()) @@ -335,6 +337,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSignPass()) self.add_pass(DecomposeAddmmPass()) + self.add_pass(DecomposeFloorDividePass()) self.add_pass(DecomposeDivTensorModePass()) self.add_pass(DecomposeAddSubAlphaPass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) diff --git a/backends/arm/_passes/decompose_floor_divide_pass.py b/backends/arm/_passes/decompose_floor_divide_pass.py new file mode 100644 index 00000000000..7f3f49aefd1 --- /dev/null +++ b/backends/arm/_passes/decompose_floor_divide_pass.py @@ -0,0 +1,64 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_div_tensor_mode import ( + DecomposeDivTensorModePass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_floor_divide_ops = (exir_ops.edge.aten.floor_divide.default,) +aten_floor_divide_ops = (torch.ops.aten.floor_divide.default,) + + +def get_floor_divide_decomposition(op) -> tuple: + """ + Returns the decomposition of the given aten.floor_div operation into + its equivalent TOSA-supported operations + + This handles both edge dialect ops and core PyTorch ops. The decomposition strategy + is: + floor_div(x, y) → div_tensor_mode(x, y, rounding_mode="floor") + + Returns: + A tuple (div_op,) corresponding to the appropriate operator overload for the input op. + + Raises: + RuntimeError: If the provided operator is not a supported floor_divide variant. + """ + + if op in edge_floor_divide_ops: + return (exir_ops.edge.aten.div.Tensor_mode,) + if op in aten_floor_divide_ops: + return (torch.ops.aten.div.Tensor_mode,) + + raise RuntimeError(f"Can't get floor_div decomposition for op {op}") + + +class DecomposeFloorDividePass(ArmPass): + """ + Decomposes aten.floor_divide into aten.div.Tensor_mode with rounding_mode="floor". + """ + + _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} + + def call_operator(self, op, args, kwargs, meta): + if op not in (edge_floor_divide_ops + aten_floor_divide_ops): + return super().call_operator(op, args, kwargs, meta, updated=False) + + (div_op,) = get_floor_divide_decomposition(op) + + input = args[0] + other = args[1] + + div_node = super().call_operator( + div_op, (input, other), {"rounding_mode": "floor"}, meta, updated=True + ) + + return div_node diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index ee61aa4cce6..e1bb69ff78c 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -229,6 +229,7 @@ exir_ops.edge.aten.logit.default, exir_ops.edge.aten.acos.default, exir_ops.edge.aten.elu.default, + exir_ops.edge.aten.floor_divide.default, } diff --git a/backends/arm/test/ops/test_floor_div.py b/backends/arm/test/ops/test_floor_div.py new file mode 100644 index 00000000000..9d057c454c5 --- /dev/null +++ b/backends/arm/test/ops/test_floor_div.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple, Union + +import torch +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +test_data_suite = { + # (test_name, input, other) + "op_floor_div_rank1_ones": lambda: ( + torch.ones(5), + torch.ones(5), + ), + "op_floor_div_rank1_rand": lambda: ( + torch.rand(5) * 5, + torch.rand(5) * 5, + ), + "op_floor_div_rank4_negative_ones": lambda: ( + (-1) * torch.ones(5, 10, 25, 20), + torch.ones(5, 10, 25, 20), + ), + "op_floor_div_rank4_ones_div_negative": lambda: ( + torch.ones(5, 10, 25, 20), + (-1) * torch.ones(5, 10, 25, 20), + ), + "op_floor_div_rank4_large_rand": lambda: ( + 200 * torch.rand(5, 10, 25, 20), + torch.rand(5, 10, 25, 20), + ), + "op_floor_div_rank4_randn_mutltiple_broadcasts": lambda: ( + torch.randn(1, 4, 4, 1), + torch.randn(1, 1, 4, 4), + ), + "op_floor_div_rank4_randn_scalar": lambda: ( + torch.randn(1, 4, 4, 1), + 2, + ), +} + + +class FloorDivide(torch.nn.Module): + aten_op = "torch.ops.aten.floor_divide.default" + aten_ops_int = ["aten.mul.Tensor", "aten.reciprocal.default", "aten.floor.default"] + exir_op = "executorch_exir_dialects_edge__ops_aten_div_Tensor_mode" + exir_ops_int = [ + "executorch_exir_dialects_edge__ops_aten_reciprocal_default", + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_floor_default", + ] + + def forward( + self, + input_: Union[torch.Tensor, torch.types.Number], + other_: Union[torch.Tensor, torch.types.Number], + ): + return torch.floor_divide(input=input_, other=other_) + + +input_t1 = Tuple[torch.Tensor, Union[torch.Tensor, int]] + + +@common.parametrize("test_data", test_data_suite) +def test_floor_divide_tosa_FP(test_data: input_t1): + pipeline = TosaPipelineFP[input_t1]( + FloorDivide(), + test_data(), + FloorDivide.aten_op, + FloorDivide.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_floor_divide_tosa_INT(test_data: input_t1): + pipeline = TosaPipelineINT[input_t1]( + FloorDivide(), + test_data(), + aten_op=FloorDivide.aten_ops_int, + exir_op=FloorDivide.exir_ops_int, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_floor_divide_u55_INT(test_data: input_t1): + pipeline = EthosU55PipelineINT[input_t1]( + FloorDivide(), + test_data(), + aten_ops=FloorDivide.aten_ops_int, + exir_ops=[], + run_on_fvp=True, + use_to_edge_transform_and_lower=False, + ) + pipeline.pop_stage("check_not.exir") + pipeline.pop_stage("check_count.exir") + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_floor_divide_u85_INT(test_data: input_t1): + pipeline = EthosU85PipelineINT[input_t1]( + FloorDivide(), + test_data(), + aten_ops=FloorDivide.aten_ops_int, + exir_ops=FloorDivide.exir_ops_int, + run_on_fvp=True, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_floor_divide_vgf_FP(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + FloorDivide(), + test_data(), + FloorDivide.aten_op, + FloorDivide.exir_op, + tosa_version="TOSA-1.0+FP", + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_floor_divide_vgf_INT(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + FloorDivide(), + test_data(), + aten_op=FloorDivide.aten_ops_int, + exir_op=FloorDivide.exir_ops_int, + tosa_version="TOSA-1.0+INT", + use_to_edge_transform_and_lower=False, + ) + pipeline.run()