From c9eb4e96bc33f08d85d68e08305b6fc051ce49eb Mon Sep 17 00:00:00 2001 From: Michael Maitland Date: Mon, 2 Jun 2025 12:09:25 -0700 Subject: [PATCH] Support approximate gelu (#11246) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11246 GELU accepts an `approximate` argument which is either `none` by default, or `tanh` When the `approximate` kwarg is present, decompose the op. We already have an existing test in test_aten_gelu_out to make sure the op is supported. Reviewed By: zonglinpeng, hsharma35 Differential Revision: D75454999 --- backends/cadence/aot/replace_ops.py | 12 +++--- .../aot/tests/test_replace_ops_passes.py | 40 +++++++++++++++---- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index e10bbf2b39a..e5a88c10a3f 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2065,11 +2065,10 @@ def call_operator( return super().call_operator(op, args, kwargs, meta) -@register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceGeluWithApproximateGeluPass(ExportPass): +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceAtenApproxGeluWithApproxGeluPass(ExportPass): """ - Replace the gelu op with an approximate gelu op. The approximate gelu op - is more efficient on DSP backends. + Replace the aten gelu op with an approximate arg with an approximate gelu op. """ def call_operator( @@ -2079,6 +2078,9 @@ def call_operator( kwargs: Dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: + if "approximate" not in kwargs: + return super().call_operator(op, args, kwargs, meta) + if op not in { exir_ops.edge.aten.gelu.default, }: @@ -2414,7 +2416,7 @@ class CadenceReplaceOpsInGraph: ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, ReplaceAtenAvgPoolWithJarvisAvgPoolPass, ReplaceWhereWithFullArgsWithWhereScalar, - ReplaceGeluWithApproximateGeluPass, + ReplaceAtenApproxGeluWithApproxGeluPass, ReplaceSplitWithSlicePass, ReplacePowWithMulPass, ] diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index e7bf8e9cefa..e8215c378f9 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -26,13 +26,13 @@ ForceChannelLastForConvPass, MakeSliceAndCatDimOutermostPass, ReplaceAddMMWithLinearPass, + ReplaceAtenApproxGeluWithApproxGeluPass, ReplaceAtenConvolutionWithJarvisConvolutionPass, ReplaceConstantPadNdWithSlicePass, ReplaceConvolutionOptionalArgsWithConcreteArgsPass, ReplaceConvWithIm2RowAndLinear, ReplaceEmptyTensorsWithFullPass, ReplaceFunctionallyEquivalentOpTargets, - ReplaceGeluWithApproximateGeluPass, ReplaceIm2RowWithViewPass, ReplaceLinearWithFullyConnectedOpPass, ReplaceMatmulWithTransposedMatmulPass, @@ -1287,17 +1287,41 @@ def forward(self, cond: torch.Tensor): 1, ) - def test_replace_aten_gelu_with_approximate_gelu(self): - class Gelu(torch.nn.Module): - def forward(self, input): - return torch.nn.functional.gelu(input) + def test_no_replace_aten_gelu_with_approximate_gelu(self): + inputs = torch.randn(2, 1, 64) + + gm = single_op_builder( + placeholders=(inputs,), + op=exir_ops.edge.aten.gelu.default, + args=(inputs,), + ) + gm = ExportPass().call(gm).graph_module + + p = ReplaceAtenApproxGeluWithApproxGeluPass() + graph_after_passes = p.call(gm).graph_module + # Assert that aten.gelu op was not decomposed, since it didn't have an approximate argument + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.aten.gelu.default, + ), + 1, + ) + + def test_replace_aten_approximate_gelu_with_approximate_gelu(self): inputs = torch.randn(2, 1, 64) - graph_module = export_to_edge(Gelu(), (inputs,)).exported_program().graph_module + gm = single_op_builder( + placeholders=(inputs,), + op=exir_ops.edge.aten.gelu.default, + args=(inputs,), + kwargs={"approximate": "tanh"}, + ) + gm = ExportPass().call(gm).graph_module - p = ReplaceGeluWithApproximateGeluPass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + p = ReplaceAtenApproxGeluWithApproxGeluPass() + graph_after_passes = p.call(gm).graph_module # Assert that aten.gelu op was decomposed self.assertEqual(