From 495abcddeb3b5639a48d6a7b605230458457e835 Mon Sep 17 00:00:00 2001 From: Ethan Ng Date: Wed, 4 Jun 2025 13:19:00 -0700 Subject: [PATCH] Create decompose_ops.py and test_decompose_ops.py (#11299) Summary: Create new class and test suite for passes that decompose an op into a equivalent series of simpler ops Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Reviewed By: hsharma35 Differential Revision: D75826474 Pulled By: ethanng72 --- backends/cadence/aot/TARGETS | 39 ++++++ backends/cadence/aot/decompose_ops.py | 122 ++++++++++++++++++ backends/cadence/aot/replace_ops.py | 80 +----------- .../aot/tests/test_decompose_ops_passes.py | 80 ++++++++++++ .../aot/tests/test_replace_ops_passes.py | 37 ------ 5 files changed, 242 insertions(+), 116 deletions(-) create mode 100644 backends/cadence/aot/decompose_ops.py create mode 100644 backends/cadence/aot/tests/test_decompose_ops_passes.py diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 1613cfb28ca..a0de747cf3f 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -276,6 +276,24 @@ python_library( ], ) +python_library( + name = "decompose_ops", + srcs = [ + "decompose_ops.py", + ], + typing = True, + deps = [ + ":pass_utils", + "//caffe2:torch", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/exir/dialects/edge:lib", + "//executorch/exir/passes:spec_prop_pass", + ], +) + + python_unittest( name = "test_graph_builder", srcs = [ @@ -314,6 +332,27 @@ python_unittest( ], ) +python_unittest( + name = "test_decompose_ops_passes", + srcs = [ + "tests/test_decompose_ops_passes.py", + ], + supports_static_listing = False, + typing = True, + deps = [ + "fbsource//third-party/pypi/parameterized:parameterized", + ":compiler", + ":decompose_ops", + "//caffe2:torch", + "//executorch/backends/cadence/aot:compiler", + "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/exir/passes:lib", + ], +) + python_unittest( name = "test_fusion_ops_passes", srcs = [ diff --git a/backends/cadence/aot/decompose_ops.py b/backends/cadence/aot/decompose_ops.py new file mode 100644 index 00000000000..60514c52902 --- /dev/null +++ b/backends/cadence/aot/decompose_ops.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# This file contains all the functions that decompose one op into simpler ops in the +# graph. The functions decomposing ops for models deployed with Jarvis are grouped +# together in class 'DecomposeOpsInGraph'. Some examples of functions in the class are +# 1. functions that decompose an ATen gelu op into an equivalent series of simpler ops + +# pyre-strict + +from typing import Dict + +from executorch.backends.cadence.aot.pass_utils import ( + CadencePassAttribute, + register_cadence_pass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch.fx.node import Argument + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class DecomposeAtenApproxGeluPass(ExportPass): + """ + Decompose the aten gelu op with an approximate arg to a series of simpler ops + """ + + def call_operator( + self, + op: EdgeOpOverload, + args: tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + # compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi)) + # as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))) + + # Get 0.5 * x + half = super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (args[0], 0.5), + {}, + meta, + ) + + scaled = super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (args[0], 0.044715), + {}, + meta, + ) + + # Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because + # it is much more efficient on DSP backends) + scaled_square = super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (scaled, args[0]), + {}, + meta, + ) + + # Get x^3 + scaled_cubed = super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (scaled_square, args[0]), + {}, + meta, + ) + + # Get x + 0.044715 * x^3 + inner_sum = super().call_operator( + exir_ops.edge.aten.add.Tensor, + (scaled_cubed, args[0]), + {}, + meta, + ) + + # Get 0.7978845608028654 * ( x + 0.044715 * x^3) + scaled_sum = super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (inner_sum, 0.7978845608028654), + {}, + meta, + ) + + # Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)) + tanh = super().call_operator( + exir_ops.edge.aten.tanh.default, + (scaled_sum,), + {}, + meta, + ) + + # Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3)) + # TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.) + outer_sum = super().call_operator( + exir_ops.edge.aten.add.Tensor, + (tanh, 1.0), + {}, + meta, + ) + + # Return the final result + return super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (half, outer_sum), + {}, + meta, + ) + + +# This class encapsulates all the functions that decompose one op in the graph. +class CadenceDecomposeOpsInGraph: + passes = [ + DecomposeAtenApproxGeluPass, + ] diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index e5a88c10a3f..d78bdfeba6e 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2078,89 +2078,11 @@ def call_operator( kwargs: Dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: - if "approximate" not in kwargs: - return super().call_operator(op, args, kwargs, meta) - if op not in { exir_ops.edge.aten.gelu.default, }: return super().call_operator(op, args, kwargs, meta) - - # compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi)) - # as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))) - - # Get 0.5 * x - half = super().call_operator( - exir_ops.edge.aten.mul.Tensor, - (args[0], 0.5), - {}, - meta, - ) - - scaled = super().call_operator( - exir_ops.edge.aten.mul.Tensor, - (args[0], 0.044715), - {}, - meta, - ) - - # Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because - # it is much more efficient on DSP backends) - scaled_square = super().call_operator( - exir_ops.edge.aten.mul.Tensor, - (scaled, args[0]), - {}, - meta, - ) - - # Get x^3 - scaled_cubed = super().call_operator( - exir_ops.edge.aten.mul.Tensor, - (scaled_square, args[0]), - {}, - meta, - ) - - # Get x + 0.044715 * x^3 - inner_sum = super().call_operator( - exir_ops.edge.aten.add.Tensor, - (scaled_cubed, args[0]), - {}, - meta, - ) - - # Get 0.7978845608028654 * ( x + 0.044715 * x^3) - scaled_sum = super().call_operator( - exir_ops.edge.aten.mul.Tensor, - (inner_sum, 0.7978845608028654), - {}, - meta, - ) - - # Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)) - tanh = super().call_operator( - exir_ops.edge.aten.tanh.default, - (scaled_sum,), - {}, - meta, - ) - - # Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3)) - # TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.) - outer_sum = super().call_operator( - exir_ops.edge.aten.add.Tensor, - (tanh, 1.0), - {}, - meta, - ) - - # Retunr the final result - return super().call_operator( - exir_ops.edge.aten.mul.Tensor, - (half, outer_sum), - {}, - meta, - ) + return super().call_operator(op, args, kwargs, meta) # Adapted from fbcode/pyspeech/opt_passes/replace_ops.py diff --git a/backends/cadence/aot/tests/test_decompose_ops_passes.py b/backends/cadence/aot/tests/test_decompose_ops_passes.py new file mode 100644 index 00000000000..e4bdf42ff62 --- /dev/null +++ b/backends/cadence/aot/tests/test_decompose_ops_passes.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Union + +import torch +from executorch.backends.cadence.aot.decompose_ops import DecomposeAtenApproxGeluPass +from executorch.backends.cadence.aot.graph_builder import single_op_builder +from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass + + +class TestDecomposeOpsPasses(unittest.TestCase): + def assertTargetCountEqual( + self, + graph_module: torch.fx.GraphModule, + target: Union[EdgeOpOverload, str], + expected_count: int, + ) -> None: + """Helper function to check the number of nodes with a given target.""" + actual_count = count_node(graph_module, target) + self.assertEqual( + actual_count, + expected_count, + f"{target} count mismatch for graph {graph_module}", + ) + + def assertTargetCountsEqual( + self, + graph_module: torch.fx.GraphModule, + targets_and_counts: list[tuple[Union[EdgeOpOverload, str], int]], + ) -> None: + """Helper function to check the number of nodes of all types for a given target.""" + for target, expected_count in targets_and_counts: + self.assertTargetCountEqual(graph_module, target, expected_count) + + def test_decompose_aten_approximate_gelu(self) -> None: + inputs = torch.randn(2, 1, 64) + + gm = single_op_builder( + placeholders=(inputs,), + op=exir_ops.edge.aten.gelu.default, + args=(inputs,), + kwargs={"approximate": "tanh"}, + ) + gm = ExportPass().call(gm).graph_module + + p = DecomposeAtenApproxGeluPass() + graph_after_passes = p.call(gm).graph_module + + # Assert that aten.gelu op was decomposed + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.aten.gelu.default, + ), + 0, + ) + + # The decomposition should have one tanh, 2 add and 6 mul + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.tanh.default), + 1, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), + 2, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), + 6, + ) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index e8215c378f9..25526cd2779 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -1309,43 +1309,6 @@ def test_no_replace_aten_gelu_with_approximate_gelu(self): 1, ) - def test_replace_aten_approximate_gelu_with_approximate_gelu(self): - inputs = torch.randn(2, 1, 64) - - gm = single_op_builder( - placeholders=(inputs,), - op=exir_ops.edge.aten.gelu.default, - args=(inputs,), - kwargs={"approximate": "tanh"}, - ) - gm = ExportPass().call(gm).graph_module - - p = ReplaceAtenApproxGeluWithApproxGeluPass() - graph_after_passes = p.call(gm).graph_module - - # Assert that aten.gelu op was decomposed - self.assertEqual( - count_node( - graph_after_passes, - exir_ops.edge.aten.gelu.default, - ), - 0, - ) - - # The decomposition should have one tanh, 2 add and 6 mul - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.tanh.default), - 1, - ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), - 2, - ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), - 6, - ) - def test_replace_split_with_sizes_with_slice(self): builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 16, 8, 4))