Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions backends/arm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@ ethos-u-vela compilation stack. which follows the fully AoT flow.
## Layout

Export:
- `arm_backend.py` - Main entrypoint for the ArmPartitioner and ArmBackend. For more information see the section on [Arm Bac
kend Architecture](#arm-backend-architecture). For examples of use see `executorch/examples/arm`.
- `arm_backend.py` - Main entrypoint for the ArmPartitioner and ArmBackend. For more information see the section on
[Arm Backend Architecture](#arm-backend-architecture). For examples of use see `executorch/examples/arm`.
- `tosa_mapping.py` - utilities for mapping edge dialect to TOSA
- `tosa_quant_utils.py` - utilities for mapping quantization information to TOSA encoding

Operators:
- `node_visitor.py` - Base class for edge operator lowering
- `op_*.py` - Edge operator lowering/serialization to TOSA

Passes:
- `arm_pass_manager.py` - Pass manager. Will decide which passes need to be applied depending on the compile_spec.
- `*_pass.py` - Compiler passes derived from ExportPass

Quantization:
- `arm_quantizer.py` - Quantizer for Arm backend
- `arm_quantizer_utils.py` - Utilities for quantization
Expand All @@ -36,8 +44,10 @@ This is the structure of the test directory

```
test # Root test folder
├── misc # Testing of debug features
├── models # Full model tests
├── ops # Single op tests
├── passes # Compiler passes tests
├── tester # Arm Tester class
├── tosautil # Utility functions for TOSA artifacts
├ common.py # Common functions and definitions used by many tests
Expand Down
6 changes: 5 additions & 1 deletion backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from executorch.backends.arm.arm_vela import vela_compile
from executorch.backends.arm.operators.node_visitor import get_node_visitors
from executorch.backends.arm.operators.op_placeholder import process_placeholder
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_quant_utils import get_quant_node_dtype, is_quant_node
from executorch.backends.arm.tosa_utils import (
Expand Down Expand Up @@ -241,10 +242,13 @@ def preprocess( # noqa: C901
# Converted output for this subgraph, serializer needs path early as it emits
# const data directly. Path created and data written only in debug builds.
tosa_graph = ts.TosaSerializer(artifact_path)
graph_module = ArmPassManager().transform_to_backend_pipeline(
graph_module=edge_program.graph_module, compile_spec=compile_spec
)

node_visitors = get_node_visitors(edge_program)

for node in edge_program.graph.nodes:
for node in graph_module.graph.nodes:
if node.op == "call_function":
# Unpack arguments and convert
inputs = []
Expand Down
58 changes: 6 additions & 52 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
from executorch.backends.arm.arm_backend import ArmBackend
from executorch.backends.arm.passes.tag_io_quant_pass import TagIOQuantPass
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Expand All @@ -18,6 +19,7 @@
)
from executorch.exir.backend.utils import tag_constant_data
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.passes import PassManager
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner

Expand Down Expand Up @@ -54,9 +56,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
supported &= self.is_node_supported_custom(node)

# Override partitioning based on pre partition passes
if supported and "arm_partition" in node.meta:
supported = supported & node.meta["arm_partition"]
node.meta.pop("arm_partition")
if "arm_override_partition" in node.meta:
supported = supported & node.meta["arm_override_partition"]
node.meta.pop("arm_override_partition")

return supported

Expand All @@ -69,54 +71,6 @@ def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
return True


from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import PassManager


class TagIOQuant(ExportPass):
"""
Pass run before partitioning to tag Q/DQ on any placeholder and output
to ensure we don't greedily partition them for device. Float conversion
has to happen outside a TOSA base inference profile.
"""

def __init__(self, edge_program: torch.export.ExportedProgram):
super(TagIOQuant, self).__init__()
self.edge_program = edge_program

def is_quant_node(self, node: torch.fx.node.Node):
return node.target in {
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
}

def is_dequant_node(self, node: torch.fx.node.Node):
return node.target in {
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
}

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
# tag q of input
if node.op == "placeholder":
for user in node.users.keys():
# if we have an input going into a quantize
if self.is_quant_node(user):
user.meta["arm_partition"] = False

# tag dq of outputs
if node.op == "output":
quant, *_ = node.args[0]
if self.is_dequant_node(quant):
quant.meta["arm_partition"] = False

graph_module.recompile()
return PassResult(graph_module, True)


@final
class ArmPartitioner(Partitioner):
def __init__(self, compile_spec: List[CompileSpec]) -> None:
Expand All @@ -133,7 +87,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# Exclude IO quantization from the partition
passes = PassManager(
passes=[
TagIOQuant(exported_program),
TagIOQuantPass(),
]
)
passes(exported_program.graph_module)
Expand Down
3 changes: 1 addition & 2 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Arm Limited and/or its affiliates.
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -9,7 +9,6 @@
op_addmm,
op_avg_pool2d,
op_batch_norm,
op_clone,
op_conv2d,
op_dequant,
op_div,
Expand Down
34 changes: 0 additions & 34 deletions backends/arm/operators/op_clone.py

This file was deleted.

25 changes: 25 additions & 0 deletions backends/arm/passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.pass_manager import PassManager


class ArmPassManager(PassManager):

def _transform(self, graph_module: torch.fx.Graph):
return self(graph_module).graph_module

def transform_to_backend_pipeline(
self, graph_module: torch.fx.Graph, compile_spec: CompileSpec
):
"""Apply passes before transforming program to backend"""
self.add_pass(RemoveClonePass())

return self._transform(graph_module)
25 changes: 25 additions & 0 deletions backends/arm/passes/remove_clone_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class RemoveClonePass(ExportPass):

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target == exir_ops.edge.aten.clone.default:
for user in list(node.users):
# TODO remove dq/q-ops around removed clone-op
user.replace_input_with(node, node.args[0])
graph_module.graph.erase_node(node)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
49 changes: 49 additions & 0 deletions backends/arm/passes/tag_io_quant_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class TagIOQuantPass(ExportPass):
"""
Pass run before partitioning to tag Q/DQ on any placeholder and output
to ensure we don't greedily partition them for device. Float conversion
has to happen outside a TOSA base inference profile.
"""

def is_quant_node(self, node: torch.fx.node.Node):
return node.target in {
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
}

def is_dequant_node(self, node: torch.fx.node.Node):
return node.target in {
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
}

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
# tag q of input
if node.op == "placeholder":
for user in node.users.keys():
# if we have an input going into a quantize
if self.is_quant_node(user):
user.meta["arm_override_partition"] = False

# tag dq of outputs
if node.op == "output":
quant, *_ = node.args[0]
if self.is_dequant_node(quant):
quant.meta["arm_override_partition"] = False

graph_module.recompile()
return PassResult(graph_module, True)
6 changes: 4 additions & 2 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def get_tosa_compile_spec(permute_memory_to_nhwc=False, custom_path=None):
return compile_spec


def get_u55_compile_spec(permute_memory_to_nhwc=False, custom_path=None):
def get_u55_compile_spec(
permute_memory_to_nhwc=False, quantize_io=False, custom_path=None
):
"""
Default compile spec for Ethos-U55 tests.
"""
Expand All @@ -115,7 +117,7 @@ def get_u55_compile_spec(permute_memory_to_nhwc=False, custom_path=None):
memory_mode="Shared_Sram",
extra_flags=None,
)
.set_quantize_io(is_option_enabled("quantize_io"))
.set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
.set_permute_memory_format(permute_memory_to_nhwc)
.dump_intermediate_artifacts_to(artifact_path)
.build()
Expand Down
63 changes: 63 additions & 0 deletions backends/arm/test/passes/test_tag_io_quant_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester


class Add(torch.nn.Module):

def get_inputs(self):
return (torch.rand(1, 10, 10, 10),)

def forward(self, x):
return x + x


class TestTagIOQuantPass(unittest.TestCase):

def _tosa_BI_u55_pipeline(self, module: torch.nn.Module):
(
ArmTester(
module,
example_inputs=module.get_inputs(),
compile_spec=common.get_u55_compile_spec(quantize_io=True),
)
.quantize()
.export()
.to_edge()
.check_count(
{
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2
}
)
.check_count(
{
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2
}
)
.partition()
.check_count(
{
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1
}
)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_count(
{
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1
}
)
# .to_executorch() requires additional steps
)

def test_BI_u55_artifact(self):
model = Add()
self._tosa_BI_u55_pipeline(model)