Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,24 @@ python_library(
],
)

python_library(
name = "decompose_ops",
srcs = [
"decompose_ops.py",
],
typing = True,
deps = [
":pass_utils",
"//caffe2:torch",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
"//executorch/exir/passes:spec_prop_pass",
],
)


python_unittest(
name = "test_graph_builder",
srcs = [
Expand Down Expand Up @@ -314,6 +332,27 @@ python_unittest(
],
)

python_unittest(
name = "test_decompose_ops_passes",
srcs = [
"tests/test_decompose_ops_passes.py",
],
supports_static_listing = False,
typing = True,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
":compiler",
":decompose_ops",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/exir/passes:lib",
],
)

python_unittest(
name = "test_fusion_ops_passes",
srcs = [
Expand Down
122 changes: 122 additions & 0 deletions backends/cadence/aot/decompose_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


# This file contains all the functions that decompose one op into simpler ops in the
# graph. The functions decomposing ops for models deployed with Jarvis are grouped
# together in class 'DecomposeOpsInGraph'. Some examples of functions in the class are
# 1. functions that decompose an ATen gelu op into an equivalent series of simpler ops

# pyre-strict

from typing import Dict

from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
register_cadence_pass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from torch.fx.node import Argument


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class DecomposeAtenApproxGeluPass(ExportPass):
"""
Decompose the aten gelu op with an approximate arg to a series of simpler ops
"""

def call_operator(
self,
op: EdgeOpOverload,
args: tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
# compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
# as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))

# Get 0.5 * x
half = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(args[0], 0.5),
{},
meta,
)

scaled = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(args[0], 0.044715),
{},
meta,
)

# Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
# it is much more efficient on DSP backends)
scaled_square = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(scaled, args[0]),
{},
meta,
)

# Get x^3
scaled_cubed = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(scaled_square, args[0]),
{},
meta,
)

# Get x + 0.044715 * x^3
inner_sum = super().call_operator(
exir_ops.edge.aten.add.Tensor,
(scaled_cubed, args[0]),
{},
meta,
)

# Get 0.7978845608028654 * ( x + 0.044715 * x^3)
scaled_sum = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(inner_sum, 0.7978845608028654),
{},
meta,
)

# Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
tanh = super().call_operator(
exir_ops.edge.aten.tanh.default,
(scaled_sum,),
{},
meta,
)

# Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
# TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
outer_sum = super().call_operator(
exir_ops.edge.aten.add.Tensor,
(tanh, 1.0),
{},
meta,
)

# Return the final result
return super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(half, outer_sum),
{},
meta,
)


# This class encapsulates all the functions that decompose one op in the graph.
class CadenceDecomposeOpsInGraph:
passes = [
DecomposeAtenApproxGeluPass,
]
80 changes: 1 addition & 79 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,89 +2078,11 @@ def call_operator(
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if "approximate" not in kwargs:
return super().call_operator(op, args, kwargs, meta)

if op not in {
exir_ops.edge.aten.gelu.default,
}:
return super().call_operator(op, args, kwargs, meta)

# compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
# as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))

# Get 0.5 * x
half = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(args[0], 0.5),
{},
meta,
)

scaled = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(args[0], 0.044715),
{},
meta,
)

# Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
# it is much more efficient on DSP backends)
scaled_square = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(scaled, args[0]),
{},
meta,
)

# Get x^3
scaled_cubed = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(scaled_square, args[0]),
{},
meta,
)

# Get x + 0.044715 * x^3
inner_sum = super().call_operator(
exir_ops.edge.aten.add.Tensor,
(scaled_cubed, args[0]),
{},
meta,
)

# Get 0.7978845608028654 * ( x + 0.044715 * x^3)
scaled_sum = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(inner_sum, 0.7978845608028654),
{},
meta,
)

# Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
tanh = super().call_operator(
exir_ops.edge.aten.tanh.default,
(scaled_sum,),
{},
meta,
)

# Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
# TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
outer_sum = super().call_operator(
exir_ops.edge.aten.add.Tensor,
(tanh, 1.0),
{},
meta,
)

# Retunr the final result
return super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(half, outer_sum),
{},
meta,
)
return super().call_operator(op, args, kwargs, meta)


# Adapted from fbcode/pyspeech/opt_passes/replace_ops.py
Expand Down
80 changes: 80 additions & 0 deletions backends/cadence/aot/tests/test_decompose_ops_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest
from typing import Union

import torch
from executorch.backends.cadence.aot.decompose_ops import DecomposeAtenApproxGeluPass
from executorch.backends.cadence.aot.graph_builder import single_op_builder
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass


class TestDecomposeOpsPasses(unittest.TestCase):
def assertTargetCountEqual(
self,
graph_module: torch.fx.GraphModule,
target: Union[EdgeOpOverload, str],
expected_count: int,
) -> None:
"""Helper function to check the number of nodes with a given target."""
actual_count = count_node(graph_module, target)
self.assertEqual(
actual_count,
expected_count,
f"{target} count mismatch for graph {graph_module}",
)

def assertTargetCountsEqual(
self,
graph_module: torch.fx.GraphModule,
targets_and_counts: list[tuple[Union[EdgeOpOverload, str], int]],
) -> None:
"""Helper function to check the number of nodes of all types for a given target."""
for target, expected_count in targets_and_counts:
self.assertTargetCountEqual(graph_module, target, expected_count)

def test_decompose_aten_approximate_gelu(self) -> None:
inputs = torch.randn(2, 1, 64)

gm = single_op_builder(
placeholders=(inputs,),
op=exir_ops.edge.aten.gelu.default,
args=(inputs,),
kwargs={"approximate": "tanh"},
)
gm = ExportPass().call(gm).graph_module

p = DecomposeAtenApproxGeluPass()
graph_after_passes = p.call(gm).graph_module

# Assert that aten.gelu op was decomposed
self.assertEqual(
count_node(
graph_after_passes,
exir_ops.edge.aten.gelu.default,
),
0,
)

# The decomposition should have one tanh, 2 add and 6 mul
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.tanh.default),
1,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
2,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
6,
)
37 changes: 0 additions & 37 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,43 +1309,6 @@ def test_no_replace_aten_gelu_with_approximate_gelu(self):
1,
)

def test_replace_aten_approximate_gelu_with_approximate_gelu(self):
inputs = torch.randn(2, 1, 64)

gm = single_op_builder(
placeholders=(inputs,),
op=exir_ops.edge.aten.gelu.default,
args=(inputs,),
kwargs={"approximate": "tanh"},
)
gm = ExportPass().call(gm).graph_module

p = ReplaceAtenApproxGeluWithApproxGeluPass()
graph_after_passes = p.call(gm).graph_module

# Assert that aten.gelu op was decomposed
self.assertEqual(
count_node(
graph_after_passes,
exir_ops.edge.aten.gelu.default,
),
0,
)

# The decomposition should have one tanh, 2 add and 6 mul
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.tanh.default),
1,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
2,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
6,
)

def test_replace_split_with_sizes_with_slice(self):
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 16, 8, 4))
Expand Down