Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ runtime.python_library(
)
runtime.python_library(
name = "common",
srcs = [
"common/__init__.py",
"common/debug.py",
"common/type.py",
],
srcs = glob(["common/*.py"]),
deps = [
"fbsource//third-party/tosa_tools:serializer",
"//caffe2:torch",
Expand Down
39 changes: 39 additions & 0 deletions backends/arm/common/annotation_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Mapping, Optional


@dataclass(frozen=True, init=False)
class ArmAnnotationInfo(dict):
"""
Dataclass wrapper that behaves like a dict so serialization can treat it as
a plain mapping, while still exposing a typed attribute for convenience.
"""

quantized: bool
CUSTOM_META_KEY: str = "_arm_annotation_info"

def __init__(
self,
value: Optional[Mapping[str, Any]] = None,
*,
quantized: Optional[bool] = None,
) -> None:
if quantized is not None:
resolved = bool(quantized)

elif isinstance(value, Mapping):
resolved = bool(value.get("quantized", False))

else:
raise TypeError(
"ArmAnnotationInfo expects a mapping with a 'quantized' entry or a keyword 'quantized'."
)
dict.__init__(self, quantized=resolved)
object.__setattr__(self, "quantized", resolved)
75 changes: 75 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_table_ops import TableOps
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS
from executorch.backends.arm.operator_support.ethos_u55_support import (
EthosU55CastCheck,
Expand Down Expand Up @@ -209,6 +210,7 @@ def tosa_support_factory(
]

if not tosa_spec.support_float():
negative_checks.append(CheckArmQuantized(reporter))
negative_checks.append(CheckProperQuantization(reporter))
if tosa_spec.is_U55_subset:
negative_checks.append(EthosU55NotSupported(reporter))
Expand Down Expand Up @@ -260,6 +262,79 @@ def is_node_supported(
return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList


class CheckArmQuantized(OperatorSupportBase):
"""
Check if the node was marked as quantized in the Arm backend.
This is used to ensure that nodes that were quantized in the Arm backend
are only partitioned if they are supported by the TOSA backend.
"""

def __init__(self, reporter: WhyNoPartitionReporter):
self.reporter = reporter

def _is_quantized(self, node: torch.fx.Node) -> bool:
"""Checks if the node is quantized.

A node is considered quantized if at least one criteria is met:
- Its dtype is not floating point or complex => integer
- It is one of the special cases where the node has been created in to_edge, e.g.
.Scalar operations that have been promoted .Tensor operations
where the scalar is replaced by a full op.
- It has been marked as quantized in the ArmAnnotationInfo custom meta.

Args:
node (torch.fx.Node): The FX node to check.

Returns:
bool: True if the node is quantized, False otherwise.
"""
node_dtype = get_first_fake_tensor(node).dtype
if not node_dtype.is_complex and not node_dtype.is_floating_point:
return True
if node.target in (
exir_ops.edge.aten.full_like.default,
*ComputeConstantOpsAOT.targeted_ops,
):
# Special cases where nodes have been created in to_edge, e.g.
# .Scalar operations that have been promoted .Tensor operations
# where the scalar is replaced by a full op.
if all(user.target in Q_OPS for user in node.users):
return True
for user in node.users:
if (
user.target
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
):
dim_order_dtype = get_first_fake_tensor(user).dtype
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
return False
else:
return False
return True
return (
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
and ArmAnnotationInfo(
node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY]
).quantized
)

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
if node.op != "call_function":
return False

if node.target in (*DQ_OPS, *Q_OPS):
return True

if not self._is_quantized(node):
self.reporter.report_reject(
node, "Node was not marked as quantized in the Arm backend."
)
return False
return True


class CheckProperQuantization(OperatorSupportBase):
"""Ensure targeted nodes are properly quantized.

Expand Down
10 changes: 9 additions & 1 deletion backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -14,6 +14,8 @@

from typing import cast

from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo

from torch.fx import Node

from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
Expand Down Expand Up @@ -65,4 +67,10 @@ def mark_node_as_annotated(node: Node) -> None:
"""
if Q_ANNOTATION_KEY not in node.meta:
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
annotation_info = ArmAnnotationInfo(
quantized=True,
)
node.meta[Q_ANNOTATION_KEY]._annotated = True
meta_custom = node.meta.get("custom", {})
meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = dict(annotation_info)
node.meta["custom"] = meta_custom
4 changes: 3 additions & 1 deletion backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def _match_pattern(
torch.ops.aten.view.default,
torch.ops.aten.view_as.default,
torch.ops.aten.view_copy.default,
torch.ops.aten._unsafe_view.default,
torch.ops.aten.select.int,
torch.ops.aten.select_copy.int,
torch.ops.aten.slice.Tensor,
Expand Down Expand Up @@ -426,6 +427,7 @@ def _match_pattern(
]

_one_to_one_shared_input_or_input_act_qspec = [
torch.ops.aten.alias.default,
torch.ops.aten.clone.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
Expand Down Expand Up @@ -693,10 +695,10 @@ def any_or_hardtanh_min_zero(n: Node):
]
quant_properties.quant_output = None
elif node.target in [
torch.ops.aten.scalar_tensor.default,
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.fill_.Scalar,
torch.ops.aten.scalar_tensor.default,
]:
quant_properties.quant_inputs = []
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
Expand Down
6 changes: 1 addition & 5 deletions backends/arm/test/misc/test_int64.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ def forward(self, x: torch.Tensor):
ConstAdd(torch.int64, 2**40),
(torch.rand(10) - 0.5,),
),
"int64_in+float_const": (
ConstAdd(torch.float32),
(torch.randint(0, 10, (10,)),),
),
"fp32_in+int64_buffer_chain": (
BufferChainAdd(torch.int64),
(torch.rand(2, 5, 3) - 0.5,),
Expand All @@ -94,7 +90,7 @@ def test_int64_tosa_FP(test_data: Tuple):
ArmTester(
model,
inputs,
common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"),
common.get_tosa_compile_spec("TOSA-1.0+FP"),
)
.export()
.to_edge_transform_and_lower()
Expand Down
100 changes: 100 additions & 0 deletions backends/arm/test/misc/test_quant_custom_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize


class AddSigmoidMul(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sigmoid = torch.nn.Sigmoid()

def forward(self, x, y):
return self.sigmoid(x + y) * x


def get_selective_quantizer(modules):
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
quantizer.set_global(get_symmetric_quantization_config())
for module in modules:
quantizer.set_module_type(module, None)

return Quantize(quantizer, get_symmetric_quantization_config())


def test_qdq_squeezed_fp_op():
"""Test that a float operation surrounded by quantize-dequantize pairs
is correctly handled by the partitioner and the TOSA backend.
Pattern:
q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q
|_____Non-delegated____|
"""
aten_op = "torch.ops.aten.add.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
module = AddSigmoidMul()
x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
pipeline = TosaPipelineINT(
module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op
)
pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid]))
pipeline.change_args(
"check_count.exir",
{
"torch.ops.higher_order.executorch_call_delegate": 2,
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
},
)
pipeline.run()


class MulAddSigmoidConv(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sigmoid = torch.nn.Sigmoid()
self.conv = torch.nn.Conv1d(3, 3, 1)

def forward(self, x, y):
return self.conv(self.sigmoid(x + y * x))


def test_quantized_to_float_transition():
"""Test that a model executing quantized ops followed by float ops
is correctly handled by the partitioner and the TOSA backend.
Pattern:
q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv
|____Non-delegated___|
"""
aten_op = "torch.ops.aten.add.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
module = MulAddSigmoidConv()
x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
pipeline = TosaPipelineINT(
module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op
)
pipeline.change_args(
"quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d])
)
pipeline.change_args(
"check_count.exir",
{
"torch.ops.higher_order.executorch_call_delegate": 1,
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
"executorch_exir_dialects_edge__ops_aten_convolution_default": 1,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
},
)
pipeline.run()
62 changes: 62 additions & 0 deletions backends/arm/test/misc/test_save_exported_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os

import torch
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
from executorch.backends.arm.quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.tosa import TosaSpecification
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


class SimpleModule(torch.nn.Module):
example_inputs = (torch.randn(1, 10),)

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)


def test_save_load_exported_int_model():
module = SimpleModule().eval()
example_inputs = module.example_inputs
exported_module = torch.export.export(module, example_inputs)

# Set up quantizer
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
quantizer.set_global(get_symmetric_quantization_config())
# Quantize model
prepared_module = prepare_pt2e(exported_module.module(), quantizer)
prepared_module(*example_inputs)
quantized_module = convert_pt2e(prepared_module)
quantized_exported_module = torch.export.export(quantized_module, example_inputs)

base_path = "arm_test/misc/"
if not os.path.exists(base_path):
os.makedirs(base_path)
file_path = base_path + "exported_module.pt2"
# Verify that we can save the model
torch.export.save(quantized_exported_module, file_path)

# Verify that we can load the model back
loaded_model = torch.export.load(file_path)
for original_node, loaded_node in zip(
quantized_exported_module.graph.nodes, loaded_model.graph.nodes
):
# Verify that the custom metadata is preserved after save/load
assert original_node.meta.get("custom", {}) == loaded_node.meta.get(
"custom", {}
)
if original_node.target == torch.ops.aten.linear.default:
assert ArmAnnotationInfo.CUSTOM_META_KEY in original_node.meta.get(
"custom", {}
)
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class TestSD3Transformer2DModel:

ops_after_partitioner_INT = {
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
"torch.ops.higher_order.executorch_call_delegate": 2,
"torch.ops.higher_order.executorch_call_delegate": 3,
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
}

def _prepare_inputs(
Expand Down
Loading
Loading