From a23d0c1336674b3efa990a895c3058fcc755a5b8 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Sun, 13 Apr 2025 21:59:18 -0700 Subject: [PATCH] Add approximate gelu replacement to opt level 2 (#10129) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/10129 As titled. Gelu is prohibitively expensive to run on DSPs, due to the std::erf call in the function. The PT approximate version using an approximation based on `tanh`, which is faster on the ASR encoder 27M model for example. Seems like BUCK files (even with just on_call commands, the linter is complaining). Differential Revision: D72935935 --- backends/cadence/aot/replace_ops.py | 97 +++++++++++++++++++ .../aot/tests/test_replace_ops_passes.py | 36 +++++++ 2 files changed, 133 insertions(+) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 5a4922ae069..f2cda53981d 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2110,6 +2110,102 @@ def call_operator( return super().call_operator(op, args, kwargs, meta) +@register_cadence_pass(CadencePassAttribute(opt_level=2)) +class ReplaceGeluWithApproximateGeluPass(ExportPass): + """ + Replace the gelu op with an approximate gelu op. The approximate gelu op + is more efficient on DSP backends. + """ + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in { + exir_ops.edge.aten.gelu.default, + }: + return super().call_operator(op, args, kwargs, meta) + + # compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi)) + # as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))) + + # Get 0.5 * x + half = super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (args[0], 0.5), + {}, + meta, + ) + + scaled = super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (args[0], 0.044715), + {}, + meta, + ) + + # Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because + # it is much more efficient on DSP backends) + scaled_square = super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (scaled, args[0]), + {}, + meta, + ) + + # Get x^3 + scaled_cubed = super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (scaled_square, args[0]), + {}, + meta, + ) + + # Get x + 0.044715 * x^3 + inner_sum = super().call_operator( + exir_ops.edge.aten.add.Tensor, + (scaled_cubed, args[0]), + {}, + meta, + ) + + # Get 0.7978845608028654 * ( x + 0.044715 * x^3) + scaled_sum = super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (inner_sum, 0.7978845608028654), + {}, + meta, + ) + + # Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)) + tanh = super().call_operator( + exir_ops.edge.aten.tanh.default, + (scaled_sum,), + {}, + meta, + ) + + # Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3)) + # TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.) + outer_sum = super().call_operator( + exir_ops.edge.aten.add.Tensor, + (tanh, 1.0), + {}, + meta, + ) + + # Retunr the final result + return super().call_operator( + exir_ops.edge.aten.mul.Tensor, + (half, outer_sum), + {}, + meta, + ) + + # This class encapsulates all the functions that replace/switch one op in the # graph with another. class CadenceReplaceOpsInGraph: @@ -2149,4 +2245,5 @@ class CadenceReplaceOpsInGraph: ReplaceAtenAvgPoolWithJarvisAvgPoolPass, ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass, ReplaceWhereWithFullArgsWithWhereScalar, + # ReplaceGeluWithApproximateGeluPass, ] diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index e40c26c0f4e..c22dd947d5d 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -29,6 +29,7 @@ ReplaceConvWithIm2RowAndLinear, ReplaceEmptyTensorsWithFullPass, ReplaceFunctionallyEquivalentOpTargets, + ReplaceGeluWithApproximateGeluPass, ReplaceIm2RowWithViewPass, ReplaceLinearWithFullyConnectedOpPass, ReplaceMMWithAddMMPass, @@ -1301,6 +1302,41 @@ def forward(self, cond: torch.Tensor): 1, ) + def test_replace_aten_gelu_with_approximate_gelu(self): + class Gelu(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.gelu(input) + + inputs = torch.randn(2, 1, 64) + + graph_module = export_to_edge(Gelu(), (inputs,)).exported_program().graph_module + + p = ReplaceGeluWithApproximateGeluPass() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + # Assert that aten.gelu op was decomposed + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.aten.gelu.default, + ), + 0, + ) + + # The decomposition should have one tanh, 2 add and 6 mul + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.tanh.default), + 1, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), + 2, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), + 6, + ) + class TestReplaceIm2rowWithViewPass(unittest.TestCase): def test_no_replacement_for_conv(self):