From 5a0096efa4466e181db8fa6bacbcf5edaf2a03bb Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Mon, 24 Nov 2025 09:18:03 +0100 Subject: [PATCH] Arm backend: Make INT+FP default for vgf-backend Make INT+FP default for vgf. With INT+FP, the following issues were solved: - Make sure that pow.Tensor_Scalar is not replaced by pow.Tensor_Tensor during transform_for_annotation. - Make sure that Scalar-ops that also maps to table ops are not replaced by Tensor ops when quantized. - Fix a bug in tosa_supported_operators where nodes with integer outputs were considered ok for partitioning when they shouldn't be. Signed-off-by: Oscar Andersson Change-Id: I53c992422b80fd7e93da9d1b07256dccd7115025 --- .../replace_scalar_with_tensor_pass.py | 12 +++++-- .../tosa_supported_operators.py | 36 ++++++++++++++----- .../arm/quantizer/quantization_annotator.py | 1 + backends/arm/test/common.py | 17 ++++++--- backends/arm/test/ops/test_eq.py | 16 ++++++--- backends/arm/test/tester/test_pipeline.py | 2 +- backends/arm/vgf/compile_spec.py | 15 +++----- 7 files changed, 69 insertions(+), 30 deletions(-) diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index 3a3ae0d5081..9f6b672c4fa 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -7,6 +7,7 @@ from typing import Dict, Set, Type, Union import torch +from executorch.backends.arm._passes.insert_table_ops import TableOps from executorch.backends.arm.tosa.specification import get_context_spec from executorch.backends.transforms.replace_scalar_with_tensor import ( @@ -64,7 +65,6 @@ Union[EdgeOpOverload, torch._ops.OpOverload], ] = _common_ops | { exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor, - torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor, } _int_profile_ops: Dict[ @@ -101,7 +101,15 @@ def call_operator(self, op, args, kwargs, meta): included_ops |= _fp_profile_ops if included_ops == {}: - raise ValueError("Profile must support either INT or FP") + raise ValueError("Profile must support at least INT or FP") + + if op in TableOps.included_ops(): + # Do not handle quantized table ops; forward unchanged. + input_qparams = meta.data.get("input_qparams", {}) + output_qparams = meta.data.get("input_qparams", {}) + if len(input_qparams) > 0 and len(output_qparams) > 0: + # Do not handle; forward unchanged. + return ExportPass.call_operator(self, op, args, kwargs, meta) if op in included_ops: # Include this op based on the current profile. diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 7c5cc647f5b..9240f14da54 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -146,6 +146,10 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]): return checker +def _is_integer_dtype(dtype: torch.dtype) -> bool: + return not dtype.is_floating_point and not dtype.is_complex + + def _is_quantized_constant(node: torch.fx.Node) -> bool: if node.target not in ( exir_ops.edge.aten.full_like.default, @@ -161,7 +165,7 @@ def _is_quantized_constant(node: torch.fx.Node) -> bool: for user in users: if user.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default: dim_order_dtype = get_first_fake_tensor(user).dtype - if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point: + if not _is_integer_dtype(dim_order_dtype): return False else: return False @@ -184,10 +188,24 @@ def is_quantized(node: torch.fx.Node) -> bool: bool: True if the node is quantized, False otherwise. """ - node_dtype = get_first_fake_tensor(node).dtype - # Integer-like dtype implies the node is already quantized. - if not node_dtype.is_complex and not node_dtype.is_floating_point: - return True + try: + node_dtype = get_first_fake_tensor(node).dtype + # Integer-like dtype implies the node is already quantized as long + # as inputs are not floating-point. + if _is_integer_dtype(node_dtype): + input_nodes = node.all_input_nodes + input_nodes_dtypes = [ + get_first_fake_tensor(input_node).dtype for input_node in input_nodes + ] + if all( + _is_integer_dtype(input_node_dtype) + for input_node_dtype in input_nodes_dtypes + ): + return True + + except TypeError: + # Could not determine dtype, fall back to other checks. + pass # Nodes introduced during lowering that exclusively feed quantized users. if _is_quantized_constant(node): @@ -510,7 +528,7 @@ def is_node_supported( input_quantized = input_quantized or all( (input_node.target in DQ_OPS) - or (not get_first_fake_tensor(input_node).dtype.is_floating_point) + or _is_integer_dtype(get_first_fake_tensor(input_node).dtype) for input_node in node.all_input_nodes ) @@ -519,8 +537,10 @@ def is_node_supported( return False all_q_users = all((output_node.target in Q_OPS) for output_node in node.users) - is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point - output_quantized = output_quantized or all_q_users or not is_floating_point + output_dtype = get_first_fake_tensor(node).dtype + output_quantized = ( + output_quantized or all_q_users or _is_integer_dtype(output_dtype) + ) if not output_quantized: self.reporter.report_reject(node, "One or more outputs were not quantized.") diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index c0f28cc3d87..ea5c174f4b6 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -356,6 +356,7 @@ def _match_pattern( torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default, torch.ops.aten.full_like.default, + torch.ops.aten.zeros_like.default, torch.ops.aten.pow.Tensor_Scalar, torch.ops.aten.gelu.default, torch.ops.aten.sinh.default, diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index b9dc9b00725..c2522941215 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -170,13 +170,22 @@ def get_vgf_compile_spec( if not custom_path: custom_path = maybe_get_tosa_collate_path() + profiles = [] if "FP" in repr(tosa_spec): - artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_vgf_fp_") - elif "INT" in repr(tosa_spec): - artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_vgf_int_") - else: + profiles.append("fp") + if "INT" in repr(tosa_spec): + profiles.append("int") + if len(profiles) == 0: raise ValueError(f"Unsupported vgf compile_spec: {repr(tosa_spec)}") + if custom_path is None: + artifact_path = "arm_vgf_" + for profile in profiles: + artifact_path = artifact_path + f"_{profile}" + artifact_path = tempfile.mkdtemp(artifact_path) + else: + artifact_path = custom_path + if not os.path.exists(artifact_path): os.makedirs(artifact_path, exist_ok=True) diff --git a/backends/arm/test/ops/test_eq.py b/backends/arm/test/ops/test_eq.py index e49f09471fa..fa2b8fda126 100644 --- a/backends/arm/test/ops/test_eq.py +++ b/backends/arm/test/ops/test_eq.py @@ -122,7 +122,7 @@ def test_eq_scalar_tosa_INT(test_module): @common.parametrize("test_module", test_data_tensor) -def test_eq_tensor_tosa_INT_a16w8(test_module): +def test_eq_tensor_tosa_INT_16a8w(test_module): pipeline = TosaPipelineINT[input_t]( test_module(), test_module().get_inputs(), @@ -134,7 +134,7 @@ def test_eq_tensor_tosa_INT_a16w8(test_module): @common.parametrize("test_module", test_data_scalar) -def test_eq_scalar_tosa_INT_a16w8(test_module): +def test_eq_scalar_tosa_INT_16a8w(test_module): pipeline = TosaPipelineINT[input_t]( test_module(), test_module().get_inputs(), @@ -238,7 +238,11 @@ def test_eq_scalar_16a8w_u85_INT16(test_module): @common.SkipIfNoModelConverter def test_eq_scalar_vgf_FP_tensor(test_module): pipeline = VgfPipeline[input_t]( - test_module(), test_module().get_inputs(), Equal.aten_op_Tensor, Equal.exir_op + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + tosa_version="TOSA-1.0+FP", ) pipeline.run() @@ -247,7 +251,11 @@ def test_eq_scalar_vgf_FP_tensor(test_module): @common.SkipIfNoModelConverter def test_eq_scalar_vgf_FP(test_module): pipeline = VgfPipeline[input_t]( - test_module(), test_module().get_inputs(), Equal.aten_op_Scalar, Equal.exir_op + test_module(), + test_module().get_inputs(), + Equal.aten_op_Scalar, + Equal.exir_op, + tosa_version="TOSA-1.0+FP", ) pipeline.run() diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 8b8daca6d2d..084302471cd 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -990,7 +990,7 @@ def __init__( exir_op: Optional[str | List[str]] = None, run_on_vulkan_runtime: bool = True, vgf_compiler_flags: Optional[str] = "", - tosa_version: str = "TOSA-1.0+FP", + tosa_version: str = "TOSA-1.0+INT+FP", symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, diff --git a/backends/arm/vgf/compile_spec.py b/backends/arm/vgf/compile_spec.py index f0ae83c654a..0e160492a9e 100644 --- a/backends/arm/vgf/compile_spec.py +++ b/backends/arm/vgf/compile_spec.py @@ -28,13 +28,12 @@ def __init__( tosa_spec (TosaSpecification | str | None): TOSA specification to target. Strings are parsed via :meth:`TosaSpecification.create_from_string`. Defaults to - ``"TOSA-1.0+FP"``. + ``"TOSA-1.0+FP+INT"``. compiler_flags (list[str] | None): Optional converter-backend flags. - """ if tosa_spec is None: - tosa_spec = "TOSA-1.0+FP" - if isinstance(tosa_spec, str): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP+INT") + elif isinstance(tosa_spec, str): tosa_spec = TosaSpecification.create_from_string(tosa_spec) if compiler_flags is None: @@ -55,13 +54,7 @@ def validate(self): if "FP" not in tosa_profiles and "INT" not in tosa_profiles: raise ValueError( - "Arm backend only supports converter-backend for FP or INT. " - f"Invalid TOSA profile: {tosa_profiles}" - ) - - if len(tosa_profiles) != 1: - raise ValueError( - "For now Arm backend only supports converter-backend for either FP or INT. " + "Arm backend only supports converter-backend for FP and/or INT. " f"Invalid TOSA profile: {tosa_profiles}" )