Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .convert_to_clamp import ConvertToClampPass # noqa
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
from .decompose_div_pass import DecomposeDivPass # noqa
from .decompose_gelu_pass import DecomposeGeluPass # noqa
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
from .decompose_linear_pass import DecomposeLinearPass # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ConvertToClampPass,
DecomposeBatchNormPass,
DecomposeDivPass,
DecomposeGeluPass,
DecomposeLayerNormPass,
DecomposeLeakyReLUPass,
DecomposeLinearPass,
Expand Down Expand Up @@ -132,6 +133,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeSoftmaxPass())
self.add_pass(DecomposeGeluPass())
self.add_pass(ConvertFullLikeToFullPass())
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
Expand Down
149 changes: 149 additions & 0 deletions backends/arm/_passes/decompose_gelu_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

torch_gelu = (torch.ops.aten.gelu.default,)

edge_gelu = (exir_ops.edge.aten.gelu.default,)


def _get_gelu_ops(op) -> tuple:
"""
Returns the operators needed to decompose GELU
"""

if op in edge_gelu:
return (
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.erf.default,
)
if op in torch_gelu:
return (
torch.ops.aten.full.default,
torch.ops.aten.add.Tensor,
torch.ops.aten.mul.Tensor,
torch.ops.aten.tanh.default,
torch.ops.aten.erf.default,
)
raise RuntimeError(f"Can't get GeLU decomposition ops for op {op}")


class DecomposeGeluPass(ExportPass):
"""
This pass decomposes the GELU operator into primitive ops.
Aiming to adhere closely to the reference implementations built into
ExecuTorch. Including using the same pre-calculated constants.

This operator has two formulae depending on the value of the
approximate argument. Examples below include the added full
operators necessary for the initialization for constants used in
each respective formula.

aten.gelu(x, approximate="none") becomes:
%FULL_0_5 = full()
%FULL_1 = full()
%FULL_SQRT1_2 = full()
%op1 = mul(x, %FULL_SQRT1_2)
%op2 = erf(%op1)
%op3 = add(%op2, %FULL_1)
%op4 = mul(%op3, %FULL_0_5)
%op5 = mul(%x, %op4)

aten.gelu(x, approximate="tanh") becomes:
%FULL_0_5 = full()
%FULL_1 = full()
%FULL_SQRT2 = full()
%FULL_2_SQRTPI = full()
%FULL_CUBE_COEFF = full()
%SQRT_MUL = mul(%FULL_SQRT2, %FULL_2_SQRTPI)
%SQRT_2_PI = mul(%SQRT_MUL, %FULL_0_5)
%sqr_x = mul(x, x)
%cube_x = mul(sqr_x, x)
%op1 = mul(%cube_x, %FULL_CUBE_COEFF)
%op2 = add(%x, %op1)
%op3 = mul(%op2, %SQRT_2_PI)
%op4 = tanh(%op3)
%op5 = add(%op4, %FULL_1)
%op6 = mul(%x, %op5)
%op7 = mul(%op6, %FULL_0_5)
"""

def call_operator(self, op, args, kwargs, meta):
if op not in torch_gelu + edge_gelu:
return super().call_operator(op, args, kwargs, meta)

full_op, add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op)

input = get_node_arg(args, 0)
# If approximate is default (none) it does not appear in kwargs
approximate = get_node_arg(kwargs, "approximate", "none")

shape = meta["val"].size()
dtype = meta["val"].dtype

FULL_0_5 = super().call_operator(
full_op, ([1] * len(shape), 0.5), {"dtype": dtype}, meta
)
FULL_1 = super().call_operator(
full_op, ([1] * len(shape), 1), {"dtype": dtype}, meta
)

if approximate == "none":
# Constant mirrors ExecuTorch implementation for parity.
FULL_SQRT1_2 = super().call_operator(
full_op, ([1] * len(shape), 0.70710678118654752440), {}, meta
)

op1 = super().call_operator(mul_op, (input, FULL_SQRT1_2), {}, meta)
op2 = super().call_operator(erf_op, (op1,), {}, meta)
op3 = super().call_operator(add_op, (op2, FULL_1), {}, meta)
op4 = super().call_operator(mul_op, (op3, FULL_0_5), {}, meta)
return super().call_operator(mul_op, (input, op4), {}, meta)

elif approximate == "tanh":
# Constants mirror ExecuTorch implementation for parity.
FULL_SQRT2 = super().call_operator(
full_op,
([1] * len(shape), 1.41421356237309504880),
{"dtype": dtype},
meta,
)
FULL_2_SQRTPI = super().call_operator(
full_op,
([1] * len(shape), 1.12837916709551257390),
{"dtype": dtype},
meta,
)
FULL_CUBE_COEFF = super().call_operator(
full_op, ([1] * len(shape), 0.044715), {"dtype": dtype}, meta
)

# Mirrors ExecuTorch implementations for calculating this value
SQRT_MUL = super().call_operator(
mul_op, (FULL_SQRT2, FULL_2_SQRTPI), {}, meta
)
SQRT_2_PI = super().call_operator(mul_op, (SQRT_MUL, FULL_0_5), {}, meta)

# Avoiding using POW in order to reduce pass order reliance.
sqr_x = super().call_operator(mul_op, (input, input), {}, meta)
cube_x = super().call_operator(mul_op, (sqr_x, input), {}, meta)
op1 = super().call_operator(mul_op, (cube_x, FULL_CUBE_COEFF), {}, meta)
op2 = super().call_operator(add_op, (input, op1), {}, meta)
op3 = super().call_operator(mul_op, (op2, SQRT_2_PI), {}, meta)
op4 = super().call_operator(tanh_op, (op3,), {}, meta)
op5 = super().call_operator(add_op, (op4, FULL_1), {}, meta)
op6 = super().call_operator(mul_op, (input, op5), {}, meta)
return super().call_operator(mul_op, (op6, FULL_0_5), {}, meta)
else:
raise RuntimeError(
f"approximate argument expected 'none' or 'tanh' but got {approximate}"
)
14 changes: 14 additions & 0 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class TableOps:
# Targets that must be treated explicitly
special_table_ops: Set[EdgeOpOverload] = {
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.gelu.default,
}

def __init__(self, exported_program: ExportedProgram):
Expand All @@ -76,6 +77,19 @@ def __getitem__(self, node: Node):
# Exponent is a constant. Embed it into a lambda.
exp = cast(int, node.args[1])
return lambda x: torch.pow(x, exp).flatten()
case exir_ops.edge.aten.gelu.default:
# If kwargs not present it is default "none"
approximate = cast(
str,
(
node.kwargs["approximate"]
if "approximate" in node.kwargs
else "none"
),
)
return lambda x: torch.nn.functional.gelu(
x, approximate=approximate
).flatten()
case _:
# Op must be handled if it's inside self.special_ops
raise AssertionError("Unhandled table operation")
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def is_node_supported(
exir_ops.edge.aten.bitwise_left_shift.Tensor,
exir_ops.edge.aten.__lshift__.Scalar,
torch.ops.aten.scalar_tensor.default,
exir_ops.edge.aten.gelu.default,
]

return supported
Expand Down Expand Up @@ -361,6 +362,7 @@ def is_node_supported(
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.gelu.default,
):
return True
elif node.target in (
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def _match_pattern(
torch.ops.aten.hardswish_.default,
torch.ops.aten.full_like.default,
torch.ops.aten.pow.Tensor_Scalar,
torch.ops.aten.gelu.default,
]

_one_to_one_shared_input_qspec = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ def test_softplus_tosa_BI(test_data: input_t1):
# Since GELU will not be quantized by TosaQuantizer, the Dropout's input will not be quantized either.
# If so, the Dropout should not be partitioned by TosaPartitioner for TOSA BI profile. This test tests that the
# partitioner indeed does not partition the Dropout (clone) for TOSA BI.
@common.parametrize("test_data", test_data)
@common.parametrize(
"test_data",
test_data,
{"3d_rand": "MLETORCH-909: Partition test to not rely on unsupported ops"},
strict=False,
)
def test_linear_residaul_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](
LinearResidualModule(),
Expand Down
125 changes: 125 additions & 0 deletions backends/arm/test/ops/test_gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)

input_t1 = Tuple[torch.Tensor]


class Gelu(torch.nn.Module):
aten_op = "torch.ops.aten.gelu.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_gelu_default"

test_data: dict[str, Tuple[str, input_t1]] = {
"zeros_none": (
"none",
torch.zeros(1, 10, 10, 10),
),
"ones_none": (
"none",
torch.ones(10, 10, 10),
),
"rand_none": (
"none",
(torch.rand(10, 10) - 0.5),
),
"randn_pos_none": (
"none",
(torch.randn(1, 4, 4, 4) + 10),
),
"randn_neg_none": (
"none",
(torch.randn(1, 4, 4, 4) - 10),
),
"ramp_none": (
"none",
torch.arange(-16, 16, 0.2),
),
"zeros_tanh": (
"tanh",
torch.zeros(1, 10, 10, 10),
),
"ones_tanh": (
"tanh",
torch.ones(10, 10, 10),
),
"rand_tanh": (
"tanh",
(torch.rand(10, 10) - 0.5),
),
"randn_pos_tanh": (
"tanh",
(torch.randn(1, 4, 4, 4) + 10),
),
"randn_neg_tanh": (
"tanh",
(torch.randn(1, 4, 4, 4) - 10),
),
"ramp_tanh": (
"tanh",
torch.arange(-16, 16, 0.2),
),
}

def __init__(self, approximate: str = "none"):
super().__init__()
self.gelu = torch.nn.GELU(approximate)

def forward(self, x: torch.Tensor):
return self.gelu(x)


@common.parametrize("test_data", Gelu.test_data)
def test_gelu_tosa_MI(test_data: input_t1):
approximate = test_data[0]
TosaPipelineMI[input_t1](
Gelu(approximate),
(test_data[1],),
Gelu.aten_op,
Gelu.exir_op,
use_to_edge_transform_and_lower=False,
).run()


@common.parametrize("test_data", Gelu.test_data)
def test_gelu_tosa_BI(test_data: input_t1):
approximate = test_data[0]
TosaPipelineBI[input_t1](
Gelu(approximate),
(test_data[1],),
Gelu.aten_op,
Gelu.exir_op,
).run()


@common.parametrize("test_data", Gelu.test_data)
def test_gelu_u55_BI(test_data: input_t1):
approximate = test_data[0]
EthosU55PipelineBI[input_t1](
Gelu(approximate),
(test_data[1],),
Gelu.aten_op,
Gelu.exir_op,
).run()


@common.parametrize("test_data", Gelu.test_data)
def test_gelu_u85_BI(test_data: input_t1):
approximate = test_data[0]
EthosU85PipelineBI[input_t1](
Gelu(approximate),
(test_data[1],),
Gelu.aten_op,
Gelu.exir_op,
).run()
Loading