-
Notifications
You must be signed in to change notification settings - Fork 696
Arm backend: Add support for floor_divide.default #14933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
agrima1304
wants to merge
6
commits into
pytorch:main
Choose a base branch
from
agrima1304:op-floor-div
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+223
−0
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
a726363
Arm backend: Add support for floor_divide.default
agrima1304 7feef67
Merge branch 'pytorch:main' into op-floor-div
agrima1304 f482fe7
Arm backend: Add support for floor_divide.default
agrima1304 0abfba1
Merge branch 'main' into op-floor-div
zingo e05b531
Merge branch 'main' into op-floor-div
Sebastian-Larsson 5f5bf9b
Merge branch 'main' into op-floor-div
Sebastian-Larsson File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import Set, Type | ||
|
|
||
| import torch | ||
| from executorch.backends.arm._passes import ArmPass | ||
| from executorch.backends.arm._passes.decompose_div_tensor_mode import ( | ||
| DecomposeDivTensorModePass, | ||
| ) | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
| from executorch.exir.pass_base import ExportPass | ||
|
|
||
| edge_floor_divide_ops = (exir_ops.edge.aten.floor_divide.default,) | ||
| aten_floor_divide_ops = (torch.ops.aten.floor_divide.default,) | ||
|
|
||
|
|
||
| def get_floor_divide_decomposition(op) -> tuple: | ||
| """ | ||
| Returns the decomposition of the given aten.floor_div operation into | ||
| its equivalent TOSA-supported operations | ||
|
|
||
| This handles both edge dialect ops and core PyTorch ops. The decomposition strategy | ||
| is: | ||
| floor_div(x, y) → div_tensor_mode(x, y, rounding_mode="floor") | ||
|
|
||
| Returns: | ||
| A tuple (div_op,) corresponding to the appropriate operator overload for the input op. | ||
|
|
||
| Raises: | ||
| RuntimeError: If the provided operator is not a supported floor_divide variant. | ||
| """ | ||
|
|
||
| if op in edge_floor_divide_ops: | ||
| return (exir_ops.edge.aten.div.Tensor_mode,) | ||
| if op in aten_floor_divide_ops: | ||
| return (torch.ops.aten.div.Tensor_mode,) | ||
|
|
||
| raise RuntimeError(f"Can't get floor_div decomposition for op {op}") | ||
|
|
||
|
|
||
| class DecomposeFloorDividePass(ArmPass): | ||
| """ | ||
| Decomposes aten.floor_divide into aten.div.Tensor_mode with rounding_mode="floor". | ||
| """ | ||
|
|
||
| _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} | ||
|
|
||
| def call_operator(self, op, args, kwargs, meta): | ||
| if op not in (edge_floor_divide_ops + aten_floor_divide_ops): | ||
| return super().call_operator(op, args, kwargs, meta, updated=False) | ||
|
|
||
| (div_op,) = get_floor_divide_decomposition(op) | ||
|
|
||
| input = args[0] | ||
| other = args[1] | ||
|
|
||
| div_node = super().call_operator( | ||
| div_op, (input, other), {"rounding_mode": "floor"}, meta, updated=True | ||
| ) | ||
|
|
||
| return div_node | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # Copyright 2024-2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import Tuple, Union | ||
|
|
||
| import torch | ||
| from executorch.backends.arm.test import common | ||
|
|
||
| from executorch.backends.arm.test.tester.test_pipeline import ( | ||
| EthosU55PipelineINT, | ||
| EthosU85PipelineINT, | ||
| TosaPipelineFP, | ||
| TosaPipelineINT, | ||
| VgfPipeline, | ||
| ) | ||
|
|
||
| test_data_suite = { | ||
| # (test_name, input, other) | ||
| "op_floor_div_rank1_ones": lambda: ( | ||
| torch.ones(5), | ||
| torch.ones(5), | ||
| ), | ||
| "op_floor_div_rank1_rand": lambda: ( | ||
| torch.rand(5) * 5, | ||
| torch.rand(5) * 5, | ||
| ), | ||
| "op_floor_div_rank4_negative_ones": lambda: ( | ||
| (-1) * torch.ones(5, 10, 25, 20), | ||
| torch.ones(5, 10, 25, 20), | ||
| ), | ||
| "op_floor_div_rank4_ones_div_negative": lambda: ( | ||
| torch.ones(5, 10, 25, 20), | ||
| (-1) * torch.ones(5, 10, 25, 20), | ||
| ), | ||
| "op_floor_div_rank4_large_rand": lambda: ( | ||
| 200 * torch.rand(5, 10, 25, 20), | ||
| torch.rand(5, 10, 25, 20), | ||
| ), | ||
| "op_floor_div_rank4_randn_mutltiple_broadcasts": lambda: ( | ||
| torch.randn(1, 4, 4, 1), | ||
| torch.randn(1, 1, 4, 4), | ||
| ), | ||
| "op_floor_div_rank4_randn_scalar": lambda: ( | ||
| torch.randn(1, 4, 4, 1), | ||
| 2, | ||
| ), | ||
| } | ||
|
|
||
|
|
||
| class FloorDivide(torch.nn.Module): | ||
| aten_op = "torch.ops.aten.floor_divide.default" | ||
| aten_ops_int = ["aten.mul.Tensor", "aten.reciprocal.default", "aten.floor.default"] | ||
| exir_op = "executorch_exir_dialects_edge__ops_aten_div_Tensor_mode" | ||
| exir_ops_int = [ | ||
| "executorch_exir_dialects_edge__ops_aten_reciprocal_default", | ||
| "executorch_exir_dialects_edge__ops_aten_mul_Tensor", | ||
| "executorch_exir_dialects_edge__ops_aten_floor_default", | ||
| ] | ||
|
|
||
| def forward( | ||
| self, | ||
| input_: Union[torch.Tensor, torch.types.Number], | ||
| other_: Union[torch.Tensor, torch.types.Number], | ||
| ): | ||
| return torch.floor_divide(input=input_, other=other_) | ||
|
|
||
|
|
||
| input_t1 = Tuple[torch.Tensor, Union[torch.Tensor, int]] | ||
|
|
||
|
|
||
| @common.parametrize("test_data", test_data_suite) | ||
| def test_floor_divide_tosa_FP(test_data: input_t1): | ||
| pipeline = TosaPipelineFP[input_t1]( | ||
| FloorDivide(), | ||
| test_data(), | ||
| FloorDivide.aten_op, | ||
| FloorDivide.exir_op, | ||
| use_to_edge_transform_and_lower=False, | ||
| ) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| @common.parametrize("test_data", test_data_suite) | ||
| def test_floor_divide_tosa_INT(test_data: input_t1): | ||
| pipeline = TosaPipelineINT[input_t1]( | ||
| FloorDivide(), | ||
| test_data(), | ||
| aten_op=FloorDivide.aten_ops_int, | ||
| exir_op=FloorDivide.exir_ops_int, | ||
| use_to_edge_transform_and_lower=False, | ||
| ) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| @common.parametrize("test_data", test_data_suite) | ||
| @common.XfailIfNoCorstone300 | ||
| def test_floor_divide_u55_INT(test_data: input_t1): | ||
| pipeline = EthosU55PipelineINT[input_t1]( | ||
| FloorDivide(), | ||
| test_data(), | ||
| aten_ops=FloorDivide.aten_ops_int, | ||
| exir_ops=[], | ||
| run_on_fvp=True, | ||
| use_to_edge_transform_and_lower=False, | ||
| ) | ||
| pipeline.pop_stage("check_not.exir") | ||
| pipeline.pop_stage("check_count.exir") | ||
| pipeline.run() | ||
|
|
||
|
|
||
| @common.parametrize("test_data", test_data_suite) | ||
| @common.XfailIfNoCorstone320 | ||
| def test_floor_divide_u85_INT(test_data: input_t1): | ||
| pipeline = EthosU85PipelineINT[input_t1]( | ||
| FloorDivide(), | ||
| test_data(), | ||
| aten_ops=FloorDivide.aten_ops_int, | ||
| exir_ops=FloorDivide.exir_ops_int, | ||
| run_on_fvp=True, | ||
| use_to_edge_transform_and_lower=False, | ||
| ) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| @common.parametrize("test_data", test_data_suite) | ||
| @common.SkipIfNoModelConverter | ||
| def test_floor_divide_vgf_FP(test_data: input_t1): | ||
| pipeline = VgfPipeline[input_t1]( | ||
| FloorDivide(), | ||
| test_data(), | ||
| FloorDivide.aten_op, | ||
| FloorDivide.exir_op, | ||
| tosa_version="TOSA-1.0+FP", | ||
| use_to_edge_transform_and_lower=False, | ||
| ) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| @common.parametrize("test_data", test_data_suite) | ||
| @common.SkipIfNoModelConverter | ||
| def test_floor_divide_vgf_INT(test_data: input_t1): | ||
| pipeline = VgfPipeline[input_t1]( | ||
| FloorDivide(), | ||
| test_data(), | ||
| aten_op=FloorDivide.aten_ops_int, | ||
| exir_op=FloorDivide.exir_ops_int, | ||
| tosa_version="TOSA-1.0+INT", | ||
| use_to_edge_transform_and_lower=False, | ||
| ) | ||
| pipeline.run() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1