Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/cadence/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Supported DSPs (in progress)
- HiFi Audio
- ...
- Fusion G3

## Tutorial

Expand Down
160 changes: 159 additions & 1 deletion backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ python_library(
":passes",
":utils",
":ops_registrations",
":replace_ops",
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
"//executorch/backends/cadence/aot/quantizer:quantizer",
Expand Down Expand Up @@ -74,12 +75,14 @@ python_library(
":utils",
":fuse_ops",
":simplify_ops",
":replace_ops",
":reorder_ops",
":remove_ops",
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/exir/passes:lib",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/backends/transforms:remove_clone_ops"
],
)

Expand Down Expand Up @@ -180,6 +183,63 @@ python_library(
],
)

python_library(
name = "remove_ops",
srcs = [
"remove_ops.py",
],
typing = True,
deps = [
"//caffe2:torch",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:simplify_ops",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/backends/transforms:remove_clone_ops"
],
)

python_library(
name = "reorder_ops",
srcs = [
"reorder_ops.py",
],
typing = True,
deps = [
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler_utils",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:utils",
"//executorch/exir:pass_base",
"//executorch/exir:tensor",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
],
)

python_library(
name = "replace_ops",
srcs = [
"replace_ops.py",
],
typing = True,
deps = [
":pass_utils",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler_utils",
"//executorch/backends/cadence/aot:fuse_ops",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:remove_ops",
"//executorch/backends/cadence/aot:utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
"//executorch/exir/passes:spec_prop_pass",
],
)

python_unittest(
name = "test_graph_builder",
srcs = [
Expand All @@ -196,3 +256,101 @@ python_unittest(
":ops_registrations"
],
)

python_unittest(
name = "test_replace_ops_passes",
srcs = [
"tests/test_replace_ops_passes.py",
],
supports_static_listing = False,
typing = True,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
":compiler",
":replace_ops",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/exir/passes:lib",
],
)

python_unittest(
name = "test_fusion_ops_passes",
srcs = [
"tests/test_fusion_ops_passes.py",
],
typing = True,
deps = [
":compiler",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:fuse_ops",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
],
)

python_unittest(
name = "test_remove_ops_passes",
srcs = [
"tests/test_remove_ops_passes.py",
],
supports_static_listing = False,
typing = True,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"fbsource//third-party/pypi/pyre-extensions:pyre-extensions",
":compiler",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:remove_ops",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/exir/dialects:lib",
],
)

python_unittest(
name = "test_simplify_ops_passes",
srcs = [
"tests/test_simplify_ops_passes.py",
],
supports_static_listing = False,
typing = True,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:simplify_ops",
"//executorch/exir/dialects:lib",
],
)

python_unittest(
name = "test_reorder_ops_passes",
srcs = [
"tests/test_reorder_ops_passes.py",
],
typing = True,
deps = [
":compiler",
":pass_utils",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:fuse_ops",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:reorder_ops",
"//executorch/exir/dialects:lib",
],
)
26 changes: 21 additions & 5 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch

from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer

from executorch.backends.cadence.aot.replace_ops import ReplaceSafeSoftmaxWithSoftmax
from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
Expand Down Expand Up @@ -194,9 +194,6 @@ def export_to_edge(
return edge_prog_manager


# Export the model and lower it to an EdgeProgramManager (in edge IR), and
# apply passes specific to Cadence DSP execution. Return both to print the
# differences.
def export_to_cadence(
model: torch.nn.Module,
inputs: tuple[object, ...],
Expand All @@ -216,6 +213,25 @@ def export_to_cadence(
return cadence_prog_manager


def quantize_and_export_to_cadence(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
opt_level: int = 1,
) -> EdgeProgramManager:
quantized_model = quantize_pt2(model, inputs)

return export_to_cadence(
quantized_model,
inputs,
opt_level=opt_level,
dump_graphs=dump_graphs,
)


# Export the model and lower it to an EdgeProgramManager (in edge IR), and
# apply passes specific to Cadence DSP execution. Return both to print the
# differences.
def export_to_executorch_gen_etrecord(
model: torch.nn.Module,
inputs: tuple[object, ...],
Expand Down
48 changes: 46 additions & 2 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyre-strict

from dataclasses import dataclass
from typing import Callable, Optional, Set, Union
from typing import Callable, List, Optional, Set, Union

import torch
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
Expand Down Expand Up @@ -50,7 +50,7 @@ def get_all_available_cadence_passes() -> Set[ExportPass]:
return set(ALL_CADENCE_PASSES.keys())


# Create a new filter to filter out relevant passes from all Jarvis passes.
# Create a new filter to filter out relevant passes from all passes.
def create_cadence_pass_filter(
opt_level: int, debug: bool = False
) -> Callable[[ExportPass], bool]:
Expand Down Expand Up @@ -98,3 +98,47 @@ def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target)
if node.op == "call_function" and node.target == target:
total += 1
return total


# Testing utils
# Return the compute/function nodes in the graph
def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:
nodes = []
for x in graph_module.graph.nodes:
if x.op == "call_function":
if isinstance(x.target, torch._ops.OpOverload):
nodes.append(x.target.overloadpacket)
elif isinstance(x.target, EdgeOpOverload):
nodes.append(get_edge_overload_packet(x.target))
return nodes


# Return true if there is no edge from a node with target pred_target to a
# node with target succ_target in the graph.
def nodes_not_connected_in_gm(
graph_module: torch.fx.GraphModule,
pred_target: torch.fx.Node,
succ_target: torch.fx.Node,
) -> bool:
for node in graph_module.graph.nodes:
if node.target != pred_target:
continue
for user in node.users:
if user.target == succ_target:
return False
return True


# Returns true if there is no instance of a node with target succ_target
# positioned immediately after a node with target pred_target in the graph
def nodes_not_adjacent_in_gm(
graph_module: torch.fx.GraphModule,
pred_target: torch.fx.Node,
succ_target: torch.fx.Node,
) -> bool:
for node in graph_module.graph.nodes:
if node.target != pred_target:
continue
if node.next.target == succ_target:
return False
return True
Loading
Loading