From a0a708413e7a78b2f8a2410d900d3ae09b2e79f2 Mon Sep 17 00:00:00 2001 From: Iliyan Georgiev Date: Fri, 28 Mar 2025 16:02:55 +0000 Subject: [PATCH] Arm backend: Add GELU operator - Add GELU decomposition pass - BI handled by table op pass - Add tests - xfail partitioning test with task to revisit Signed-off-by: Iliyan Georgiev Change-Id: I794c84d88ae5e34c3732d458f548dd6ea0efe679 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/_passes/decompose_gelu_pass.py | 149 ++++++++++++++++++ backends/arm/_passes/insert_table_ops.py | 14 ++ .../tosa_supported_operators.py | 2 + .../arm/quantizer/quantization_annotator.py | 1 + ...test_partition_decomposed_quantized_ops.py | 7 +- backends/arm/test/ops/test_gelu.py | 125 +++++++++++++++ 8 files changed, 300 insertions(+), 1 deletion(-) create mode 100644 backends/arm/_passes/decompose_gelu_pass.py create mode 100644 backends/arm/test/ops/test_gelu.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 2b935172a7f..1a1719bf8ae 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -20,6 +20,7 @@ from .convert_to_clamp import ConvertToClampPass # noqa from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa +from .decompose_gelu_pass import DecomposeGeluPass # noqa from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa from .decompose_linear_pass import DecomposeLinearPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 85004686ebe..261ee045790 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -25,6 +25,7 @@ ConvertToClampPass, DecomposeBatchNormPass, DecomposeDivPass, + DecomposeGeluPass, DecomposeLayerNormPass, DecomposeLeakyReLUPass, DecomposeLinearPass, @@ -132,6 +133,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertMeanDimToAveragePoolPass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeSoftmaxPass()) + self.add_pass(DecomposeGeluPass()) self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py new file mode 100644 index 00000000000..6e72175e68b --- /dev/null +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -0,0 +1,149 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.arm_pass_utils import get_node_arg +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +torch_gelu = (torch.ops.aten.gelu.default,) + +edge_gelu = (exir_ops.edge.aten.gelu.default,) + + +def _get_gelu_ops(op) -> tuple: + """ + Returns the operators needed to decompose GELU + """ + + if op in edge_gelu: + return ( + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.erf.default, + ) + if op in torch_gelu: + return ( + torch.ops.aten.full.default, + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + torch.ops.aten.tanh.default, + torch.ops.aten.erf.default, + ) + raise RuntimeError(f"Can't get GeLU decomposition ops for op {op}") + + +class DecomposeGeluPass(ExportPass): + """ + This pass decomposes the GELU operator into primitive ops. + Aiming to adhere closely to the reference implementations built into + ExecuTorch. Including using the same pre-calculated constants. + + This operator has two formulae depending on the value of the + approximate argument. Examples below include the added full + operators necessary for the initialization for constants used in + each respective formula. + + aten.gelu(x, approximate="none") becomes: + %FULL_0_5 = full() + %FULL_1 = full() + %FULL_SQRT1_2 = full() + %op1 = mul(x, %FULL_SQRT1_2) + %op2 = erf(%op1) + %op3 = add(%op2, %FULL_1) + %op4 = mul(%op3, %FULL_0_5) + %op5 = mul(%x, %op4) + + aten.gelu(x, approximate="tanh") becomes: + %FULL_0_5 = full() + %FULL_1 = full() + %FULL_SQRT2 = full() + %FULL_2_SQRTPI = full() + %FULL_CUBE_COEFF = full() + %SQRT_MUL = mul(%FULL_SQRT2, %FULL_2_SQRTPI) + %SQRT_2_PI = mul(%SQRT_MUL, %FULL_0_5) + %sqr_x = mul(x, x) + %cube_x = mul(sqr_x, x) + %op1 = mul(%cube_x, %FULL_CUBE_COEFF) + %op2 = add(%x, %op1) + %op3 = mul(%op2, %SQRT_2_PI) + %op4 = tanh(%op3) + %op5 = add(%op4, %FULL_1) + %op6 = mul(%x, %op5) + %op7 = mul(%op6, %FULL_0_5) + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in torch_gelu + edge_gelu: + return super().call_operator(op, args, kwargs, meta) + + full_op, add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op) + + input = get_node_arg(args, 0) + # If approximate is default (none) it does not appear in kwargs + approximate = get_node_arg(kwargs, "approximate", "none") + + shape = meta["val"].size() + dtype = meta["val"].dtype + + FULL_0_5 = super().call_operator( + full_op, ([1] * len(shape), 0.5), {"dtype": dtype}, meta + ) + FULL_1 = super().call_operator( + full_op, ([1] * len(shape), 1), {"dtype": dtype}, meta + ) + + if approximate == "none": + # Constant mirrors ExecuTorch implementation for parity. + FULL_SQRT1_2 = super().call_operator( + full_op, ([1] * len(shape), 0.70710678118654752440), {}, meta + ) + + op1 = super().call_operator(mul_op, (input, FULL_SQRT1_2), {}, meta) + op2 = super().call_operator(erf_op, (op1,), {}, meta) + op3 = super().call_operator(add_op, (op2, FULL_1), {}, meta) + op4 = super().call_operator(mul_op, (op3, FULL_0_5), {}, meta) + return super().call_operator(mul_op, (input, op4), {}, meta) + + elif approximate == "tanh": + # Constants mirror ExecuTorch implementation for parity. + FULL_SQRT2 = super().call_operator( + full_op, + ([1] * len(shape), 1.41421356237309504880), + {"dtype": dtype}, + meta, + ) + FULL_2_SQRTPI = super().call_operator( + full_op, + ([1] * len(shape), 1.12837916709551257390), + {"dtype": dtype}, + meta, + ) + FULL_CUBE_COEFF = super().call_operator( + full_op, ([1] * len(shape), 0.044715), {"dtype": dtype}, meta + ) + + # Mirrors ExecuTorch implementations for calculating this value + SQRT_MUL = super().call_operator( + mul_op, (FULL_SQRT2, FULL_2_SQRTPI), {}, meta + ) + SQRT_2_PI = super().call_operator(mul_op, (SQRT_MUL, FULL_0_5), {}, meta) + + # Avoiding using POW in order to reduce pass order reliance. + sqr_x = super().call_operator(mul_op, (input, input), {}, meta) + cube_x = super().call_operator(mul_op, (sqr_x, input), {}, meta) + op1 = super().call_operator(mul_op, (cube_x, FULL_CUBE_COEFF), {}, meta) + op2 = super().call_operator(add_op, (input, op1), {}, meta) + op3 = super().call_operator(mul_op, (op2, SQRT_2_PI), {}, meta) + op4 = super().call_operator(tanh_op, (op3,), {}, meta) + op5 = super().call_operator(add_op, (op4, FULL_1), {}, meta) + op6 = super().call_operator(mul_op, (input, op5), {}, meta) + return super().call_operator(mul_op, (op6, FULL_0_5), {}, meta) + else: + raise RuntimeError( + f"approximate argument expected 'none' or 'tanh' but got {approximate}" + ) diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 02510600d82..a5f66829da5 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -56,6 +56,7 @@ class TableOps: # Targets that must be treated explicitly special_table_ops: Set[EdgeOpOverload] = { exir_ops.edge.aten.pow.Tensor_Scalar, + exir_ops.edge.aten.gelu.default, } def __init__(self, exported_program: ExportedProgram): @@ -76,6 +77,19 @@ def __getitem__(self, node: Node): # Exponent is a constant. Embed it into a lambda. exp = cast(int, node.args[1]) return lambda x: torch.pow(x, exp).flatten() + case exir_ops.edge.aten.gelu.default: + # If kwargs not present it is default "none" + approximate = cast( + str, + ( + node.kwargs["approximate"] + if "approximate" in node.kwargs + else "none" + ), + ) + return lambda x: torch.nn.functional.gelu( + x, approximate=approximate + ).flatten() case _: # Op must be handled if it's inside self.special_ops raise AssertionError("Unhandled table operation") diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 1cec35923de..f84bde7fadc 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -225,6 +225,7 @@ def is_node_supported( exir_ops.edge.aten.bitwise_left_shift.Tensor, exir_ops.edge.aten.__lshift__.Scalar, torch.ops.aten.scalar_tensor.default, + exir_ops.edge.aten.gelu.default, ] return supported @@ -361,6 +362,7 @@ def is_node_supported( exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.upsample_nearest2d.vec, + exir_ops.edge.aten.gelu.default, ): return True elif node.target in ( diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index b0f9e90b10f..1575d59fd77 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -178,6 +178,7 @@ def _match_pattern( torch.ops.aten.hardswish_.default, torch.ops.aten.full_like.default, torch.ops.aten.pow.Tensor_Scalar, + torch.ops.aten.gelu.default, ] _one_to_one_shared_input_qspec = [ diff --git a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py index f69d9d34462..49efbbb4a9c 100644 --- a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py +++ b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py @@ -117,7 +117,12 @@ def test_softplus_tosa_BI(test_data: input_t1): # Since GELU will not be quantized by TosaQuantizer, the Dropout's input will not be quantized either. # If so, the Dropout should not be partitioned by TosaPartitioner for TOSA BI profile. This test tests that the # partitioner indeed does not partition the Dropout (clone) for TOSA BI. -@common.parametrize("test_data", test_data) +@common.parametrize( + "test_data", + test_data, + {"3d_rand": "MLETORCH-909: Partition test to not rely on unsupported ops"}, + strict=False, +) def test_linear_residaul_tosa_MI(test_data: input_t1): pipeline = TosaPipelineMI[input_t1]( LinearResidualModule(), diff --git a/backends/arm/test/ops/test_gelu.py b/backends/arm/test/ops/test_gelu.py new file mode 100644 index 00000000000..fb1253fdb0c --- /dev/null +++ b/backends/arm/test/ops/test_gelu.py @@ -0,0 +1,125 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +input_t1 = Tuple[torch.Tensor] + + +class Gelu(torch.nn.Module): + aten_op = "torch.ops.aten.gelu.default" + exir_op = "executorch_exir_dialects_edge__ops_aten_gelu_default" + + test_data: dict[str, Tuple[str, input_t1]] = { + "zeros_none": ( + "none", + torch.zeros(1, 10, 10, 10), + ), + "ones_none": ( + "none", + torch.ones(10, 10, 10), + ), + "rand_none": ( + "none", + (torch.rand(10, 10) - 0.5), + ), + "randn_pos_none": ( + "none", + (torch.randn(1, 4, 4, 4) + 10), + ), + "randn_neg_none": ( + "none", + (torch.randn(1, 4, 4, 4) - 10), + ), + "ramp_none": ( + "none", + torch.arange(-16, 16, 0.2), + ), + "zeros_tanh": ( + "tanh", + torch.zeros(1, 10, 10, 10), + ), + "ones_tanh": ( + "tanh", + torch.ones(10, 10, 10), + ), + "rand_tanh": ( + "tanh", + (torch.rand(10, 10) - 0.5), + ), + "randn_pos_tanh": ( + "tanh", + (torch.randn(1, 4, 4, 4) + 10), + ), + "randn_neg_tanh": ( + "tanh", + (torch.randn(1, 4, 4, 4) - 10), + ), + "ramp_tanh": ( + "tanh", + torch.arange(-16, 16, 0.2), + ), + } + + def __init__(self, approximate: str = "none"): + super().__init__() + self.gelu = torch.nn.GELU(approximate) + + def forward(self, x: torch.Tensor): + return self.gelu(x) + + +@common.parametrize("test_data", Gelu.test_data) +def test_gelu_tosa_MI(test_data: input_t1): + approximate = test_data[0] + TosaPipelineMI[input_t1]( + Gelu(approximate), + (test_data[1],), + Gelu.aten_op, + Gelu.exir_op, + use_to_edge_transform_and_lower=False, + ).run() + + +@common.parametrize("test_data", Gelu.test_data) +def test_gelu_tosa_BI(test_data: input_t1): + approximate = test_data[0] + TosaPipelineBI[input_t1]( + Gelu(approximate), + (test_data[1],), + Gelu.aten_op, + Gelu.exir_op, + ).run() + + +@common.parametrize("test_data", Gelu.test_data) +def test_gelu_u55_BI(test_data: input_t1): + approximate = test_data[0] + EthosU55PipelineBI[input_t1]( + Gelu(approximate), + (test_data[1],), + Gelu.aten_op, + Gelu.exir_op, + ).run() + + +@common.parametrize("test_data", Gelu.test_data) +def test_gelu_u85_BI(test_data: input_t1): + approximate = test_data[0] + EthosU85PipelineBI[input_t1]( + Gelu(approximate), + (test_data[1],), + Gelu.aten_op, + Gelu.exir_op, + ).run()