Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions backends/arm/common/annotation_meta.py

This file was deleted.

76 changes: 3 additions & 73 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_table_ops import TableOps
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS
from executorch.backends.arm.operator_support.ethos_u55_support import (
EthosU55CastCheck,
Expand Down Expand Up @@ -141,7 +140,6 @@ def tosa_support_factory(
]

if not tosa_spec.support_float():
negative_checks.append(CheckArmQuantized(reporter))
negative_checks.append(CheckProperQuantization(reporter))
if tosa_spec.is_U55_subset:
negative_checks.append(EthosU55NotSupported(reporter))
Expand Down Expand Up @@ -169,6 +167,7 @@ class TOSAProINTSupportList(OperatorSupportBase):
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList


Expand All @@ -181,78 +180,8 @@ class TOSAProFPSupportList(OperatorSupportBase):
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList


class CheckArmQuantized(OperatorSupportBase):
"""
Check if the node was marked as quantized in the Arm backend.
This is used to ensure that nodes that were quantized in the Arm backend
are only partitioned if they are supported by the TOSA backend.
"""

def __init__(self, reporter: WhyNoPartitionReporter):
self.reporter = reporter

def _is_quantized(self, node: torch.fx.Node) -> bool:
"""Checks if the node is quantized.

A node is considered quantized if at least one criteria is met:
- Its dtype is not floating point or complex => integer
- It is one of the special cases where the node has been created in to_edge, e.g.
.Scalar operations that have been promoted .Tensor operations
where the scalar is replaced by a full op.
- It has been marked as quantized in the ArmAnnotationInfo custom meta.

Args:
node (torch.fx.Node): The FX node to check.

Returns:
bool: True if the node is quantized, False otherwise.
"""
node_dtype = get_first_fake_tensor(node).dtype
if not node_dtype.is_complex and not node_dtype.is_floating_point:
return True
if node.target in (
exir_ops.edge.aten.full_like.default,
*ComputeConstantOpsAOT.targeted_ops,
):
# Special cases where nodes have been created in to_edge, e.g.
# .Scalar operations that have been promoted .Tensor operations
# where the scalar is replaced by a full op.
if all(user.target in Q_OPS for user in node.users):
return True
for user in node.users:
if (
user.target
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
):
dim_order_dtype = get_first_fake_tensor(user).dtype
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
return False
else:
return False
return True
return (
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
and node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY].quantized
)

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
if node.op != "call_function":
return False

if node.target in (*DQ_OPS, *Q_OPS):
return True

if not self._is_quantized(node):
self.reporter.report_reject(
node, "Node was not marked as quantized in the Arm backend."
)
return False
return True
return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList


class CheckProperQuantization(OperatorSupportBase):
Expand Down Expand Up @@ -498,6 +427,7 @@ def is_node_supported(


class CheckFloat64Inputs(OperatorSupportBase):

def __init__(
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
):
Expand Down
10 changes: 1 addition & 9 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -14,8 +14,6 @@

from typing import cast

from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo

from torch.fx import Node

from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
Expand Down Expand Up @@ -67,10 +65,4 @@ def mark_node_as_annotated(node: Node) -> None:
"""
if Q_ANNOTATION_KEY not in node.meta:
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
annotation_info = ArmAnnotationInfo(
quantized=True,
)
node.meta[Q_ANNOTATION_KEY]._annotated = True
meta_custom = node.meta.get("custom", {})
meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = annotation_info
node.meta["custom"] = meta_custom
4 changes: 1 addition & 3 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,6 @@ def _match_pattern(
torch.ops.aten.view.default,
torch.ops.aten.view_as.default,
torch.ops.aten.view_copy.default,
torch.ops.aten._unsafe_view.default,
torch.ops.aten.select.int,
torch.ops.aten.select_copy.int,
torch.ops.aten.slice.Tensor,
Expand Down Expand Up @@ -427,7 +426,6 @@ def _match_pattern(
]

_one_to_one_shared_input_or_input_act_qspec = [
torch.ops.aten.alias.default,
torch.ops.aten.clone.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
Expand Down Expand Up @@ -695,10 +693,10 @@ def any_or_hardtanh_min_zero(n: Node):
]
quant_properties.quant_output = None
elif node.target in [
torch.ops.aten.scalar_tensor.default,
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.fill_.Scalar,
torch.ops.aten.scalar_tensor.default,
]:
quant_properties.quant_inputs = []
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
Expand Down
6 changes: 5 additions & 1 deletion backends/arm/test/misc/test_int64.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def forward(self, x: torch.Tensor):
ConstAdd(torch.int64, 2**40),
(torch.rand(10) - 0.5,),
),
"int64_in+float_const": (
ConstAdd(torch.float32),
(torch.randint(0, 10, (10,)),),
),
"fp32_in+int64_buffer_chain": (
BufferChainAdd(torch.int64),
(torch.rand(2, 5, 3) - 0.5,),
Expand All @@ -90,7 +94,7 @@ def test_int64_tosa_FP(test_data: Tuple):
ArmTester(
model,
inputs,
common.get_tosa_compile_spec("TOSA-1.0+FP"),
common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"),
)
.export()
.to_edge_transform_and_lower()
Expand Down
100 changes: 0 additions & 100 deletions backends/arm/test/misc/test_quant_custom_meta.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ class TestSD3Transformer2DModel:

ops_after_partitioner_INT = {
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
"torch.ops.higher_order.executorch_call_delegate": 3,
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
"torch.ops.higher_order.executorch_call_delegate": 2,
}

def _prepare_inputs(
Expand Down
1 change: 1 addition & 0 deletions backends/arm/test/models/test_nn_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def test_nn_functional_FP(test_data):
@parametrize(
"test_data",
module_tests,
{"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"},
)
def test_nn_functional_INT(test_data):
module, inputs = test_data
Expand Down
4 changes: 1 addition & 3 deletions backends/arm/test/models/test_torch_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,10 @@ def test_torch_fns_FP(test_data):
xfails={
"nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). "
"Requires dynamic output shape.",
"eye": "ValueError: Failed processing buffer placeholder: aten_arange_start_step_1_pre_computed_common. "
"Is the original torch function supported?",
"topk": "NotImplementedError: No registered serialization name for <class 'torch.return_types.topk'> found",
"sort": "NotImplementedError: No registered serialization name for <class 'torch.return_types.sort'> found",
},
strict=True,
strict=False,
)
def test_torch_fns_INT(test_data):
module, inputs = test_data
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_eye.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_eye_u85_INT(test_data: test_data_t):
input_data(),
EyeAdd.aten_op,
use_to_edge_transform_and_lower=True,
)
).dump_artifact("to_edge_transform_and_lower")
pipeline.pop_stage("check.quant_nodes")
pipeline.run()

Expand Down
Loading
Loading