From 4da1550319b59ec08ab85cafd38c1859ce66c868 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 01/19] =?UTF-8?q?Revert=20"Arm=20backend:=20Merge=20Retrac?= =?UTF-8?q?eFoldedDtypesPass=20into=20FoldAndAnnotateQParam=E2=80=A6=20(#1?= =?UTF-8?q?5377)"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 467529235f64f64ad5df3e87fa2f39172db03d9c. --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 3 ++ .../fold_qdq_with_annotated_qparams_pass.py | 45 ++++++++++++++----- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index deacfb7ec6f..55daf92a5a9 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -72,6 +72,7 @@ from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, QuantizeOperatorArguments, + RetraceFoldedDtypesPass, ) from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index d6e63100603..728eb984ae9 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -88,6 +88,7 @@ RemoveNoopPass, ReplaceInfValues, ReplaceScalarWithTensorByProfilePass, + RetraceFoldedDtypesPass, RewriteConv2dPass, RewriteMatmulPass, RewriteUpsamplePass, @@ -175,6 +176,7 @@ def _tosa_INT_pipeline( self.add_pass(QuantizeOperatorArguments()) self.add_pass(ConvertELUParamsPass()) self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] + self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(MatchArgRanksPass(exported_program)) if self.tosa_spec.is_U55_subset: @@ -269,6 +271,7 @@ def _tosa_FP_pipeline( self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] + self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(DecomposeAdaptiveAvgPool2dPass()) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 52e96878042..7fd9c2f2119 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -13,7 +13,6 @@ from executorch.backends.arm._passes.arm_pass_utils import ( get_param_tensor, is_param_node, - set_node_arg, ) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass @@ -23,6 +22,7 @@ from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node @@ -66,6 +66,38 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]: return output_qparams +class RetraceFoldedDtypesPass(ArmPass): + """ + FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced + some operators are retraced to types that cannot be handled by TOSA. One + such example is sum.dim_IntList: + q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ... + After folding it becomes: + q (int8) -> sum (int64) -> ... + This pass changes types of ops in self.targeted_ops, such as sum, so that + the output type of that matches the type of the output_qparams. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + targeted_ops: Set[EdgeOpOverload] = { + exir_ops.edge.aten.sum.dim_IntList, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in self.targeted_ops: + return super().call_operator(op, args, kwargs, meta, False) + + node_kwargs = kwargs.copy() + output_qparams = meta["output_qparams"] + if len(output_qparams) == 0: + return super().call_operator(op, args, kwargs, meta, False) + + output_dtype = output_qparams[0].dtype + node_kwargs["dtype"] = output_dtype + return super().call_operator(op, args, node_kwargs, meta, True) + + class FoldAndAnnotateQParamsPass(ArmPass): """ A pass that walks the graph and removes any DQ and Q nodes before and after the target @@ -97,6 +129,7 @@ class FoldAndAnnotateQParamsPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { + RetraceFoldedDtypesPass, InsertTableOpsPass, RemoveNoopPass, } @@ -201,16 +234,6 @@ def call(self, graph_module: GraphModule) -> PassResult: user.replace_all_uses_with(n) graph_module.graph.erase_node(user) - # Some op(s) contain a "dtype" key in their node kwargs. Set this - # to the type of output qparams. - output_qparams = n.meta["output_qparams"] - if ( - n.target in {exir_ops.edge.aten.sum.dim_IntList} - and len(output_qparams) > 0 - ): - output_dtype = output_qparams[0].dtype - set_node_arg(n, "dtype", output_dtype) - # retrace the graph to update the fake tensor types graph_module = super().call(graph_module).graph_module From 5b1a947f1a90dc521c7e5097d65eb450cd194b51 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 02/19] Revert "Arm backend: Serialize controlflow submodules. (#15381)" This reverts commit a4e747566cb6cabedbdb32e74dc8901ba776058f. --- backends/arm/_passes/arm_pass_manager.py | 28 ++--- backends/arm/process_node.py | 10 +- backends/arm/tosa/backend.py | 148 +++++++++-------------- backends/arm/tosa/mapping.py | 4 +- 4 files changed, 74 insertions(+), 116 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 728eb984ae9..b1eea847792 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -153,15 +153,15 @@ def _transform(self, graph_module: GraphModule): with TosaLoweringContext(self.tosa_spec): return self(graph_module).graph_module - def _tosa_INT_pipeline( - self, exported_program: ExportedProgram, graph_module: GraphModule - ) -> GraphModule: + def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(AnnotateOutputDimOrderPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) self.add_pass(ConvertSplitToSlicePass()) self.add_pass(ConvertMmToBmmPass()) - self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) + self.add_pass( + DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) + ) self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) @@ -218,11 +218,9 @@ def _tosa_INT_pipeline( self.add_pass(InsertRescalePass()) self.validate_constraints_mandatory() - return self._transform(graph_module) + return self._transform(exported_program.graph_module) - def _tosa_FP_pipeline( - self, exported_program: ExportedProgram, graph_module: GraphModule - ) -> GraphModule: + def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(AnnotateOutputDimOrderPass()) self.add_pass(DecomposeExpm1Pass()) self.add_pass(DecomposeLogitPass()) @@ -257,7 +255,9 @@ def _tosa_FP_pipeline( self.add_pass(DecomposeLayerNormPass()) self.add_pass(DecomposeBatchNormNoStatsPass()) self.add_pass(DecomposeVarPass()) - self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) + self.add_pass( + DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) + ) self.add_pass(DecomposeNotEqualPass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeAddSubAlphaPass()) @@ -305,16 +305,14 @@ def _tosa_FP_pipeline( self.add_pass(InsertRescalePass()) self.validate_constraints_mandatory() - return self._transform(graph_module) + return self._transform(exported_program.graph_module) - def transform_to_backend_pipeline( - self, exported_program: ExportedProgram, graph_module: GraphModule - ): + def transform_to_backend_pipeline(self, exported_program: ExportedProgram): """Apply passes before transforming program to backend""" if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"): - return self._tosa_FP_pipeline(exported_program, graph_module) + return self._tosa_FP_pipeline(exported_program) elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"): - return self._tosa_INT_pipeline(exported_program, graph_module) + return self._tosa_INT_pipeline(exported_program) else: raise NotImplementedError( f"No pass pipeline implemented for {self.tosa_spec=}" diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index f9694c1abdf..54797b825b7 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -158,7 +158,7 @@ def process_inputs_to_buffers( buffer_values = np.transpose(buffer_values, tosa_arg.dim_order) tosa_graph.addConst( - buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name + buffer_values.shape, tosa_arg.dtype, buffer_values, name=node.name ) @@ -215,9 +215,11 @@ def process_placeholder( raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.") -def process_output(node: torch.fx.Node, tosa_graph: Any, tosa_spec: TosaSpecification): +def process_output( + node: torch.fx.Node, + tosa_graph: Any, +): for output in cast(tuple[torch.fx.Node, ...], node.args[0]): - output_arg = TosaArg(output, tosa_spec) tosa_graph.addOutputTensor( - tosa_graph.currRegion.currBasicBlock.tensors[output_arg.name] + tosa_graph.currRegion.currBasicBlock.tensors[output.name] ) diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 75cdd28c6da..8fb50707952 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -24,13 +24,10 @@ process_placeholder, ) from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec -from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.graph_module import get_control_flow_submodules from torch.export.exported_program import ExportedProgram -from torch.fx import Graph, GraphModule, Node - +from torch.fx import Graph, Node # TOSA backend debug functionality logger = logging.getLogger(__name__) @@ -55,39 +52,13 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]): # Walk backwards so we touch every producer q.extend(n.all_input_nodes) - out = ep_graph.output_node() - # First argument of output node is tuple of outputs - output_list = cast(tuple, out.args[0]) + out = next(n for n in ep_graph.nodes if n.op == "output") seen: Set[Node] = set() - for idx, val in enumerate(output_list): + for idx, val in enumerate(out.args[0]): bfs_mark([val], idx, seen) return node2external_id -def _sort_outputs(graph_module: GraphModule, node_to_id_map: dict[str, int]): - def _external_id(n: Node, node_2_id, fallback: int) -> int: - return node_2_id.get(n.name, fallback) - - out_node = graph_module.graph.output_node() - out_list = cast(tuple, out_node.args[0]) - _counter = count() - - # sort nodes by the key that is id - def _sort_key(t: Node) -> int: - return _external_id(t, node_to_id_map, next(_counter)) - - orig_ord = tuple(sorted(out_list, key=_sort_key)) - - current_order = tuple(out_list) - if orig_ord != current_order: - replacement = list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord - out_node.args = (replacement,) - graph_module.graph.lint() - graph_module.recompile() - - return graph_module - - def arm_get_first_delegation_tag(graph_module) -> str: """Get the first delegation tag from the graph_module or return empty string.""" for node in graph_module.graph.nodes: @@ -122,9 +93,9 @@ def _preprocess( # noqa: C901 artifact_path = compile_spec.get_intermediate_path() tosa_spec = compile_spec.tosa_spec dump_debug_info = compile_spec.tosa_debug_mode - debug_hook = None - if dump_debug_info is not None: - debug_hook = DebugHook(dump_debug_info) + + # Assign to every node external id + node_2_id = _annotate_external_ids(edge_program.graph) logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}") @@ -145,66 +116,45 @@ def _preprocess( # noqa: C901 f"doesn't match specification {tosa_spec}" ) - TOSABackend._preprocess_module( - edge_program.graph_module, - edge_program, - compile_spec, - tosa_graph, - debug_hook, - ) - # Serialize and return the TOSA flatbuffer. - binary = tosa_graph.serialize() - - if artifact_path: - tag = arm_get_first_delegation_tag(edge_program.graph_module) - debug_tosa_dump( - binary, - artifact_path, - suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), - ) - - if debug_hook is not None: - if debug_hook.mode == ArmCompileSpec.DebugMode.JSON: - json_output = debug_hook.serialize() - with open(f"{artifact_path}/debug.json", "w") as f: - f.write(json_output) - - return PreprocessResult(processed_bytes=binary) - - @staticmethod - def _preprocess_module( # noqa: C901 - graph_module: GraphModule, - edge_program: ExportedProgram, - compile_spec: TosaCompileSpec, - tosa_graph: ts.TosaSerializer, - debug_hook: DebugHook | None, - submodule_name: str | None = None, - ): - """Convert 'graph_module' to a tosa_graph""" - tosa_spec = compile_spec.tosa_spec - node_to_id_map = _annotate_external_ids(graph_module.graph) - artifact_path = compile_spec.get_intermediate_path() - # TODO: Fix the need to lazily import this. from executorch.backends.arm._passes import ArmPassManager graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore - exported_program=edge_program, graph_module=graph_module + exported_program=edge_program ) + debug_hook = None + if dump_debug_info is not None: + debug_hook = DebugHook(dump_debug_info) + # TODO: Fix the need to lazily import this. from executorch.backends.arm.operators.node_visitor import get_node_visitors node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook) - graph_module = _sort_outputs(graph_module, node_to_id_map) - if submodule_name is not None: - tosa_graph.startRegion(submodule_name) - tosa_graph.currRegion.addBasicBlock(submodule_name) - suffix = f"_{submodule_name}" - for loop_node in graph_module.graph.nodes: - loop_node.meta[TOSA_TENSOR_NAME_META] = suffix + # Re-shuffle output nodes to preserve author's order + def _external_id(n: Node, node_2_id, fallback: int) -> int: + return node_2_id.get(n.name, fallback) + + out_node = next(n for n in graph_module.graph.nodes if n.op == "output") + _counter = count() + + # sort nodes by the key that is id + def _sort_key(t: Node) -> int: + return _external_id(t, node_2_id, next(_counter)) + orig_ord = tuple(sorted(out_node.args[0], key=_sort_key)) + + current_order = tuple(out_node.args[0]) + if orig_ord != current_order: + replacement = ( + list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord + ) + out_node.args = (replacement,) + graph_module.graph.lint() + graph_module.recompile() + + input_count = 0 for node in graph_module.graph.nodes: node = cast(Node, node) try: @@ -214,27 +164,37 @@ def _preprocess_module( # noqa: C901 if len(node.users) == 0: continue process_placeholder(node, tosa_graph, edge_program, tosa_spec) + if node.name in edge_program.graph_signature.user_inputs: + input_count += 1 elif node.op == "output": - process_output(node, tosa_graph, tosa_spec) + process_output(node, tosa_graph) else: # This will only happen if an unpartitioned graph is passed without # any checking of compatibility. raise RuntimeError(f"{node.name} is unsupported op {node.op}") except Exception: - debug_fail(node, graph_module, tosa_graph, artifact_path) + debug_fail(node, graph_module, tosa_graph.serialize(), artifact_path) raise - # Recursively preprocess controlflow submodules. - for name, submodule, _ in get_control_flow_submodules(graph_module): - TOSABackend._preprocess_module( - submodule, - edge_program, - compile_spec, - tosa_graph, - debug_hook, - submodule_name=name, + # Serialize and return the TOSA flatbuffer. + binary = tosa_graph.serialize() + + if artifact_path: + tag = arm_get_first_delegation_tag(graph_module) + debug_tosa_dump( + binary, + artifact_path, + suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), ) + if debug_hook is not None: + if debug_hook.mode == ArmCompileSpec.DebugMode.JSON: + json_output = debug_hook.serialize() + with open(f"{artifact_path}/debug.json", "w") as f: + f.write(json_output) + + return PreprocessResult(processed_bytes=binary) + @staticmethod def filter_tosa_compile_specs( compile_spec: ArmCompileSpec, diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 5162d2c6a53..2287a727009 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -17,8 +17,6 @@ import tosa_serializer as ts from executorch.backends.arm.tosa.specification import TosaSpecification -TOSA_TENSOR_NAME_META = "tosa_tensor_name" - UNSUPPORTED_DTYPES = ( torch.float64, torch.double, @@ -146,7 +144,7 @@ def __process_node(self, argument: torch.fx.Node): argument (torch.fx.Node): FX node to inspect. """ - self.name = argument.name + argument.meta.get(TOSA_TENSOR_NAME_META, "") + self.name: str = argument.name output_dtype, self.shape, self.dim_order = extract_tensor_meta( argument.meta, self.tosa_spec ) From 4250b49298e4deb65083c9e013350064080b7b00 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 03/19] Revert "Arm backend: Remove pyre-unsafe from remaining files (#15391)" This reverts commit 542691866a0bc37c83609f6b26181219e2c29889. --- backends/arm/common/arm_compile_spec.py | 1 + backends/arm/operator_support/__init__.py | 1 + backends/arm/operator_support/ethos_u55_support.py | 1 + backends/arm/operator_support/right_shift_support.py | 2 ++ backends/arm/operator_support/to_dim_order_copy_support.py | 1 + backends/arm/operator_support/tosa_supported_operators.py | 1 + backends/arm/test/misc/test_outputs_order.py | 1 + backends/arm/util/arm_model_evaluator.py | 1 + 8 files changed, 9 insertions(+) diff --git a/backends/arm/common/arm_compile_spec.py b/backends/arm/common/arm_compile_spec.py index 29037ec833a..b38fe72b29c 100644 --- a/backends/arm/common/arm_compile_spec.py +++ b/backends/arm/common/arm_compile_spec.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe # # Main implementation of AoT flow to partition and preprocess for Arm target diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index f3c50ee3719..53d37407ee6 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from . import ( # noqa clone_dim_order_support, diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index f6fdada7d52..2403cfffa7e 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -10,6 +10,7 @@ """ +# pyre-unsafe import typing from typing import cast diff --git a/backends/arm/operator_support/right_shift_support.py b/backends/arm/operator_support/right_shift_support.py index 82c4387fc85..df124319887 100644 --- a/backends/arm/operator_support/right_shift_support.py +++ b/backends/arm/operator_support/right_shift_support.py @@ -9,6 +9,8 @@ """ +# pyre-unsafe + import logging diff --git a/backends/arm/operator_support/to_dim_order_copy_support.py b/backends/arm/operator_support/to_dim_order_copy_support.py index 181796b97fe..3cc587d99d3 100644 --- a/backends/arm/operator_support/to_dim_order_copy_support.py +++ b/backends/arm/operator_support/to_dim_order_copy_support.py @@ -10,6 +10,7 @@ """ +# pyre-unsafe import copy import logging diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 3a1d11eab8c..ba479818a81 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import itertools import operator diff --git a/backends/arm/test/misc/test_outputs_order.py b/backends/arm/test/misc/test_outputs_order.py index 8eb34b8605c..ff02ffc360a 100644 --- a/backends/arm/test/misc/test_outputs_order.py +++ b/backends/arm/test/misc/test_outputs_order.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # +# pyre-unsafe import tempfile from pathlib import Path diff --git a/backends/arm/util/arm_model_evaluator.py b/backends/arm/util/arm_model_evaluator.py index ac2c9cfe065..8c36128cea8 100644 --- a/backends/arm/util/arm_model_evaluator.py +++ b/backends/arm/util/arm_model_evaluator.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import json import logging From d975528d1ccd40bd158401d94ad74493c784728f Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 04/19] Revert "Arm backend: Move rescales from ADD & SUB visitors to pass (#15378)" This reverts commit 4efd79c8d875cde82eaa57c67aa4b0723a37d07e. --- backends/arm/_passes/insert_rescales_pass.py | 43 +-- backends/arm/operators/op_add.py | 110 +++++- backends/arm/operators/op_sub.py | 105 ++++- .../test/misc/test_conv_relu_residual_add.py | 7 - backends/arm/test/ops/test_var.py | 12 +- .../passes/test_insert_rescale_i32_pass.py | 14 +- backends/arm/tosa/quant_utils.py | 361 +++++++++++++++++- 7 files changed, 570 insertions(+), 82 deletions(-) diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 3826cb13337..831dfe360b1 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -76,12 +76,13 @@ def call(self, graph_module: GraphModule) -> PassResult: class InsertRescaleInt32Pass(ArmPass): - """Numerous TOSA ops require inputs and outputs to be 32-bit integers in their + """ + Numerous TOSA ops require inputs and outputs to be 32-bit integers in their quantized implementations. This pass treats such operator nodes by - inserting rescale ops before and after them if needed. Note that extra - logic that handles the scales and zero points are in place here because the - affected TOSA ops have naive implementations that do not account for the - quantization parameters. + inserting rescale ops before and after them if needed. Note that extra logic + that handles the scales and zero points must be in place because the affected + TOSA have naive implementations that do not account for the quantization + parameters. """ # SUM must be decomposed after this pass to prevent insertion of RESCALE @@ -92,7 +93,6 @@ class InsertRescaleInt32Pass(ArmPass): included_targets = [ exir_ops.edge.aten.abs.default, - exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.ge.Tensor, exir_ops.edge.aten.gt.Tensor, @@ -101,7 +101,6 @@ class InsertRescaleInt32Pass(ArmPass): exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, ] @@ -143,34 +142,6 @@ def _get_inputs_rescaled_qparams( qparams = { i: self._int32_qargs(min_scale) for i in range(len(input_qparams)) } - elif target in [ - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.sub.Tensor, - ]: - if input_qparams[0].dtype != input_qparams[1].dtype: - raise ValueError( - "Mismatch in dtype args: {input_qparams[0].dtype} != {input_qparams[1].dtype}" - ) - - # We are handling two INT8 or two INT16 numbers. For INT8, if the - # zero point is non-null, the result will be in the range [-255; - # 255], therefore we need 9 bits for the result. We have a 32-bit - # accumulator, so we can divide the scale by (1 << 20) which is - # equivalent to shifting the INT8 operands 20 bits to the left - # before rescaling them both to 2 * max(lhs, rhs). - # - # For INT16, similary logic can be applied, but we instead end up - # with a left shift of 12. - lhs_scale, rhs_scale = ( - qp.get_scale_per_tensor() for qp in input_qparams.values() - ) - max_scale_2x = 2 * max(lhs_scale, rhs_scale) - - # Select shift based on input dtype. - shift_bits = 12 if input_qparams[0].dtype == torch.int16 else 20 - - scale = max_scale_2x / (1 << shift_bits) - qparams = {i: self._int32_qargs(scale) for i in range(len(input_qparams))} elif target in [ exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.sum.dim_IntList, @@ -197,8 +168,6 @@ def _get_output_qparams( exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.sum.dim_IntList, - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.sub.Tensor, ]: # The op has not altered the scale; the output scale is equal to # the operands' scales. diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 6c1ff2e1449..2ae792f0ee1 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -6,6 +6,8 @@ from typing import Any, List +import executorch.backends.arm.tosa.quant_utils as tqutils +import executorch.backends.arm.tosa.utils as tutils import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( @@ -17,20 +19,22 @@ validate_same_dtype, validate_valid_dtype, ) +from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @register_node_visitor -class AddVisitor(NodeVisitor): +class AddVisitor_INT(NodeVisitor): target = "aten.add.Tensor" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), ] + def __init__(self, *args): + super().__init__(*args) + def define_node( self, node: Node, @@ -40,21 +44,113 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) + valid_dtypes = [] + if self.tosa_spec.support_integer(): + valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]) + if self.tosa_spec.support_float(): + valid_dtypes.extend([ts.DType.INT32]) + validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], + valid_dtypes, output.tosa_spec, ) + scale_back = 1.0 + if inputs[0].dtype == ts.DType.INT8: + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( + tosa_graph, inputs, node, self.tosa_spec + ) + elif inputs[0].dtype == ts.DType.INT16: + rescaled_inputs, scale_back = ( + tqutils.insert_rescale_ops_int16_to_int32_maxscale( + tosa_graph, inputs, node, self.tosa_spec + ) + ) + else: + # input[0].dtype == ts.DType.INT16 or ts.DType.INT32 + # Non quantized input, natively support by TOSA.ADD + rescaled_inputs = inputs + + if output.dtype in [ts.DType.INT8, ts.DType.INT16]: + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) + add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + else: + # output.dtype == ts.DType.INT16 or ts.DType.INT32 + add_output = output + input1, input2 = rescaled_inputs attr = ts.TosaSerializerAttribute() attr.AddAttribute() - + # Do the INT32 Add self._serialize_operator( node, tosa_graph, ts.Op.ADD, - [inputs[0].name, inputs[1].name], - [output.name], + [input1.name, input2.name], + [add_output.name], attr, ) + + if output.dtype == ts.DType.INT8: + # Scale output back to 8 bit + # pyre-ignore + tqutils.insert_rescale_op_to_int8( + tosa_graph, + add_output, + scale_back, + node, + compute_rescale=False, + tosa_spec=self.tosa_spec, + ) # type: ignore[possibly-undefined] + elif output.dtype == ts.DType.INT16: + tqutils.insert_rescale_op_to_int16( + tosa_graph, + add_output, + scale_back, + node, + compute_rescale=False, + tosa_spec=self.tosa_spec, + ) # type: ignore[possibly-undefined] + + +@register_node_visitor +class AddVisitor_FP(AddVisitor_INT): + # inheriting 'target' from INT class + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, 2) + validate_same_dtype(self.target, [*inputs, output], ts) + + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output) + else: + # FP32 Add lowering + validate_valid_dtype( + self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + ) + + input1, input2 = inputs + attr = ts.TosaSerializerAttribute() + attr.AddAttribute() + # FP lowering + self._serialize_operator( + node, + tosa_graph, + ts.Op.ADD, + [input1.name, input2.name], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 039a2f6bd68..f5f82679ca8 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -6,6 +6,8 @@ from typing import Any, List +import executorch.backends.arm.tosa.quant_utils as tqutils +import executorch.backends.arm.tosa.utils as tutils import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( @@ -17,20 +19,22 @@ validate_same_dtype, validate_valid_dtype, ) +from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @register_node_visitor -class SubVisitor(NodeVisitor): +class SubVisitor_INT(NodeVisitor): target = "aten.sub.Tensor" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), ] + def __init__(self, *args): + super().__init__(*args) + def define_node( self, node: Node, @@ -43,21 +47,106 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], output.tosa_spec, ) + scale_back = 1.0 + if inputs[0].dtype == ts.DType.INT8: + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( + tosa_graph, inputs, node, self.tosa_spec + ) + elif inputs[0].dtype == ts.DType.INT16: + rescaled_inputs, scale_back = ( + tqutils.insert_rescale_ops_int16_to_int32_maxscale( + tosa_graph, inputs, node, self.tosa_spec + ) + ) + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.SUB + rescaled_inputs = inputs + + if output.dtype in [ts.DType.INT8, ts.DType.INT16]: + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) + sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + else: + # output.dtype == ts.DType.INT32 + sub_output = output + + # Do the INT32 Sub attr = ts.TosaSerializerAttribute() attr.SubAttribute() - self._serialize_operator( node, tosa_graph, ts.Op.SUB, [ - inputs[0].name, - inputs[1].name, + rescaled_inputs[0].name, + rescaled_inputs[1].name, ], - [output.name], + [sub_output.name], attr, ) + + if output.dtype == ts.DType.INT8: + # Scale output back to 8 bit + # pyre-ignore + tqutils.insert_rescale_op_to_int8( + tosa_graph, + sub_output, + scale_back, + node, + compute_rescale=False, + tosa_spec=self.tosa_spec, + ) # type: ignore[possibly-undefined] + elif output.dtype == ts.DType.INT16: + tqutils.insert_rescale_op_to_int16( + tosa_graph, + sub_output, + scale_back, + node, + compute_rescale=False, + tosa_spec=self.tosa_spec, + ) # type: ignore[possibly-undefined] + + +@register_node_visitor +class SubVisitor_FP(SubVisitor_INT): + # inheriting 'target' from INT class + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, 2) + validate_same_dtype(self.target, [*inputs, output], ts) + + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output) + else: + # FP32 Sub lowering + validate_valid_dtype( + self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + ) + + # MI lowering + attr = ts.TosaSerializerAttribute() + attr.SubAttribute() + self._serialize_operator( + node, + tosa_graph, + ts.Op.SUB, + [inputs[0].name, inputs[1].name], + [output.name], + attr, + ) diff --git a/backends/arm/test/misc/test_conv_relu_residual_add.py b/backends/arm/test/misc/test_conv_relu_residual_add.py index 72886fb4b29..d88a9c74b7c 100644 --- a/backends/arm/test/misc/test_conv_relu_residual_add.py +++ b/backends/arm/test/misc/test_conv_relu_residual_add.py @@ -76,13 +76,6 @@ def test_tosa_INT(per_channel_quantization): pipeline.run() -# TODO: Xfail until the Ethos-U Vela compiler ships commit -# 642f7517d3a6bd053032e1942822f6e38ccd546f. That patch fixes the bug that -# causes this test to fail. -@pytest.mark.xfail( - reason=("Blocked by Vela commit 642f7517d3a6bd053032e1942822f6e38ccd546f"), - strict=True, -) @pytest.mark.slow @common.XfailIfNoCorstone300 @common.parametrize("per_channel_quantization", quant_test_data) diff --git a/backends/arm/test/ops/test_var.py b/backends/arm/test/ops/test_var.py index 282c3a4455d..9f1c437fc65 100644 --- a/backends/arm/test/ops/test_var.py +++ b/backends/arm/test/ops/test_var.py @@ -344,17 +344,7 @@ def test_var_dim_tosa_INT_correction(test_data: Tuple): pipeline.run() -# TODO: Xfail "var_3d_dims_keep_dim_0_correction" until the Ethos-U Vela compiler ships commit -# 642f7517d3a6bd053032e1942822f6e38ccd546f. That patch fixes the bug that causes the test to fail. -@common.parametrize( - "test_data", - VarCorrection.test_parameters, - xfails={ - "var_3d_dims_keep_dim_0_correction": ( - "Blocked by Vela commit 642f7517d3a6bd053032e1942822f6e38ccd546f" - ), - }, -) +@common.parametrize("test_data", VarCorrection.test_parameters) @common.XfailIfNoCorstone300 def test_var_dim_u55_INT_correction(test_data: Tuple): test_data, dim, keepdim, correction = test_data() diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py index 4b5c16ab31a..2f625b955ce 100644 --- a/backends/arm/test/passes/test_insert_rescale_i32_pass.py +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -19,13 +19,11 @@ class MultipleOpsModel(torch.nn.Module): input_t = Tuple[torch.Tensor, torch.Tensor] def forward(self, x, y): - a = x - y - b = x * a - c = torch.maximum(a, b) - d = torch.abs(b) - e = c + d - f = e > a - return f + a = x * y + b = torch.maximum(a, y) + c = torch.abs(b) + d = c > b + return d def get_inputs(self, dtype) -> input_t: if dtype == torch.float32: @@ -40,7 +38,7 @@ def get_inputs(self, dtype) -> input_t: def get_num_expected_rescales(self): # "number of op nodes with i8 output" + "number of i8 node inputs" - return 5 + 11 + return 3 + 7 class SumModel(torch.nn.Module): diff --git a/backends/arm/tosa/quant_utils.py b/backends/arm/tosa/quant_utils.py index b3840c6ab1c..9ad2192bb9a 100644 --- a/backends/arm/tosa/quant_utils.py +++ b/backends/arm/tosa/quant_utils.py @@ -11,13 +11,270 @@ from typing import Any, Tuple import tosa_serializer as ts +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) + +from executorch.backends.arm.tosa.mapping import TosaArg +from torch.fx import Node + + +def insert_rescale_ops_to_int32_maxscale( + tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None +) -> tuple[list[Any], float]: + """For ADD and SUB, we rescale to int32 using a different common scale(2*max(left scale,right scale)) + compared to all the other cases. We also multiply the left and right scales by 1<<20 giving us extra precision + for the computation without overflowing. + + Returns a list of the rescaled nodes and the scale factor used, + needed by insert_rescale_op_to_int8. + """ + + if len(inputs) > 2: + raise ValueError("More than two inputs not supported") + + tensors = inputs.copy() + # Reshape tensor according to TOSA dim order + for tensor in tensors: + dim_order = tensor.dim_order + tensor.shape = [tensor.shape[i] for i in dim_order] + + input_qparams = get_input_qparams(node) + lhs_qparams, rhs_qparams = input_qparams.values() + lhs_scale = lhs_qparams.get_scale_per_tensor() + rhs_scale = rhs_qparams.get_scale_per_tensor() + # Common scale for the two numbers + max_scale_2x = 2 * max(lhs_scale, rhs_scale) + SHIFT_INT8 = 20 + # We are adding two int8 numbers. If the zero point is non-null, the result will be in the range [-255;255], therefore we need 9 bits for the result. + # We have a 32-bit accumulator, so we can shift to the left by 20 bits and not overflow. In reality, because we divide by the 2*max(lhs_scale,rhs_scale) + # we are shifting to the left by 19. + lhs_factor = (1 << SHIFT_INT8) * lhs_scale / max_scale_2x + rhs_factor = (1 << SHIFT_INT8) * rhs_scale / max_scale_2x + rescaled_lhs = build_rescale_to_int32( + tosa_graph, + tensors[0], + lhs_qparams.get_zp_per_tensor(), + lhs_factor, + tosa_spec=tosa_spec, + ) + rescaled_rhs = build_rescale_to_int32( + tosa_graph, + tensors[1], + rhs_qparams.get_zp_per_tensor(), + rhs_factor, + tosa_spec=tosa_spec, + ) + out_qparam = get_output_qparams(node)[0] + out_scale = out_qparam.get_scale_per_tensor() + back_scale = max_scale_2x / (out_scale * (1 << SHIFT_INT8)) + + return [rescaled_lhs, rescaled_rhs], back_scale + + +def insert_rescale_ops_int16_to_int32_maxscale( + tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None +) -> tuple[list[Any], float]: + """For ADD and SUB with int16 inputs, we rescale to int32 using a different common scale(2*max(left scale,right scale)) + compared to all the other cases. We multiply the left and right scales by 1<<12 giving us extra precision + for the computation without overflowing. + + Returns a list of the rescaled nodes and the scale factor used, + needed by insert_rescale_op_to_int16. + """ + + if len(inputs) > 2: + raise ValueError("More than two inputs not supported") + + tensors = inputs.copy() + # Reshape tensor according to TOSA dim order + for tensor in tensors: + dim_order = tensor.dim_order + tensor.shape = [tensor.shape[i] for i in dim_order] + + input_qparams = get_input_qparams(node) + lhs_qparams, rhs_qparams = input_qparams.values() + lhs_scale = lhs_qparams.get_scale_per_tensor() + rhs_scale = rhs_qparams.get_scale_per_tensor() + # Common scale for the two numbers + max_scale_2x = 2 * max(lhs_scale, rhs_scale) + SHIFT_INT16 = 12 + # We are adding two int16 numbers. If the zero point is non-null, the result will be in the range [-131070;131070], therefore we need 18 bits for the result. + # We have a 32-bit accumulator, so we can shift to the left by 12 bits and not overflow. In reality, because we divide by the 2*max(lhs_scale,rhs_scale) + # we are shifting to the left by 11. + lhs_factor = (1 << SHIFT_INT16) * lhs_scale / max_scale_2x + rhs_factor = (1 << SHIFT_INT16) * rhs_scale / max_scale_2x + rescaled_lhs = build_rescale_to_int32( + tosa_graph, + tensors[0], + lhs_qparams.get_zp_per_tensor(), + lhs_factor, + tosa_spec=tosa_spec, + ) + rescaled_rhs = build_rescale_to_int32( + tosa_graph, + tensors[1], + rhs_qparams.get_zp_per_tensor(), + rhs_factor, + tosa_spec=tosa_spec, + ) + out_qparam = get_output_qparams(node)[0] + out_scale = out_qparam.get_scale_per_tensor() + back_scale = max_scale_2x / (out_scale * (1 << SHIFT_INT16)) + + return [rescaled_lhs, rescaled_rhs], back_scale + + +def insert_rescale_ops_to_int32( + tosa_graph: Any, + inputs: list[TosaArg], + node: Node, + tosa_spec=None, +) -> tuple[list[Any], float]: + """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. + The scales are adjusted using the smallest scale of all 'nodes'. + + Returns a list of the rescaled nodes and the scale factor used, + needed by insert_rescale_op_to_int8. + + This functions is used in serialization to TOSA for target ops that are + handled by the DQ/D folding pass, which stores the quantization parameters + in the node meta dict. + """ + + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + ) + + tensors = inputs.copy() + + # Reshape tensor according to TOSA dim order + for tensor in tensors: + dim_order = tensor.dim_order + tensor.shape = [tensor.shape[i] for i in dim_order] + + input_qparams = get_input_qparams(node) + qargs = input_qparams.values() + + # Scale the int8 quantized input to a common scale in the integer + # domain + min_scale = min([qarg.get_scale_per_tensor() for qarg in qargs]) + scales = [qarg.get_scale_per_tensor() / min_scale for qarg in qargs] + + rescaled_nodes: list[Any] = [] + for tensor, qarg, scale in zip(tensors, qargs, scales): + rescaled_nodes.append( + build_rescale_to_int32( + tosa_graph, tensor, qarg.get_zp_per_tensor(), scale, tosa_spec=tosa_spec + ) + ) + return rescaled_nodes, min_scale + + +def insert_rescale_op_to_int8( + tosa_graph: Any, + last_tensor: TosaArg, + scale: float, + node: Node, + compute_rescale=True, + tosa_spec=None, +) -> None: + """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. + Parameters: + node: The original node that is being handled by the rescales. + last_tensor:the tosa tensor to rescale back. + scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' + compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. + tosa_graph: the tosa_graph to manipulate. + + This functions is used in serialization to TOSA for target ops that are + handled by the DQ/D folding pass, which stores the quantization parameters + in the node meta dict. + """ + _insert_rescale_op_to_dtype( + tosa_graph, last_tensor, scale, node, ts.DType.INT8, compute_rescale, tosa_spec + ) + + +def insert_rescale_op_to_int16( + tosa_graph: Any, + last_tensor: TosaArg, + scale: float, + node: Node, + compute_rescale=True, + tosa_spec=None, +) -> None: + """Rescales the node back to int16, adding a suitable RESCALE op to 'tosa_graph'. + Parameters: + node: The original node that is being handled by the rescales. + last_tensor:the tosa tensor to rescale back. + scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' + compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. + tosa_graph: the tosa_graph to manipulate. + + This functions is used in serialization to TOSA for target ops that are + handled by the DQ/D folding pass, which stores the quantization parameters + in the node meta dict. + """ + _insert_rescale_op_to_dtype( + tosa_graph, last_tensor, scale, node, ts.DType.INT16, compute_rescale, tosa_spec + ) + + +def _insert_rescale_op_to_dtype( + tosa_graph: Any, + last_tensor: TosaArg, + scale: float, + node: Node, + output_dtype: Any, + compute_rescale=True, + tosa_spec=None, +) -> None: + """Common implementation for rescaling nodes back to a specific dtype. + Parameters: + node: The original node that is being handled by the rescales. + last_tensor:the tosa tensor to rescale back. + scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' + output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16) + compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. + tosa_graph: the tosa_graph to manipulate. + + This functions is used in serialization to TOSA for target ops that are + handled by the DQ/D folding pass, which stores the quantization parameters + in the node meta dict. + """ + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_output_qparams, + ) + + output_qparams = get_output_qparams(node) + if len(output_qparams) != 1: + raise ValueError("More than one output not supported") + + qargs_out = output_qparams[0] + if compute_rescale: + output_rescale_scale = scale / qargs_out.get_scale_per_tensor() + else: + output_rescale_scale = scale + + # Rescale Back to the specified dtype + build_rescale_from_int32_to_dtype( + tosa_graph, + last_tensor, + node.name, + qargs_out.get_zp_per_tensor(), + output_rescale_scale, + output_dtype, + tosa_spec=tosa_spec, + ) # TOSA uses the RESCALE operation to scale between values with differing precision. # The RESCALE operator is defined using an integer multiply, add, and shift. # This utility function is for calculating the multiplier and shift given a scale. # Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling -def _compute_multiplier_and_shift( +def compute_multiplier_and_shift( scales: list[float], scaleWidth: int = 32 ) -> Tuple[list[int], list[int]]: if scaleWidth == 16: @@ -70,7 +327,7 @@ def _compute_multiplier_and_shift( # For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be # const inputs. Create constant operators from the data already initialized. -def _create_const_ops_for_rescale( +def create_const_ops_for_rescale( tosa_fb, scale_32, input_dtype, @@ -116,8 +373,8 @@ def build_rescale( ): scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 is_scale32 = False if input_node.dtype == ts.DType.INT48 else True - multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth) - rescale_inputs = _create_const_ops_for_rescale( + multipliers, shifts = compute_multiplier_and_shift(scale, scaleWidth) + rescale_inputs = create_const_ops_for_rescale( tosa_fb, is_scale32, input_node.dtype, @@ -146,3 +403,99 @@ def build_rescale( ) return + + +def build_rescale_to_int32( + tosa_fb: Any, + input_arg: TosaArg, + input_zp: int, + rescale_scale: float, + is_scale32: bool = True, + is_double_round: bool = False, + per_channel: bool = False, + tosa_spec=None, +) -> Any: + input_A_rescaled_to_int32 = None + + input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input_arg.shape, ts.DType.INT32) + + build_rescale( + tosa_fb, + [rescale_scale], + input_arg, + input_A_rescaled_to_int32.name, + ts.DType.INT32, + [input_zp], + [0], + rounding_mode=ts.RoundingMode.SINGLE_ROUND, + ) # type: ignore[call-arg] + + return input_A_rescaled_to_int32 + + +def build_rescale_from_int32( + tosa_fb: Any, + input_node: TosaArg, + output_name: str, + output_zp: int, + rescale_scale: float, + is_scale32: bool = True, + is_double_round: bool = False, + per_channel: bool = False, + tosa_spec=None, +) -> None: + # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs + # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale + build_rescale_from_int32_to_dtype( + tosa_fb, + input_node, + output_name, + output_zp, + rescale_scale, + ts.DType.INT8, + is_scale32, + is_double_round, + per_channel, + tosa_spec, + ) + + return + + +def build_rescale_from_int32_to_dtype( + tosa_fb: Any, + input_node: TosaArg, + output_name: str, + output_zp: int, + rescale_scale: float, + output_dtype: Any, + is_scale32: bool = True, + is_double_round: bool = False, + per_channel: bool = False, + tosa_spec=None, +) -> None: + """Common implementation for rescaling from INT32 to a specific dtype (INT8 or INT16). + + Parameters: + tosa_fb: The TOSA serializer + input_node: Input tensor (should be INT32) + output_name: Name for the output tensor + output_zp: Output zero point + rescale_scale: Rescaling factor + output_dtype: Target dtype (ts.DType.INT8 or ts.DType.INT16) + Other parameters: Standard rescale parameters + """ + # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs + # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale + build_rescale( + tosa_fb, + [rescale_scale], + input_node, + output_name=output_name, + output_type=output_dtype, + input_zp=[0], + output_zp=[output_zp], + rounding_mode=ts.RoundingMode.SINGLE_ROUND, + ) # type: ignore[call-arg] + + return From 5c59fa0f6e21cf003e8c7dcd36dcc1cee5b27e63 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 05/19] Revert "Arm backend: Remove pyre-unsafe from operators/ (#15376)" This reverts commit cf407fc91912795a4eba0df0dfc6ac1ac5330624. --- backends/arm/operators/__init__.py | 1 + backends/arm/operators/node_visitor.py | 1 + backends/arm/operators/op_abs.py | 1 + backends/arm/operators/op_add.py | 1 + backends/arm/operators/op_any.py | 1 + backends/arm/operators/op_avg_pool2d.py | 1 + backends/arm/operators/op_cat.py | 1 + backends/arm/operators/op_clamp.py | 1 + backends/arm/operators/op_constant_pad_nd.py | 1 + backends/arm/operators/op_cos.py | 1 + backends/arm/operators/op_eq.py | 1 + backends/arm/operators/op_erf.py | 1 + backends/arm/operators/op_exp.py | 1 + backends/arm/operators/op_ge.py | 1 + backends/arm/operators/op_gt.py | 1 + backends/arm/operators/op_index_select.py | 1 + backends/arm/operators/op_index_tensor.py | 1 + backends/arm/operators/op_le.py | 1 + backends/arm/operators/op_log.py | 1 + backends/arm/operators/op_lt.py | 1 + backends/arm/operators/op_max_pool2d.py | 1 + backends/arm/operators/op_maximum.py | 1 + backends/arm/operators/op_minimum.py | 1 + backends/arm/operators/op_mul.py | 1 + backends/arm/operators/op_neg.py | 1 + backends/arm/operators/op_permute.py | 1 + backends/arm/operators/op_pow.py | 1 + backends/arm/operators/op_reciprocal.py | 1 + backends/arm/operators/op_repeat.py | 1 + backends/arm/operators/op_rshift_tensor.py | 1 + backends/arm/operators/op_rsqrt.py | 1 + backends/arm/operators/op_sigmoid.py | 1 + backends/arm/operators/op_sin.py | 1 + backends/arm/operators/op_slice.py | 1 + backends/arm/operators/op_sub.py | 1 + backends/arm/operators/op_sum.py | 1 + backends/arm/operators/op_tanh.py | 1 + backends/arm/operators/op_to_dim_order_copy.py | 1 + backends/arm/operators/op_tosa_conv2d.py | 1 + backends/arm/operators/op_tosa_depthwise_conv2d.py | 2 ++ backends/arm/operators/op_tosa_matmul.py | 1 + backends/arm/operators/op_tosa_rescale.py | 1 + backends/arm/operators/op_tosa_resize.py | 1 + backends/arm/operators/op_tosa_table.py | 1 + backends/arm/operators/op_tosa_transpose.py | 1 + backends/arm/operators/op_view.py | 1 + backends/arm/operators/ops_binary.py | 1 + backends/arm/operators/ops_identity.py | 1 + 48 files changed, 49 insertions(+) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a180d0a6e86..e7812630f91 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from . import ( # noqa node_visitor, diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 682c849fe80..c26929dab28 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import json from typing import Any, Dict, List, Optional diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index b5a58136395..82e09f5f1d4 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import tosa_serializer as ts diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 2ae792f0ee1..24f43d62e56 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index 3cbdd91d2e4..b024a01bd58 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, cast, List import tosa_serializer as ts diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 83f5f5d45f3..eb4bb743d5b 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import torch diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 2cfa4720c3c..11b46038fbf 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index 74722abe281..1d394f76dab 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree +# pyre-unsafe from typing import Any, List, Tuple diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index 3bda87af5ed..f1f0f1bcb19 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_cos.py b/backends/arm/operators/op_cos.py index e6039730b69..af97c0c95b3 100644 --- a/backends/arm/operators/op_cos.py +++ b/backends/arm/operators/op_cos.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import List import tosa_serializer as ts diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index 8fb789a9d01..1ffdda219a2 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py index e642a4059fe..ef68c97ffcf 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_erf.py @@ -2,6 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import torch.fx diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 72e89b6906b..aef9ec7aca0 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import tosa_serializer as ts diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index 5994cbc9c0f..3b55d526282 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index 859e5c236d7..aa261ea06d7 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_index_select.py b/backends/arm/operators/op_index_select.py index db2488fa163..cf99f4efc50 100644 --- a/backends/arm/operators/op_index_select.py +++ b/backends/arm/operators/op_index_select.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index 760e744923c..af738033e1d 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import math from typing import Any, List diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index fb26b5b8606..086fd892a49 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 565d6d56027..254e02c9adf 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import tosa_serializer as ts diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index f5cf71420f4..ed831206e36 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 1cab28f9153..a068a2f49a7 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import torch diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index d3ab305ea3b..463bee41a52 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index 7f72d158d43..52125ed1f54 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 78b0b1b6675..4b97d5e50b4 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_neg.py b/backends/arm/operators/op_neg.py index e0bb408e155..d025e8f0bd4 100644 --- a/backends/arm/operators/op_neg.py +++ b/backends/arm/operators/op_neg.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import torch.fx diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 80ccfae04e6..26e66b40301 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py index 33cbc290d2c..a46c6dd8df9 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_pow.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 108a4fac0fb..10f9192a9c2 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import torch diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 21a8f8e1b04..49c45913614 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index 0b5717aa403..9fcd2b56381 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index a86eaa40985..259e34f129a 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import torch diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index 908544ff00c..814158a1d32 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import tosa_serializer as ts diff --git a/backends/arm/operators/op_sin.py b/backends/arm/operators/op_sin.py index faa249917c3..ac0ae212e78 100644 --- a/backends/arm/operators/op_sin.py +++ b/backends/arm/operators/op_sin.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import List import tosa_serializer as ts diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index c5510493eae..941f8d690c3 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index f5f82679ca8..52caa9d0f8a 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index 91c37e25f43..7be3e884275 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index c4603e90118..0799628cd7d 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import tosa_serializer as ts diff --git a/backends/arm/operators/op_to_dim_order_copy.py b/backends/arm/operators/op_to_dim_order_copy.py index 9d3aff83554..c41431a1b6d 100644 --- a/backends/arm/operators/op_to_dim_order_copy.py +++ b/backends/arm/operators/op_to_dim_order_copy.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import torch diff --git a/backends/arm/operators/op_tosa_conv2d.py b/backends/arm/operators/op_tosa_conv2d.py index 918db443ef9..0e10867da7e 100644 --- a/backends/arm/operators/op_tosa_conv2d.py +++ b/backends/arm/operators/op_tosa_conv2d.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import tosa_serializer as ts """Provide a visitor for lowering 2D convolution to TOSA (INT/FP).""" diff --git a/backends/arm/operators/op_tosa_depthwise_conv2d.py b/backends/arm/operators/op_tosa_depthwise_conv2d.py index 78e6e4424cb..1d1c317a0b8 100644 --- a/backends/arm/operators/op_tosa_depthwise_conv2d.py +++ b/backends/arm/operators/op_tosa_depthwise_conv2d.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + """Provide a visitor for lowering 2D depthwise convolution to TOSA (INT/FP).""" import tosa_serializer as ts diff --git a/backends/arm/operators/op_tosa_matmul.py b/backends/arm/operators/op_tosa_matmul.py index 2281564a0c4..e88ef9be55d 100644 --- a/backends/arm/operators/op_tosa_matmul.py +++ b/backends/arm/operators/op_tosa_matmul.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe """Provide a visitor for lowering batched matmul (BMM) to TOSA.""" from typing import Any, List diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index 75268938579..26f27370bdd 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, cast, List diff --git a/backends/arm/operators/op_tosa_resize.py b/backends/arm/operators/op_tosa_resize.py index fb8e305839f..60328a3b3ab 100644 --- a/backends/arm/operators/op_tosa_resize.py +++ b/backends/arm/operators/op_tosa_resize.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List import torch diff --git a/backends/arm/operators/op_tosa_table.py b/backends/arm/operators/op_tosa_table.py index 11407517b6a..9572e49781f 100644 --- a/backends/arm/operators/op_tosa_table.py +++ b/backends/arm/operators/op_tosa_table.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_tosa_transpose.py b/backends/arm/operators/op_tosa_transpose.py index bbd9252f8f8..2159a67b285 100644 --- a/backends/arm/operators/op_tosa_transpose.py +++ b/backends/arm/operators/op_tosa_transpose.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index f13c386a5ee..c6f6e36b6e9 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, cast, List import torch diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 3e8cda76b5a..360e15a0ad2 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, Callable, List diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index d570c52ed31..994b43a7c15 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Any, List From efae4a8c727aed478e6187d29d7e1fa3ad47d4e4 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 06/19] Revert "Arm backend: Align operator_validation_utils docstrings with backend (#15369)" This reverts commit ce6b6e5af965eb1a3ef5ba2bf89e6e66f9fe906f. --- .../operators/operator_validation_utils.py | 192 ++++++++++-------- 1 file changed, 112 insertions(+), 80 deletions(-) diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py index 32c01143f4f..9419e116789 100644 --- a/backends/arm/operators/operator_validation_utils.py +++ b/backends/arm/operators/operator_validation_utils.py @@ -2,42 +2,46 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Provide validation helpers for operator inputs and dtypes. - -Use these utilities to validate input counts, ensure dtype consistency, check -allowed dtypes, and compute pooling padding adjustments. - -""" from math import ceil, floor from typing import Any, List, Optional def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[int]): - """Validate the number of inputs against expected values. + """ + Validates the number of inputs provided to an operation against expected values. + + This function checks whether the length of the input list matches the expected + number(s) of inputs. + + Parameters: + ----------- + op_name : str + The name of the operation for which the inputs are being validated. + Used in the error message to provide context. - This function checks whether the length of the input list matches the - expected number(s) of inputs. + inputs : List[TosaArg] + A list of inputs to be validated, where each input is assumed to be an + instance of `TosaArg`. - Args: - op_name (str): The name of the operation for which the inputs are being - validated. Used in the error message to provide context. - inputs (List[TosaArg]): A list of inputs to be validated, where each - input is assumed to be an instance of ``TosaArg``. - expected (int | List[int]): The expected number of inputs. Can be either - an integer or a list of integers. + expected : int or List[int] + The expected number of inputs. Can be either an integer or a list of integers. Raises: - ValueError: If the number of inputs does not match the expected - value(s); the message indicates the operation name and the mismatch - in expected versus provided counts. + ------- + ValueError + If the number of inputs does not match the expected value(s), a `ValueError` is + raised with a message indicating the operation name and the mismatch in expected + versus provided number of inputs. Example: - from executorch.backends.arm.operators.operator_validation_utils import \ - validate_num_inputs - - validate_num_inputs(self.target, inputs, [3, 4]) + -------- + # Example usage: + from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + ) + validate_num_inputs(self.target, inputs, [3, 4]) """ if isinstance(expected, int): expected = [expected] @@ -50,28 +54,39 @@ def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[in def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = None): - """Validate that all given tensors have the same dtype. + """ + Validates that all given tensors have the same dtype attribute. + + This function checks whether all items in the `tensors` list have the same + `dtype` as the first item. + + Parameters: + ----------- + op_name : str + The name of the operation for which the dtype validation is being performed. + Used in the error message to provide context. - This function checks whether all items in the ``tensors`` list have the - same ``dtype`` as the first item. + tensors : List[Any] + A list of tensors to be validated, each is assumed to have a `dtype` attribute. - Args: - op_name (str): The name of the operation for which the dtype validation - is being performed. Used in the error message to provide context. - tensors (List[Any]): A list of tensors to be validated, each assumed to - have a ``dtype`` attribute. - ts (Optional[Any]): TOSA serializer (optional) to improve readability of - dtype names in error messages. + ts: Optional[Any] + TOSA serializer. Not required but only to get clearer error messages. Raises: - ValueError: If the dtype of any item in the list does not match the - dtype of the first item, or if the list is empty. + ------- + ValueError + If the dtype of any item in the list does not match the dtype of the first item, + a `ValueError` is raised with a message indicating the operation name and the + mismatch in dtypes. Example: - from executorch.backends.arm.operators.operator_validation_utils import \ - validate_same_dtype + -------- + # Example usage: + from executorch.backends.arm.operators.operator_validation_utils import ( + validate_same_dtype, + ) - validate_same_dtype(self.target, [input1, input2, output]) + validate_same_dtype(self.target, [input1, input2, output]) """ if not tensors: @@ -95,40 +110,48 @@ def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = No def validate_valid_dtype( op_name: str, tensors: Any | List[Any], valid_dtypes: Any | List[Any], tosa_spec ): - """Validate that one or more tensors have allowed dtypes. - - This function checks whether the ``dtype`` attribute of the provided - tensor(s) is one of the valid dtype values. It supports checking a single - tensor or a list of tensors. - - Args: - op_name (str): The name of the operation performing the validation. - tensors (Any | List[Any]): A tensor or list of tensors (each assumed to - have ``dtype`` and ``name`` attributes) whose dtype will be - validated. - valid_dtypes (Any | List[Any]): A dtype enum or list of dtype enums - representing allowed dtype values. - tosa_spec (Any): A TosaSpecification instance indicating which TOSA - version is targeted. This determines which serializer to use for - dtype name resolution. + """ + Validates that one or more tensors have dtypes within a set of allowed dtypes. + + This function checks whether the `dtype` attribute of the provided tensor(s) is one + of the valid dtype values. It supports checking a single tensor or a list of + tensors. + + Parameters: + ----------- + op_name : str + The name of the operation performing the validation. + tensors : Any or List[Any] + A tensor or list of tensors (each assumed to have `dtype` and `name` attributes) + whose dtype will be validated. + valid_dtypes : Any or List[Any] + A dtype enum or list of dtype enums representing allowed dtype values. + tosa_spec : Any + A TosaSpecification instance indicating which TOSA version is targeted. This + determines which serializer to use for dtype name resolution. Raises: - ValueError: If no tensors are provided, or if any tensor has a dtype not - in ``valid_dtypes``. + ------- + ValueError + If no tensors are provided, or if any tensor has a dtype not in `valid_dtypes`. Example: - from executorch.backends.arm.operators.operator_validation_utils import \ - validate_valid_dtype - import serializer.tosa_serializer as ts - - validate_valid_dtype( - self.target, - [*inputs, output], - [ts.DType.INT8, ts.DType.INT32], - output.tosa_spec, - ) + -------- + # Example usage: + from executorch.backends.arm.operators.operator_validation_utils import ( + validate_valid_dtype, + ) + + + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.INT8, ts.DType.INT32], + output.tosa_spec, + ) """ + if not tensors: raise ValueError( f"{op_name}: Input tensor list is empty, cannot validate dtypes" @@ -153,27 +176,36 @@ def validate_valid_dtype( def adjust_pooling_pad_if_needed( input_size: int, kernel_size: int, stride: int, pad: int, ceil_mode: bool ) -> int: - """Compute the post padding needed for pooling. + """ + The Aten pooling ops has one value 'pad' per dimension to specify padding, but they + do not require input and output sizes to match up perfectly. Instead, the output + size is rounded up or down depending on ceil_mode, and padding at the end of the + input is automatically added or removed. TOSA on the other hand specifies two + padding values, one for pre-padding and one for post-padding, and these must satisfy - ATen pooling uses a single symmetric ``pad`` per dimension and rounds the - output size up or down depending on ``ceil_mode``. TOSA requires distinct - pre- and post-padding values that satisfy: + output_size = (input_size + pre_pad + post_pad - kernel_size) / stride + 1 - output_size == (input_size + pre_pad + post_pad - kernel_size) / stride + 1 + This function returns the post_pad value required to satisfy the above condition. - This function returns the required ``post_pad`` given a symmetric ``pad``. + Parameters: + ----------- + input_size : int + The size of the input to the operator. - Args: - input_size (int): Input size. - kernel_size (int): Kernel size. - stride (int): Stride size. - pad (int): Symmetric padding specified by ATen. - ceil_mode (bool): Use ceil when computing output size. + kernel_size : int + The size of the kernel. - Returns: - int: Post-padding to satisfy the TOSA formula. + stride : int + The size of the stride. + pad : int + The amount of padding. + + Output: + ------- + An int, giving the post-padding to use for the """ + if ceil_mode: output_size = ceil((input_size - kernel_size + 2 * pad) / stride) + 1 else: From 3a50a0be20c68052d87e27df0670884ceddf46ef Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 07/19] Revert "Arm backend: Deprecate internal models using aot_arm_compiler (#15302)" This reverts commit cea66e350dcd775627c92da1392e4eaed5683b46. --- examples/arm/aot_arm_compiler.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 1348542de07..db248f0bf56 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -91,11 +91,6 @@ def _load_internal_model( model_name: str, example_inputs: Any ) -> Optional[Tuple[torch.nn.Module, Any]]: """Load a bundled example model from the internal `MODELS` mapping.""" - logging.info( - "Loading internal models is deprecated. Use --model_name .py/.pt " - "or a model from examples/models." - ) - if model_name not in MODELS: return None @@ -447,7 +442,7 @@ def get_args(): "-m", "--model_name", required=True, - help=f"Model file .py/.pth/.pt or a model from examples/models. Valid names: {set(MODEL_NAME_TO_MODEL.keys())}", + help=f"Model file .py/.pth/.pt, builtin model or a model from examples/models. Valid names: {set(list(MODELS.keys()) + list(MODEL_NAME_TO_MODEL.keys()))}", ) parser.add_argument( "--model_input", From e83e676460c63dd0cd8c53f90670c25ecfdd6c37 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 08/19] Revert "Arm backend: Remove pyre-unsafe from quantizer/ & backends/arm (#15374)" This reverts commit 80b4a841563cd6a4ab9db2dff91b8bda0aeab576. --- backends/arm/arm_vela.py | 1 + backends/arm/process_node.py | 1 + backends/arm/quantizer/arm_quantizer.py | 1 + backends/arm/quantizer/arm_quantizer_utils.py | 1 + backends/arm/quantizer/quantization_config.py | 1 + 5 files changed, 5 insertions(+) diff --git a/backends/arm/arm_vela.py b/backends/arm/arm_vela.py index 1ecaca3c454..5e2af9c5f39 100644 --- a/backends/arm/arm_vela.py +++ b/backends/arm/arm_vela.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import os import struct diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 54797b825b7..7dd8f9a7d38 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. # +# pyre-unsafe from typing import Any, cast, Dict import numpy as np diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index cad714607e2..e6b1358e7e0 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -5,6 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe # # Quantizer for Arm backend diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index c1137ea4149..90876386aa6 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -5,6 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe """Provide utilities for quantization annotations. Use these helpers to check and mark annotation state when working with diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index bb1c8ec51cd..7495ff22ac6 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -11,6 +11,7 @@ """ +# pyre-unsafe from dataclasses import dataclass From ddfa961f95256610f9f0e6a0819b1eed3a7398a1 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 09/19] Revert "Arm backend: Add docstrings for operator_support/embedding_support.py (#15373)" This reverts commit ef04360e0dc73a0427f11106a2aecef41ffab6a5. --- .../arm/operator_support/embedding_support.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/backends/arm/operator_support/embedding_support.py b/backends/arm/operator_support/embedding_support.py index 3ad17012cbb..24395d56cbf 100644 --- a/backends/arm/operator_support/embedding_support.py +++ b/backends/arm/operator_support/embedding_support.py @@ -2,12 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Declare operator support for ``aten.embedding`` in TOSA. -Permit embeddings with int32 indices (TOSA lacks int64 support); other dtypes -are rejected by this check. - -""" import torch @@ -22,8 +17,6 @@ @register_tosa_support_check class EmbeddingSupported(SupportedTOSAOperatorCheck): - """Provide TOSA support check for ``aten.embedding``.""" - targets = [exir_ops.edge.aten.embedding.default] tosa_specs = [ @@ -34,20 +27,16 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] - """Return True if the node is supported by TOSA. + # Note aten.embedding.default requires int64 indices and TOSA does not + # support it. Int32 indices here for aten.embedding.default is ok since + # it will be decomposed into ops that can handle it. - PyTorch's ``aten.embedding`` typically takes int64 indices, but for - TOSA we only allow int32 indices. The export path decomposes the op so - that int32 indices are ok. - - """ if len(node.all_input_nodes) != 2: self.reporter.report_reject( node, (f"Expected exactly two input nodes, got {len(node.all_input_nodes)}"), ) return False - indices_val = node.all_input_nodes[1].meta["val"] indices_dtype = indices_val.dtype From a4c3cd7793776d08f2b7f9eafdda87c422b3b041 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 10/19] Revert "Arm backend: Merge passes that replace scalars (#15298)" This reverts commit de56c8176e46f513a00f2031d9aea549b2661a84. --- backends/arm/_passes/__init__.py | 3 +- backends/arm/_passes/arm_pass_manager.py | 9 +-- backends/arm/_passes/decompose_acosh_pass.py | 4 +- .../_passes/decompose_asin_and_acos_pass.py | 4 +- backends/arm/_passes/decompose_asinh_pass.py | 4 +- backends/arm/_passes/decompose_atan_pass.py | 4 +- backends/arm/_passes/decompose_atanh_pass.py | 4 +- backends/arm/_passes/decompose_cosh_pass.py | 4 +- backends/arm/_passes/decompose_expm1_pass.py | 4 +- backends/arm/_passes/decompose_logit_pass.py | 4 +- backends/arm/_passes/decompose_sinh_pass.py | 4 +- .../replace_scalar_with_tensor_pass.py | 56 ++++--------------- 12 files changed, 37 insertions(+), 67 deletions(-) diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 55daf92a5a9..de9a793b9aa 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -88,7 +88,8 @@ from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa from .remove_noop_pass import RemoveNoopPass # noqa from .replace_scalar_with_tensor_pass import ( # noqa - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSABI, + ReplaceScalarWithTensorArgPassTOSAMI, ) from .rewrite_conv2d_pass import RewriteConv2dPass # noqa from .rewrite_matmul import RewriteMatmulPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b1eea847792..1cda9917037 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -87,7 +87,8 @@ QuantizeOperatorArguments, RemoveNoopPass, ReplaceInfValues, - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSABI, + ReplaceScalarWithTensorArgPassTOSAMI, RetraceFoldedDtypesPass, RewriteConv2dPass, RewriteMatmulPass, @@ -171,7 +172,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(CastToInt32Pass()) self.add_pass(CastBoolToInt8Pass()) - self.add_pass(ReplaceScalarWithTensorByProfilePass()) + self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) self.add_pass(ConvertELUParamsPass()) @@ -241,7 +242,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(DecomposeSinhPass()) self.add_pass(DecomposeSignPass()) self.add_pass(DecomposeDivTensorModePass()) - self.add_pass(ReplaceScalarWithTensorByProfilePass()) + self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) @@ -334,7 +335,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeAddmmPass()) self.add_pass(DecomposeDivTensorModePass()) self.add_pass(DecomposeAddSubAlphaPass()) - self.add_pass(ReplaceScalarWithTensorByProfilePass()) + self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(ScalarsToAttributePass()) self.add_pass(DecomposeGroupNormPass()) self.add_pass(DecomposeLayerNormPass()) diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py index 8b10cccb913..8f9ae76817c 100644 --- a/backends/arm/_passes/decompose_acosh_pass.py +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -12,7 +12,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -32,7 +32,7 @@ class DecomposeAcoshPass(ArmPass): DecomposeSqrtPass, InsertTableOpsPass, MatchArgRanksPass, - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, MatchArgDtypePass, } diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index 3a0f87af835..734ca0fdb41 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -19,7 +19,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -71,7 +71,7 @@ class DecomposeAsinAndAcosPass(ArmPass): ConvertFullLikeToFullPass, MatchArgRanksPass, MatchArgDtypePass, - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, } def _build_polynomial( diff --git a/backends/arm/_passes/decompose_asinh_pass.py b/backends/arm/_passes/decompose_asinh_pass.py index 7ffe75cd255..52ef4b91c2e 100644 --- a/backends/arm/_passes/decompose_asinh_pass.py +++ b/backends/arm/_passes/decompose_asinh_pass.py @@ -12,7 +12,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -32,7 +32,7 @@ class DecomposeAsinhPass(ArmPass): DecomposeSqrtPass, InsertTableOpsPass, MatchArgRanksPass, - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, MatchArgDtypePass, } diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py index 6f1adccd257..03ed62e7870 100644 --- a/backends/arm/_passes/decompose_atan_pass.py +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -12,7 +12,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -47,7 +47,7 @@ class DecomposeAtanPass(ArmPass): InsertTableOpsPass, MatchArgRanksPass, MatchArgDtypePass, - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, } def _rational_approximation(self, z, ops, meta): diff --git a/backends/arm/_passes/decompose_atanh_pass.py b/backends/arm/_passes/decompose_atanh_pass.py index 1a41e77eacc..2c8347e7e9f 100644 --- a/backends/arm/_passes/decompose_atanh_pass.py +++ b/backends/arm/_passes/decompose_atanh_pass.py @@ -10,7 +10,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -43,7 +43,7 @@ class DecomposeAtanhPass(ArmPass): InsertTableOpsPass, MatchArgRanksPass, MatchArgDtypePass, - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, } def call_operator(self, op, args, kwargs, meta): diff --git a/backends/arm/_passes/decompose_cosh_pass.py b/backends/arm/_passes/decompose_cosh_pass.py index 6716ba499ad..cbfbd5783e2 100644 --- a/backends/arm/_passes/decompose_cosh_pass.py +++ b/backends/arm/_passes/decompose_cosh_pass.py @@ -10,7 +10,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -31,7 +31,7 @@ class DecomposeCoshPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = { InsertTableOpsPass, MatchArgRanksPass, - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, MatchArgDtypePass, } diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index 0fe95d37ba2..5de03cbf102 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -12,7 +12,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -83,7 +83,7 @@ class DecomposeExpm1Pass(ArmPass): ConvertIntPowToMuls, InsertTableOpsPass, DecomposeDivPass, - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, MatchArgDtypePass, MatchArgRanksPass, } diff --git a/backends/arm/_passes/decompose_logit_pass.py b/backends/arm/_passes/decompose_logit_pass.py index 69a250b41cb..213b8f038e8 100644 --- a/backends/arm/_passes/decompose_logit_pass.py +++ b/backends/arm/_passes/decompose_logit_pass.py @@ -12,7 +12,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -73,7 +73,7 @@ class DecomposeLogitPass(ArmPass): InsertTableOpsPass, MatchArgRanksPass, MatchArgDtypePass, - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, } def call_operator(self, op, args, kwargs, meta): diff --git a/backends/arm/_passes/decompose_sinh_pass.py b/backends/arm/_passes/decompose_sinh_pass.py index 772cc7c4741..acb18df3134 100644 --- a/backends/arm/_passes/decompose_sinh_pass.py +++ b/backends/arm/_passes/decompose_sinh_pass.py @@ -11,7 +11,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -36,7 +36,7 @@ class DecomposeSinhPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = { InsertTableOpsPass, MatchArgRanksPass, - ReplaceScalarWithTensorByProfilePass, + ReplaceScalarWithTensorArgPassTOSAMI, MatchArgDtypePass, } diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index 579ac825e9e..b7715654e7c 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -7,8 +7,6 @@ from typing import Dict, Set, Type, Union import torch - -from executorch.backends.arm.tosa.specification import get_context_spec from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) @@ -17,8 +15,6 @@ from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass -from .arm_pass import ArmPass - # Operators that are included for both TOSA profiles _common_ops: Dict[ @@ -57,51 +53,23 @@ torch.ops.aten.bitwise_xor.Scalar: torch.ops.aten.bitwise_xor.Tensor, } -_fp_profile_ops: Dict[ - Union[EdgeOpOverload, torch._ops.OpOverload], - Union[EdgeOpOverload, torch._ops.OpOverload], -] = _common_ops | { - exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor, - torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor, -} -_int_profile_ops: Dict[ - Union[EdgeOpOverload, torch._ops.OpOverload], - Union[EdgeOpOverload, torch._ops.OpOverload], -] = _common_ops +class ReplaceScalarWithTensorArgPassTOSAMI(ReplaceScalarWithTensorArgPass): + _passes_required_after: Set[Type[ExportPass]] = set() -_all_ops: Dict[ - Union[EdgeOpOverload, torch._ops.OpOverload], - Union[EdgeOpOverload, torch._ops.OpOverload], -] = ( - _fp_profile_ops | _int_profile_ops -) + scalar_to_tensor_ops = _common_ops | { + exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor, + torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor, + } + def __init__(self): + super().__init__(self.scalar_to_tensor_ops) -class ReplaceScalarWithTensorByProfilePass(ReplaceScalarWithTensorArgPass, ArmPass): - """Profile-aware scalar-to-tensor replacement pass for binary ops.""" +class ReplaceScalarWithTensorArgPassTOSABI(ReplaceScalarWithTensorArgPass): _passes_required_after: Set[Type[ExportPass]] = set() - def __init__(self): - # Initialize base (ReplaceScalarWithTensorArgPass) with the full - # superset which will make the superclass handle ops in _all_ops. - # Actual selection is done per-call in call_operator. - super().__init__(_all_ops) - - def call_operator(self, op, args, kwargs, meta): - tosa_spec = get_context_spec() + scalar_to_tensor_ops = _common_ops - if tosa_spec.support_integer(): - included_ops = _int_profile_ops - elif tosa_spec.support_float(): - included_ops = _fp_profile_ops - else: - raise ValueError("Profile must support either INT or FP") - - if op in included_ops: - # Include this op based on the current profile. - return super().call_operator(op, args, kwargs, meta) - else: - # Do not handle; forward unchanged. - return ExportPass.call_operator(self, op, args, kwargs, meta) + def __init__(self): + super().__init__(self.scalar_to_tensor_ops) From 16f7f7a36fd89c97e74ed74e51920853fdc2ddb3 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 11/19] Revert "Arm backend: Use reshape instead of view before edge (#15269)" This reverts commit 69f79b926862c1bf8ef3b7f339238bb9955398f8. --- .../arm/_passes/decompose_embedding_pass.py | 2 +- .../arm/_passes/decompose_groupnorm_pass.py | 2 +- .../arm/_passes/decompose_layernorm_pass.py | 2 +- .../arm/_passes/decompose_meandim_pass.py | 2 +- backends/arm/_passes/decompose_sum_pass.py | 2 +- backends/arm/test/ops/test_embedding.py | 26 ++----------------- .../passes/test_decompose_meandim_pass.py | 6 ++--- 7 files changed, 10 insertions(+), 32 deletions(-) diff --git a/backends/arm/_passes/decompose_embedding_pass.py b/backends/arm/_passes/decompose_embedding_pass.py index a87b26366d7..2633145bc6c 100644 --- a/backends/arm/_passes/decompose_embedding_pass.py +++ b/backends/arm/_passes/decompose_embedding_pass.py @@ -42,7 +42,7 @@ class DecomposeEmbeddingPass(ArmPass): def get_decomposition(self, op): if op in self.aten_ops: return ( - torch.ops.aten.reshape.default, + torch.ops.aten.view_copy.default, torch.ops.aten.index_select.default, ) diff --git a/backends/arm/_passes/decompose_groupnorm_pass.py b/backends/arm/_passes/decompose_groupnorm_pass.py index ecd4ecc23a4..595fd0e269f 100644 --- a/backends/arm/_passes/decompose_groupnorm_pass.py +++ b/backends/arm/_passes/decompose_groupnorm_pass.py @@ -39,7 +39,7 @@ def get_group_norm_decomposition(op) -> tuple: torch.ops.aten.add.Tensor, torch.ops.aten.rsqrt.default, torch.ops.aten.mul.Tensor, - torch.ops.aten.reshape.default, + torch.ops.aten.view_copy.default, ) raise RuntimeError(f"Can't get group_norm composition for op {op}") diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index 7623e410cf9..2abca6e008d 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -39,7 +39,7 @@ def get_layer_norm_decomposition(op) -> tuple: torch.ops.aten.add.Tensor, torch.ops.aten.rsqrt.default, torch.ops.aten.mul.Tensor, - torch.ops.aten.reshape.default, + torch.ops.aten.view_copy.default, ) raise RuntimeError(f"Can't get layer_norm composition for op {op}") diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 2ec7497ae82..135e2830d5d 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -46,7 +46,7 @@ def get_view(op): if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): return exir_ops.edge.aten.view_copy.default if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): - return torch.ops.aten.reshape.default + return torch.ops.aten.view_copy.default raise RuntimeError(f"Can't get meandim decomposition for op {op}") diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index 589dcfcefa7..989d299a7e8 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -19,7 +19,7 @@ def _get_sum_decomp(op): exir_ops.edge.aten.sum.dim_IntList, ) case torch.ops.aten.sum.dim_IntList: - return (torch.ops.aten.reshape.default, torch.ops.aten.sum.dim_IntList) + return (torch.ops.aten.view_copy.default, torch.ops.aten.sum.dim_IntList) case _: raise RuntimeError("Unvalid op in DecomposeSumPass") diff --git a/backends/arm/test/ops/test_embedding.py b/backends/arm/test/ops/test_embedding.py index 23b14ae5c44..901fbbc0916 100644 --- a/backends/arm/test/ops/test_embedding.py +++ b/backends/arm/test/ops/test_embedding.py @@ -27,17 +27,10 @@ def forward(self, weights: torch.Tensor, indices: torch.Tensor): return torch.embedding(weights, indices) -class ExpandEmbedding(Embedding): - example_inputs = (torch.randn(10, 3), torch.tensor([[1, 2, 3]], dtype=torch.int32)) - - def forward(self, weights: torch.Tensor, indices: torch.Tensor): - return torch.embedding(weights, indices.expand(2, 3)) - - -input_params = Tuple[torch.Tensor, torch.Tensor] +input_params = Tuple[torch.Tensor, torch.Tensor, torch.dtype] -test_input: dict[str, input_params] = { +test_input: dict[input_params] = { "test_1": ( torch.randn(10, 3), torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32), @@ -96,21 +89,6 @@ def test_embedding_tosa_INT(test_input: input_params): pipeline.run() -def test_expand_embedding_tosa_INT(): - op = ExpandEmbedding() - pipeline = TosaPipelineINT( - op, - ExpandEmbedding.example_inputs, - ExpandEmbedding.aten_op, - ExpandEmbedding.exir_op, - use_to_edge_transform_and_lower=True, - ) - pipeline.pop_stage("check.aten") - pipeline.pop_stage("check_count.exir") - - pipeline.run() - - @pytest.mark.skip("reason=MLETORCH-1274 Improve data type checks during partitioning") @common.parametrize("test_input", test_input) @common.SkipIfNoModelConverter diff --git a/backends/arm/test/passes/test_decompose_meandim_pass.py b/backends/arm/test/passes/test_decompose_meandim_pass.py index e771d74b5c4..22dda5d9244 100644 --- a/backends/arm/test/passes/test_decompose_meandim_pass.py +++ b/backends/arm/test/passes/test_decompose_meandim_pass.py @@ -28,7 +28,7 @@ class MeanDim(torch.nn.Module): } ops_not_after_pass = u55_ops_not_after_pass = [ - "torch.ops.aten.reshape.default", + "torch.ops.aten.view_copy.default", "torch.ops.aten.avg_pool2d.default", "torch.ops.aten.mean.dim", ] @@ -52,7 +52,7 @@ class MeanDimTensor(torch.nn.Module): "torch.ops.aten.sum.dim_IntList": 2, "torch.ops.aten.mul.Tensor": 1, "torch.ops.aten.avg_pool2d.default": 1, - "torch.ops.aten.reshape.default": 1, + "torch.ops.aten.view_copy.default": 1, } ops_not_after_pass = [ @@ -62,7 +62,7 @@ class MeanDimTensor(torch.nn.Module): u55_ops_after_pass = { "torch.ops.aten.sum.dim_IntList": 2, "torch.ops.aten.mul.Tensor": 1, - "torch.ops.aten.reshape.default": 1, + "torch.ops.aten.view_copy.default": 1, } u55_ops_not_after_pass = [ From a71332dc761204b5dbf9d39c4d24e8ea7db1bd30 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 12/19] Revert "Arm backend: Fix arg-type MyPy errors (#15367)" This reverts commit c66078cc67df7c589cdf521d15a2fdacfa0ab013. --- .../arm/_passes/annotate_decomposed_matmul.py | 8 ++--- backends/arm/_passes/arm_pass_utils.py | 9 ++--- .../arm/_passes/scalars_to_attribute_pass.py | 4 +-- .../arm/_passes/to_tosa_memory_format_pass.py | 8 +---- backends/arm/common/type.py | 28 --------------- .../operator_support/index_tensor_support.py | 7 ++-- .../tosa_supported_operators.py | 8 +++-- backends/arm/quantizer/arm_quantizer.py | 2 +- .../arm/quantizer/quantization_annotator.py | 34 ++++++++----------- backends/arm/test/tester/arm_tester.py | 6 ++-- backends/arm/tosa/partitioner.py | 4 +-- 11 files changed, 37 insertions(+), 81 deletions(-) delete mode 100644 backends/arm/common/type.py diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index f378802d2c0..03e41c8dfc5 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -51,7 +51,7 @@ def _match_partition_to_node( raise RuntimeError(f"Cannot find an input node which matches, {node}.") def call(self, graph_module: GraphModule) -> PassResult: - matmul_partitions_map = get_source_partitions( + matmul_partitions = get_source_partitions( graph_module.graph, [ torch.matmul, @@ -60,7 +60,7 @@ def call(self, graph_module: GraphModule) -> PassResult: None, ) matmul_partitions = list( - itertools.chain.from_iterable(matmul_partitions_map.values()) + itertools.chain.from_iterable(matmul_partitions.values()) ) matmul_targets = { exir_ops.edge.aten.bmm.default, @@ -88,7 +88,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # Create new dq-node before matmul dq_node = create_node( graph=graph_module.graph, - op_target=cast(EdgeOpOverload, input_node.target), + op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type] ) dq_node.args = (node, *input_node.args[1:]) matmul_node.replace_input_with(node, dq_node) @@ -109,7 +109,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # Create q-node after matmul q_node = create_node( graph=graph_module.graph, - op_target=cast(EdgeOpOverload, partition_output.target), + op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type] ) matmul_node.replace_all_uses_with(q_node) q_node.args = (matmul_node, *partition_output.args[1:]) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index de42c961d08..b88b0ba7e78 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -13,10 +13,8 @@ import torch import torch.fx from executorch.backends.arm.common.debug import get_node_debug_info -from executorch.backends.arm.common.type import ensure_type from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._export.utils import ( get_buffer, @@ -83,18 +81,17 @@ def get_param_tensor( elif is_lifted_tensor_constant(exp_prog, node): return get_lifted_tensor_constant(exp_prog, node) elif is_get_attr_node(node): - target_node = ensure_type(str, node.target) # This is a hack to support both lifted and unlifted graph try: - return getattr(node.graph.owning_module, target_node) + return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type] except AttributeError: - return getattr(exp_prog.graph_module, target_node) + return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type] raise RuntimeError(f"unsupported param type, {node.op}.") def create_node( graph: torch.fx.Graph, - op_target: OpOverload | EdgeOpOverload, + op_target: OpOverload, args: tuple = (), kwargs: Optional[dict] = None, quantize: bool = False, diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index ddef9c75213..8a1cd91593f 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -49,7 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult: shape = get_first_fake_tensor(arg).shape biggest_rank = max(biggest_rank, len(shape)) - new_args: list[Node | int] = [] + new_args = [] for arg in n.args: if isinstance(arg, Node): new_args.append(arg) @@ -57,7 +57,7 @@ def call(self, graph_module: GraphModule) -> PassResult: if isinstance(arg, int) and not torch.is_floating_point( get_first_fake_tensor(n) ): - new_args.append(arg) + new_args.append(arg) # type: ignore[arg-type] continue prefix = "_tensor_constant_" diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 956eb77b62c..2e6db91640d 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -259,19 +259,13 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): # Transpose outputs if they are in (N)NCHW format outputs = output_node.args[0] - if not isinstance(outputs, (list, tuple)): - raise TypeError( - f"Expected output node args to be a list or tuple, got {type(outputs)}" - ) output_dim_orders = output_node.meta.get("original_dim_orders") if output_dim_orders is None: raise RuntimeError( f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}." ) - for output_node_input, output_dim_order in zip( - outputs, output_dim_orders, strict=True - ): + for output_node_input, output_dim_order in zip(outputs, output_dim_orders): # type: ignore[arg-type] if output_dim_order in ( NCHW_ORDER, NNCHW_ORDER, diff --git a/backends/arm/common/type.py b/backends/arm/common/type.py deleted file mode 100644 index e53dc1ee769..00000000000 --- a/backends/arm/common/type.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -"""Type checking utilities.""" - -from typing import TypeVar - -T = TypeVar("T") - - -def ensure_type(expected_type: type[T], arg: object) -> T: - """Ensure that the argument is of the expected type. - - Args: - expected_type (type[T]): The expected type. - arg (object): The argument to check. - - Returns: - T: The argument, if it is of the expected type. - - """ - if isinstance(arg, expected_type): - return arg - - expected_name = getattr(expected_type, "__name__", str(expected_type)) - actual_name = type(arg).__name__ - raise TypeError(f"Expected value of type {expected_name}, got {actual_name!r}") diff --git a/backends/arm/operator_support/index_tensor_support.py b/backends/arm/operator_support/index_tensor_support.py index 5de70c0a2de..92b0ce48a32 100644 --- a/backends/arm/operator_support/index_tensor_support.py +++ b/backends/arm/operator_support/index_tensor_support.py @@ -14,7 +14,6 @@ import torch import torch.fx as fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor -from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.operator_support.tosa_supported_operators import ( register_tosa_support_check, SupportedTOSAOperatorCheck, @@ -138,8 +137,7 @@ def is_node_tosa_supported( return False # Usage 1 guard - index = ensure_type(torch.fx.Node, index) - fake_tensor = get_first_fake_tensor(index) + fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type] if len(fake_tensor.size()) > 3: self.reporter.report_reject( node, @@ -148,8 +146,7 @@ def is_node_tosa_supported( return False # Usage 3 guard - input_node = ensure_type(torch.fx.Node, node.args[0]) - total_vals = math.prod(get_first_fake_tensor(input_node).shape) + total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type] if total_vals > torch.iinfo(torch.int32).max: self.reporter.report_reject( node, diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index ba479818a81..f7857894d40 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -219,7 +219,7 @@ def _is_matmul_node_supported( """ for graph_module in submodules.values(): graph_module = typing.cast(fx.GraphModule, graph_module) - matmul_partitions_map = get_source_partitions( + matmul_partitions = get_source_partitions( graph_module.graph, [ torch.matmul, @@ -228,7 +228,7 @@ def _is_matmul_node_supported( None, ) matmul_partitions = list( - itertools.chain.from_iterable(matmul_partitions_map.values()) + itertools.chain.from_iterable(matmul_partitions.values()) ) matched_partition = None for partition in matmul_partitions: @@ -406,7 +406,9 @@ def is_node_supported( if input_node.target in ComputeConstantOpsAOT.targeted_ops: # This is not perfect since the input_node can still be rejected by other checks but # this should cover the majority of cases. - if self.is_node_supported({}, input_node): + if self.is_node_supported( + None, input_node # type: ignore[arg-type] #(we don't use 'submodules') + ): continue self.reporter.report_reject( node, f"Non-constant int64 input {input_node.name}" diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index e6b1358e7e0..2b0b028c5e4 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -374,7 +374,7 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule: # TODO: Fix the need to lazily import this. from executorch.backends.arm._passes import ArmPassManager - return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline( + return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type] graph_module=model ) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index ee7003aacb8..b429bacd738 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -12,7 +12,6 @@ import torch.fx import torch.nn.functional as F from executorch.backends.arm.common.debug import get_node_debug_info -from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.quantizer import QuantizationConfig from torch._subclasses import FakeTensor @@ -511,8 +510,7 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.minimum.default, torch.ops.aten.maximum.default, ): - lhs_node = ensure_type(Node, node.args[0]) - shared_qspec = SharedQuantizationSpec((lhs_node, node)) + shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type] quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty( @@ -522,24 +520,22 @@ def any_or_hardtanh_min_zero(n: Node): ] quant_properties.quant_output = _QuantProperty(0, shared_qspec) elif node.target in (torch.ops.aten.where.self,): - true_node = ensure_type(Node, node.args[1]) - shared_qspec = SharedQuantizationSpec(true_node) + shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type] quant_properties.quant_inputs = [ _QuantProperty(1, shared_qspec), _QuantProperty(2, shared_qspec), ] quant_properties.quant_output = _QuantProperty(0, shared_qspec) elif node.target in _one_to_one_shared_input_or_input_act_qspec: - input_node = ensure_type(Node, node.args[0]) input_qspec = ( - SharedQuantizationSpec(input_node) - if is_output_annotated(input_node) + SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] + if is_output_annotated(node.args[0]) # type: ignore[arg-type] else input_act_qspec ) quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] quant_properties.quant_output = _QuantProperty( 0, - SharedQuantizationSpec((input_node, node)), + SharedQuantizationSpec((node.args[0], node)), # type: ignore[arg-type] ) elif node.target in ( torch.ops.aten.cat.default, @@ -554,12 +550,15 @@ def any_or_hardtanh_min_zero(n: Node): ) if len(node.args[0]) == 0: raise ValueError("Expected non-empty list for node.args[0]") - inputs = [ensure_type(Node, element) for element in node.args[0]] - shared_qspec = SharedQuantizationSpec((inputs[0], node)) + + shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) # type: ignore[arg-type] quant_properties.quant_inputs = [ _QuantProperty( 0, - [input_act_qspec if n == inputs[0] else shared_qspec for n in inputs], + [ + input_act_qspec if n == node.args[0][0] else shared_qspec # type: ignore[misc] + for n in node.args[0] + ], ) ] quant_properties.quant_output = _QuantProperty(0, shared_qspec) @@ -567,11 +566,10 @@ def any_or_hardtanh_min_zero(n: Node): quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in _one_to_one_shared_input_qspec: - input_node = ensure_type(Node, node.args[0]) quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] quant_properties.quant_output = _QuantProperty( 0, - SharedQuantizationSpec((input_node, node)), + SharedQuantizationSpec((node.args[0], node)), # type: ignore[arg-type] ) elif node.target in [ torch.ops.aten.eq.Tensor, @@ -580,8 +578,7 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.le.Tensor, torch.ops.aten.lt.Tensor, ]: - input_node = ensure_type(Node, node.args[0]) - shared_qspec = SharedQuantizationSpec((input_node, node)) + shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type] quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty( @@ -599,10 +596,9 @@ def any_or_hardtanh_min_zero(n: Node): quant_properties.quant_inputs = [] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in [operator.getitem]: - input_node = ensure_type(Node, node.args[0]) - if not is_output_annotated(input_node): + if not is_output_annotated(node.args[0]): # type: ignore[arg-type] return None - shared_qspec = SharedQuantizationSpec(input_node) + shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] quant_properties.quant_output = _QuantProperty(0, shared_qspec) else: diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 7be249609b0..44b1a7aef13 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -604,9 +604,9 @@ def run_transform_for_annotation_pipeline( # We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run. artifact = self.get_artifact(stage) if self.cur == StageType.EXPORT: - new_gm = ArmPassManager( - self.compile_spec.tosa_spec - ).transform_for_annotation_pipeline(graph_module=artifact.graph_module) + new_gm = ArmPassManager(self.compile_spec.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type] + graph_module=artifact.graph_module + ) else: raise RuntimeError("Can only run passes on Export stage.") _copy_module(artifact.graph_module, new_gm) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 1ae743101b6..58abf4e7e2e 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -22,7 +22,6 @@ from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( calculate_multiples, ) -from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.operator_support.tosa_supported_operators import ( tosa_support_factory, @@ -89,8 +88,7 @@ def is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool: if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default: return False else: - input_node = ensure_type(torch.fx.Node, node.args[0]) - return node.meta.get("dtype") == get_first_fake_tensor(input_node).dtype + return node.meta.get("dtype") == get_first_fake_tensor(node.args[0]).dtype # type: ignore[arg-type] def is_noop_expand(node: torch.fx.node.Node) -> bool: From e204ea6526d40f0004f54420ec13379985aa3c24 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 13/19] Revert "Arm backend: Remove pyre-unsafe from _passes/ (#15351)" This reverts commit 4b42bea3dc0fbf24177c6089f417a76613f72317. --- backends/arm/_passes/annotate_decomposed_matmul.py | 1 + backends/arm/_passes/arm_pass.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 ++ backends/arm/_passes/arm_pass_utils.py | 1 + backends/arm/_passes/cast_int64_pass.py | 1 + backends/arm/_passes/convert_expand_copy_to_repeat.py | 1 + backends/arm/_passes/convert_int64_const_ops_to_int32.py | 2 ++ backends/arm/_passes/convert_int64_output_ops_to_int32.py | 2 ++ backends/arm/_passes/convert_int_pow_to_mul.py | 1 + backends/arm/_passes/convert_split_to_slice.py | 1 + backends/arm/_passes/convert_squeezes_to_view.py | 1 + backends/arm/_passes/decompose_acosh_pass.py | 1 + backends/arm/_passes/decompose_asin_and_acos_pass.py | 1 + backends/arm/_passes/decompose_asinh_pass.py | 2 ++ backends/arm/_passes/decompose_batch_norm_no_stats.py | 1 + backends/arm/_passes/decompose_div_pass.py | 1 + backends/arm/_passes/decompose_div_tensor_mode.py | 1 + backends/arm/_passes/decompose_embedding_pass.py | 2 ++ backends/arm/_passes/decompose_groupnorm_pass.py | 1 + backends/arm/_passes/decompose_int16_activation_conv2d_pass.py | 1 + backends/arm/_passes/decompose_layernorm_pass.py | 1 + backends/arm/_passes/decompose_leaky_relu_pass.py | 1 + backends/arm/_passes/decompose_linear_pass.py | 1 + backends/arm/_passes/decompose_masked_fill.py | 2 ++ backends/arm/_passes/decompose_maxpool2d_with_dilation.py | 1 + backends/arm/_passes/decompose_select.py | 1 + backends/arm/_passes/decompose_silu_pass.py | 1 + backends/arm/_passes/decompose_softmax_unstable_pass.py | 1 + backends/arm/_passes/decompose_sqrt_pass.py | 1 + backends/arm/_passes/decompose_var_pass.py | 2 ++ backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py | 2 ++ backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py | 1 + backends/arm/_passes/fuse_batchnorm2d_pass.py | 1 + backends/arm/_passes/fuse_quantized_activation_pass.py | 1 + .../arm/_passes/insert_int32_casts_after_int64_placeholders.py | 2 ++ backends/arm/_passes/insert_table_ops.py | 1 + backends/arm/_passes/match_arg_ranks_pass.py | 1 + backends/arm/_passes/mm_to_bmm_pass.py | 1 + backends/arm/_passes/remove_noop_pass.py | 1 + backends/arm/_passes/replace_scalar_with_tensor_pass.py | 2 ++ backends/arm/_passes/scalars_to_attribute_pass.py | 1 + backends/arm/_passes/size_adjust_input_pass.py | 1 + backends/arm/_passes/to_tosa_memory_format_pass.py | 2 ++ backends/arm/_passes/unsqueeze_before_repeat_pass.py | 1 + backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py | 1 + 45 files changed, 56 insertions(+) diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 03e41c8dfc5..666214ec267 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import itertools import operator diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index f893eba4fc9..3cc5e3ee0c0 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import traceback from abc import abstractmethod diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 1cda9917037..98240f6dc1d 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from collections import defaultdict diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index b88b0ba7e78..71e2030958f 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -5,6 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import traceback from inspect import isclass diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index 4822c6c25c0..33d07f54af0 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import logging from typing import Set, Type diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index f932ae7f4c4..7f66a4343b9 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import logging from typing import cast, Set, Type diff --git a/backends/arm/_passes/convert_int64_const_ops_to_int32.py b/backends/arm/_passes/convert_int64_const_ops_to_int32.py index dff270fda13..798bbc6006f 100644 --- a/backends/arm/_passes/convert_int64_const_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_const_ops_to_int32.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import logging from typing import Set, Type diff --git a/backends/arm/_passes/convert_int64_output_ops_to_int32.py b/backends/arm/_passes/convert_int64_output_ops_to_int32.py index 048219198b8..7eb02493d50 100644 --- a/backends/arm/_passes/convert_int64_output_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_output_ops_to_int32.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import logging from typing import Set, Type diff --git a/backends/arm/_passes/convert_int_pow_to_mul.py b/backends/arm/_passes/convert_int_pow_to_mul.py index 2d8c72748a2..8f9b3a9cb4b 100644 --- a/backends/arm/_passes/convert_int_pow_to_mul.py +++ b/backends/arm/_passes/convert_int_pow_to_mul.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py index 94a1a3157f5..cd9f8bef2f7 100644 --- a/backends/arm/_passes/convert_split_to_slice.py +++ b/backends/arm/_passes/convert_split_to_slice.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/convert_squeezes_to_view.py b/backends/arm/_passes/convert_squeezes_to_view.py index f7b9df3b5f4..c7d02c27a36 100644 --- a/backends/arm/_passes/convert_squeezes_to_view.py +++ b/backends/arm/_passes/convert_squeezes_to_view.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py index 8f9ae76817c..509849fce4e 100644 --- a/backends/arm/_passes/decompose_acosh_pass.py +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index 734ca0fdb41..5b1c575e9c9 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import logging from math import pi diff --git a/backends/arm/_passes/decompose_asinh_pass.py b/backends/arm/_passes/decompose_asinh_pass.py index 52ef4b91c2e..088230ca4b2 100644 --- a/backends/arm/_passes/decompose_asinh_pass.py +++ b/backends/arm/_passes/decompose_asinh_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import Set, Type diff --git a/backends/arm/_passes/decompose_batch_norm_no_stats.py b/backends/arm/_passes/decompose_batch_norm_no_stats.py index ef9b9f859cd..b18bd4d9ac8 100644 --- a/backends/arm/_passes/decompose_batch_norm_no_stats.py +++ b/backends/arm/_passes/decompose_batch_norm_no_stats.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import operator from typing import Set, Type diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py index c1878e6ce0c..f2ae77514c5 100644 --- a/backends/arm/_passes/decompose_div_pass.py +++ b/backends/arm/_passes/decompose_div_pass.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index cb7ffbb33b8..07e57c60f1b 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/decompose_embedding_pass.py b/backends/arm/_passes/decompose_embedding_pass.py index 2633145bc6c..ac424230491 100644 --- a/backends/arm/_passes/decompose_embedding_pass.py +++ b/backends/arm/_passes/decompose_embedding_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import logging from math import prod diff --git a/backends/arm/_passes/decompose_groupnorm_pass.py b/backends/arm/_passes/decompose_groupnorm_pass.py index 595fd0e269f..29d68234b29 100644 --- a/backends/arm/_passes/decompose_groupnorm_pass.py +++ b/backends/arm/_passes/decompose_groupnorm_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import operator from typing import Set, Type diff --git a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py index e150d9466cb..388ce217807 100644 --- a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py +++ b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import cast diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index 2abca6e008d..c73806b0022 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import operator from typing import Set, Type diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index 61cf8d4138b..8ae13a76eb0 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py index ffe63f8cb65..70268c77a1d 100644 --- a/backends/arm/_passes/decompose_linear_pass.py +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/decompose_masked_fill.py b/backends/arm/_passes/decompose_masked_fill.py index 5a0f12348ec..8c41c1a11bc 100644 --- a/backends/arm/_passes/decompose_masked_fill.py +++ b/backends/arm/_passes/decompose_masked_fill.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import Set, Type diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py index 9e98ad90aed..22d2ec1d85b 100644 --- a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import operator from typing import Set, Type diff --git a/backends/arm/_passes/decompose_select.py b/backends/arm/_passes/decompose_select.py index ba12f9d93d7..73f8decf4a1 100644 --- a/backends/arm/_passes/decompose_select.py +++ b/backends/arm/_passes/decompose_select.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/decompose_silu_pass.py b/backends/arm/_passes/decompose_silu_pass.py index 80c9413acfb..413beb2625f 100644 --- a/backends/arm/_passes/decompose_silu_pass.py +++ b/backends/arm/_passes/decompose_silu_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/decompose_softmax_unstable_pass.py b/backends/arm/_passes/decompose_softmax_unstable_pass.py index 75cd90e4651..04e99a46b3e 100644 --- a/backends/arm/_passes/decompose_softmax_unstable_pass.py +++ b/backends/arm/_passes/decompose_softmax_unstable_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index 50731388fed..e716de3b048 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Tuple, Type, Union import torch diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index f5903d61135..db5d820ac70 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import Set, Type diff --git a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py index a6f69a1fcc9..9d704520302 100644 --- a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py +++ b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import Set, Type diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 7fd9c2f2119..4427e0357a0 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import copy diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batchnorm2d_pass.py index 250cac230d8..5d4308ec3f6 100644 --- a/backends/arm/_passes/fuse_batchnorm2d_pass.py +++ b/backends/arm/_passes/fuse_batchnorm2d_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py index f50216153a5..46a1c0d66fe 100644 --- a/backends/arm/_passes/fuse_quantized_activation_pass.py +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py index ef5aa9625c7..a12388e65df 100644 --- a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py +++ b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import logging diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index ade287a0cee..8d8a1284011 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from itertools import chain from typing import Callable, cast, Dict, Iterator, Set, Type diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index c09df48f7be..e70e45c61b4 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -5,6 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import cast, Set, Type diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 48dbde43802..353977fba0a 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type diff --git a/backends/arm/_passes/remove_noop_pass.py b/backends/arm/_passes/remove_noop_pass.py index 9758ac7ba24..5035e26bc47 100644 --- a/backends/arm/_passes/remove_noop_pass.py +++ b/backends/arm/_passes/remove_noop_pass.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import logging from typing import Set, Type diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index b7715654e7c..f6ef056f677 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import Dict, Set, Type, Union diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index 8a1cd91593f..5ca6a60e844 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import cast, Set, Type, Union diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index 23e6ec422aa..d0cc164ba30 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import cast, Set, Type, TypeAlias diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 2e6db91640d..3783f782610 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import logging from typing import Set, Type diff --git a/backends/arm/_passes/unsqueeze_before_repeat_pass.py b/backends/arm/_passes/unsqueeze_before_repeat_pass.py index ed6aa82aad5..6384f001580 100644 --- a/backends/arm/_passes/unsqueeze_before_repeat_pass.py +++ b/backends/arm/_passes/unsqueeze_before_repeat_pass.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type import torch diff --git a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py index bd093d6774e..5691b04ff2f 100644 --- a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py +++ b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import Set, Type From eb2c8763874ad8e006b7a5908b83e6c6ec921490 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 14/19] Revert "Arm backend: Remove pyre-unsafe from tosa/, vgf/ and ethosu/ (#15352)" This reverts commit eff402415cfd8ebf3643ebeb6973349ef86e2532. --- backends/arm/ethosu/__init__.py | 1 + backends/arm/ethosu/backend.py | 1 + backends/arm/ethosu/partitioner.py | 1 + backends/arm/tosa/__init__.py | 1 + backends/arm/tosa/backend.py | 1 + backends/arm/tosa/mapping.py | 1 + backends/arm/tosa/partitioner.py | 1 + backends/arm/tosa/quant_utils.py | 1 + backends/arm/tosa/specification.py | 1 + backends/arm/tosa/utils.py | 1 + backends/arm/vgf/__init__.py | 1 + backends/arm/vgf/backend.py | 1 + backends/arm/vgf/partitioner.py | 1 + 13 files changed, 13 insertions(+) diff --git a/backends/arm/ethosu/__init__.py b/backends/arm/ethosu/__init__.py index 10b14d4a68a..25a91dc5929 100644 --- a/backends/arm/ethosu/__init__.py +++ b/backends/arm/ethosu/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # +# pyre-unsafe from .backend import EthosUBackend # noqa: F401 from .compile_spec import EthosUCompileSpec # noqa: F401 diff --git a/backends/arm/ethosu/backend.py b/backends/arm/ethosu/backend.py index c2feab6478b..a529aa126f4 100644 --- a/backends/arm/ethosu/backend.py +++ b/backends/arm/ethosu/backend.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe # # Main implementation of AoT flow to partition and preprocess for Arm target diff --git a/backends/arm/ethosu/partitioner.py b/backends/arm/ethosu/partitioner.py index 33ac6a1db62..4e6c5f4b985 100644 --- a/backends/arm/ethosu/partitioner.py +++ b/backends/arm/ethosu/partitioner.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import final, Optional, Sequence diff --git a/backends/arm/tosa/__init__.py b/backends/arm/tosa/__init__.py index 30860642ac5..132d3563a43 100644 --- a/backends/arm/tosa/__init__.py +++ b/backends/arm/tosa/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # +# pyre-unsafe from .specification import TosaSpecification diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 8fb50707952..e19d026e03b 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe # # Main implementation of AoT flow to partition and preprocess for Arm target diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 2287a727009..e21fd38723b 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe """Provide PyTorch-to-TOSA mapping helpers. Use these utilities to translate PyTorch dtypes and FX node metadata into diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 58abf4e7e2e..3a1a79ec8de 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe """Provide a partitioner for delegating subgraphs to the TOSA backend. Implement logic to identify and tag regions of an ``ExportedProgram`` that can diff --git a/backends/arm/tosa/quant_utils.py b/backends/arm/tosa/quant_utils.py index 9ad2192bb9a..ddf9d31145c 100644 --- a/backends/arm/tosa/quant_utils.py +++ b/backends/arm/tosa/quant_utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe # Utility functions for TOSA quantized lowerings diff --git a/backends/arm/tosa/specification.py b/backends/arm/tosa/specification.py index 7afa7d9f0de..3edf27760b5 100644 --- a/backends/arm/tosa/specification.py +++ b/backends/arm/tosa/specification.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe """Provide TOSA specification parsing and context utilities. Use these helpers to parse and validate TOSA profile/extension strings and to diff --git a/backends/arm/tosa/utils.py b/backends/arm/tosa/utils.py index 14a22298d8a..edcef8ceb9d 100644 --- a/backends/arm/tosa/utils.py +++ b/backends/arm/tosa/utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import logging from typing import Any diff --git a/backends/arm/vgf/__init__.py b/backends/arm/vgf/__init__.py index 88be90e084e..f4ce8f5d1a4 100644 --- a/backends/arm/vgf/__init__.py +++ b/backends/arm/vgf/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # +# pyre-unsafe from .backend import VgfBackend # noqa: F401 from .compile_spec import VgfCompileSpec # noqa: F401 diff --git a/backends/arm/vgf/backend.py b/backends/arm/vgf/backend.py index 82d200f44fd..7ed0154ab99 100644 --- a/backends/arm/vgf/backend.py +++ b/backends/arm/vgf/backend.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe # # Main implementation of AoT flow to partition and preprocess for VGF target diff --git a/backends/arm/vgf/partitioner.py b/backends/arm/vgf/partitioner.py index be684505d2b..ea10730e810 100644 --- a/backends/arm/vgf/partitioner.py +++ b/backends/arm/vgf/partitioner.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import final, Optional, Sequence From 8167327e153efaa9ce5d7153a9b3939d94129150 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 15/19] Revert "Arm backend: Move rescales from SUM visitor to pass (#15299)" This reverts commit c4cd274642ba63e0851760ff197131f4a8c515e2. --- backends/arm/_passes/arm_pass_manager.py | 5 +- backends/arm/_passes/decompose_sum_pass.py | 2 +- backends/arm/_passes/insert_rescales_pass.py | 10 +-- .../operator_support/reduce_sum_support.py | 9 +-- backends/arm/operators/op_sum.py | 64 ++++++++++++++++++- backends/arm/test/ops/test_sum.py | 1 - .../passes/test_insert_rescale_i32_pass.py | 55 ++++------------ 7 files changed, 83 insertions(+), 63 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 98240f6dc1d..b579d910752 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -194,6 +194,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) + self.add_pass(DecomposeSumPass()) self.add_pass(DecomposeCumsumPass(exported_program)) self.add_pass(Conv1dUnsqueezePass()) self.add_pass(DecomposeMaxPool2DPass()) @@ -214,11 +215,10 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RewriteMatmulPass()) self.add_pass(RewriteUpsamplePass()) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) - self.add_pass(InsertRescaleInt32Pass()) - self.add_pass(DecomposeSumPass()) self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) + self.add_pass(InsertRescaleInt32Pass()) self.validate_constraints_mandatory() return self._transform(exported_program.graph_module) @@ -361,6 +361,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(ConvertMinMaxPass()) self.add_pass(ReplaceInfValues()) + self.add_pass(DecomposeSumPass()) if not self.tosa_spec.is_U55_subset: # Uses where which is not supported on Ethos-U55 diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index 989d299a7e8..59c352a0e07 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -83,7 +83,7 @@ def call_operator(self, op, args, kwargs, meta): if not keepdims: shape = list(meta["val"].size()) input_node = super().call_operator( - view_op, (input_node, shape), {}, meta, updated=True + view_op, (input_node, shape), kwargs, meta, updated=True ) return input_node diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 831dfe360b1..89630978366 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -10,7 +10,6 @@ import torch from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg -from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_output_qparams, ) @@ -85,11 +84,7 @@ class InsertRescaleInt32Pass(ArmPass): parameters. """ - # SUM must be decomposed after this pass to prevent insertion of RESCALE - # nodes between each subsequent SUM node after decomposition. RESCALE nodes - # should only be inserted before and after the SUM node prior to its - # decomposition. - _passes_required_after: Set[Type[ExportPass]] = {DecomposeSumPass} + _passes_required_after: Set[Type[ExportPass]] = set() included_targets = [ exir_ops.edge.aten.abs.default, @@ -101,7 +96,6 @@ class InsertRescaleInt32Pass(ArmPass): exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.sum.dim_IntList, ] def _int32_qargs(self, s): @@ -144,7 +138,6 @@ def _get_inputs_rescaled_qparams( } elif target in [ exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.sum.dim_IntList, ]: # The input scales do not need to be adjusted for these ops; they # can remain the same. @@ -167,7 +160,6 @@ def _get_output_qparams( exir_ops.edge.aten.abs.default, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, - exir_ops.edge.aten.sum.dim_IntList, ]: # The op has not altered the scale; the output scale is equal to # the operands' scales. diff --git a/backends/arm/operator_support/reduce_sum_support.py b/backends/arm/operator_support/reduce_sum_support.py index 76d1ba7bf36..4ff8f54ad69 100644 --- a/backends/arm/operator_support/reduce_sum_support.py +++ b/backends/arm/operator_support/reduce_sum_support.py @@ -29,13 +29,8 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): # U55 case, Vela 4.2.0 (25.02 release) input_shape = node.all_input_nodes[0].meta["val"].shape - - if node.args[1] is None: - # Dim is allowed to be None, which means to sum all dimensions - dim_list = list(range(len(input_shape))) - else: - dim_list = cast(list[int], node.args[1]) - dim_list = [dim % len(input_shape) for dim in dim_list] + dim_list = cast(list[int], node.args[1]) + dim_list = [dim % len(input_shape) for dim in dim_list] for dim in dim_list: if not 1 <= input_shape[dim] <= 65536: diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index 7be3e884275..5c88c00537e 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -7,6 +7,8 @@ from typing import Any, List +import executorch.backends.arm.tosa.quant_utils as tqutils +import executorch.backends.arm.tosa.utils as tutils import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( @@ -23,14 +25,69 @@ @register_node_visitor -class SumVisitor(NodeVisitor): +class SumVisitor_INT(NodeVisitor): target = "aten.sum.dim_IntList" tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+FP"), TosaSpecification.create_from_string("TOSA-1.0+INT"), ] + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, 3) + validate_same_dtype(self.target, [inputs[0], output], ts) + + tensor = inputs[0] + input_shape = list(tensor.shape) + dim = int(inputs[1].number % len(input_shape)) + + output_shape = input_shape + output_shape[dim] = 1 # Output shape is input shape with dim reduced + + # Rescale input to 32 bit + rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32( + tosa_graph, [tensor], node, self.tosa_spec + ) + + attr = ts.TosaSerializerAttribute() + attr.ReduceSumAttribute(tensor.dim_order.index(dim)) + + intermediate = tosa_graph.addIntermediate( + tutils.tosa_shape(output_shape, tensor.dim_order), + dtype=ts.DType.INT32, + ) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.REDUCE_SUM, + [rescaled_inputs[0].name], + [intermediate.name], + attr, + ) + + tqutils.insert_rescale_op_to_int8( + tosa_graph, intermediate, scale, node, self.tosa_spec + ) + + +@register_node_visitor +class SumVisitor_FP(SumVisitor_INT): + # inheriting 'target' from INT class + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + def define_node( self, node: Node, @@ -45,6 +102,9 @@ def define_node( input_shape = list(tensor.shape) dim = int(inputs[1].number % len(input_shape)) + output_shape = input_shape + output_shape[dim] = 1 # Output shape is input shape with dim reduced + attr = ts.TosaSerializerAttribute() attr.ReduceSumAttribute(tensor.dim_order.index(dim)) diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index f0af9a022e8..13c1e029032 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -35,7 +35,6 @@ class Sum(torch.nn.Module): "4d_dim_3_keep": lambda: (torch.rand(1, 2, 3, 4), 3, True), "4d_dims_keep": lambda: (torch.rand(1, 2, 8, 8), [2, 3, 0], True), "dim_None": lambda: (torch.rand(10), None, True), - "dim_None_4d_tensor": lambda: (torch.rand(10, 3, 2, 1), None, True), } def forward(self, x: torch.Tensor, dim: int, keepdim: bool): diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py index 2f625b955ce..66f09ba89a9 100644 --- a/backends/arm/test/passes/test_insert_rescale_i32_pass.py +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -13,11 +13,14 @@ from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -class MultipleOpsModel(torch.nn.Module): +class NeedsRescaleOps(torch.nn.Module): """A module containing ops that require INT32 inputs/outputs.""" input_t = Tuple[torch.Tensor, torch.Tensor] + def __init__(self): + super().__init__() + def forward(self, x, y): a = x * y b = torch.maximum(a, y) @@ -36,41 +39,19 @@ def get_inputs(self, dtype) -> input_t: else: raise ValueError("Not a valid input dtype for model") - def get_num_expected_rescales(self): - # "number of op nodes with i8 output" + "number of i8 node inputs" - return 3 + 7 - - -class SumModel(torch.nn.Module): - input_t = Tuple[torch.Tensor] - - def forward(self, x): - a = torch.sum(x, 2, keepdim=True) # (1, 2, 1, 4) - b = torch.sum(a, [1, 3], keepdim=True) # (1, 1, 1, 1) - c = torch.sum(b, [0, 2], keepdim=False) # (1, 1) - return c - - def get_inputs(self, dtype) -> input_t: - if dtype == torch.float32: - return (torch.rand(1, 2, 3, 4),) - elif dtype == torch.int32: - return (torch.randint(0, 10, (1, 2, 3, 4), dtype=torch.int32),) - else: - raise ValueError("Not a valid input dtype for model") - def get_num_expected_rescales(self): - # Two RESCALE nodes per SUM node - return 6 - - -def _test_model_with_f32_data(model): +def test_insert_rescales(): + module = NeedsRescaleOps() + input_t = Tuple[torch.Tensor, torch.Tensor] ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} ops_after = { - "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": model.get_num_expected_rescales(), + # "number of op nodes with i8 output" + "number of i8 node inputs" + "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3 + + 7, } - pipeline = PassPipeline[model.input_t]( - model, - model.get_inputs(torch.float32), + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(torch.float32), quantize=True, ops_not_before_pass=ops_not_before, ops_after_pass=ops_after, @@ -80,16 +61,8 @@ def _test_model_with_f32_data(model): pipeline.run() -def test_insert_rescales_sum_model(): - _test_model_with_f32_data(SumModel()) - - -def test_insert_rescales_multiple_ops_model(): - _test_model_with_f32_data(MultipleOpsModel()) - - def test_dont_insert_rescales(): - module = MultipleOpsModel() + module = NeedsRescaleOps() input_t = Tuple[torch.Tensor, torch.Tensor] ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} # All inputs are already i32. Rescales should not be added. From 008a01468e2bc65f3659ae4509710d002ffaf153 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 16/19] Revert "Arm backend: Tag control flow submodules in partitioner (#15364)" This reverts commit b24c39a262359c22a253244c5c122cd686bfc6a5. --- backends/arm/tosa/partitioner.py | 173 ++++++++++++------------------- 1 file changed, 67 insertions(+), 106 deletions(-) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 3a1a79ec8de..6eb1dcbef72 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -15,7 +15,6 @@ """ import logging -from itertools import count from typing import Callable, List, Optional, Sequence, Tuple import torch @@ -36,10 +35,8 @@ ) from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.graph_module import get_control_flow_submodules from torch.export.exported_program import ExportedProgram -from torch.fx import GraphModule -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase logger = logging.getLogger(__name__) @@ -113,43 +110,6 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool: return all(m == 1 for m in multiples) -def is_partitioned( - node: torch.fx.Node, - tag: str, -) -> bool: - """Return True if the node currently belongs to the partition ``tag``. - - Args: - node (torch.fx.Node): FX node to check. - tag (str): Delegation tag identifying the partition. - - Returns: - bool: True if the node carries the matching delegation tag. - - """ - return "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag - - -def reject_partition( - reason: str, partition: Partition, reporter: WhyNoPartitionReporter -) -> None: - """Remove a proposed partition and record the rejection reason. - - Args: - reason (str): Human-readable explanation for rejection. - partition (object): Proposed partition object from the - capability partitioner. - reporter (WhyNoPartitionReporter): used to report why nodes were rejected. - """ - for node in partition.nodes: - if "delegation_tag" in node.meta: - del node.meta["delegation_tag"] - reporter.report_reject( - node, - reason, - ) - - class TOSAPartitioner(Partitioner): """Partition an exported program into TOSA-delegable subgraphs. @@ -182,66 +142,97 @@ def __init__( self.additional_checks = additional_checks self.tosa_spec = compile_spec.tosa_spec - def _tag_module( # noqa - self, - module: GraphModule, - containing_program: ExportedProgram, - reporter: WhyNoPartitionReporter, - tag_iterator: count | None = None, - ) -> set[str]: - """Tag nodes in a module, possibly a submodule, from the containing program. + def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa + """Partition the program and tag TOSA-compatible subgraphs. + + Run the FX capability-based partitioner to propose subgraphs, then + refine tags by removing boundary-only quantize/dequantize nodes and by + rejecting partitions that would lower to no-ops. Emit a detailed report + of rejected nodes and their reasons. Args: - module: a GraphModule from `containing_program` to tag nodes in. - containing_program: The ExportedProgram that contains the module. - reporter: A reporter to report why nodes were rejected. + exported_program (ExportedProgram): Program to analyze and + partition. + Returns: - A set of strings with the partition tags. + PartitionResult: The input program with nodes tagged for delegation + and a mapping of partition tags to delegation specs. + """ - tags: set[str] = set() - if tag_iterator is None: - tag_iterator = count(0) - for _, submodule, _ in get_control_flow_submodules(module): - submodule_tags = self._tag_module( - submodule, containing_program, reporter, tag_iterator - ) - if len(tags & submodule_tags) != 0: - raise RuntimeError( - "Got overlapping tags in two different modules, this shouldn't happen." - ) - tags = tags | submodule_tags + logger.info("TOSAPartitioner::partition") + partition_tags: dict[str, DelegationSpec] = {} + + logger.info( + f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}" + ) + + reporter = WhyNoPartitionReporter() operator_support = tosa_support_factory( - self.tosa_spec, containing_program, reporter, self.additional_checks + self.tosa_spec, exported_program, reporter, self.additional_checks ) capability_partitioner = CapabilityBasedPartitioner( - module, + exported_program.graph_module, operator_support, allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() + def reject_partition(reason: str, partition, tag) -> None: + """Remove a proposed partition and record the rejection reason. + + Args: + reason (str): Human-readable explanation for rejection. + partition (object): Proposed partition object from the + capability partitioner. + tag (str): Delegation tag associated with the partition. + + """ + for node in partition.nodes: + if "delegation_tag" in node.meta: + del node.meta["delegation_tag"] + reporter.report_reject( + node, + reason, + ) + partition_tags.pop(tag, None) + for partition in partition_list: - tag = f"tag{next(tag_iterator)}" - tags.add(tag) + tag = f"tag{partition.id}" + + def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: + """Return True if the node currently belongs to the partition ``tag``. + + Args: + node (torch.fx.Node): FX node to check. + tag (str): Delegation tag identifying the partition. + + Returns: + bool: True if the node carries the matching delegation tag. + + """ + return ( + "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag + ) for node in partition.nodes: node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec # De-tag outermost q-nodes upwards and dq-nodes downwards. # De-tag if at least one input/output is not part of the partition. - for node in module.graph.nodes: - if not is_partitioned(node, tag): + for node in exported_program.graph_module.graph.nodes: + if not is_partitioned(node): continue if node.target in Q_OPS: for input in node.all_input_nodes: - if not is_partitioned(input, tag): + if not is_partitioned(input): del node.meta["delegation_tag"] break continue if node.target in DQ_OPS: for user in node.users: - if not is_partitioned(user, tag): + if not is_partitioned(user): del node.meta["delegation_tag"] break continue @@ -249,9 +240,9 @@ def _tag_module( # noqa if self.tosa_spec.support_float(): continue - if is_partitioned(node, tag): + if is_partitioned(node): for input in node.all_input_nodes: - if is_partitioned(input, tag): + if is_partitioned(input): continue if get_first_fake_tensor(input).dtype.is_floating_point: reporter.report_reject( @@ -274,38 +265,8 @@ def _tag_module( # noqa reject_partition( "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.", partition, - reporter, + tag, ) - tags.remove(tag) - return tags - - def partition(self, exported_program: ExportedProgram) -> PartitionResult: - """Partition the program and tag TOSA-compatible subgraphs. - - Run the FX capability-based partitioner to propose subgraphs, then - refine tags by removing boundary-only quantize/dequantize nodes and by - rejecting partitions that would lower to no-ops. Emit a detailed report - of rejected nodes and their reasons. - - Args: - exported_program (ExportedProgram): Program to analyze and - partition. - - Returns: - PartitionResult: The input program with nodes tagged for delegation - and a mapping of partition tags to delegation specs. - - """ - logger.info("TOSAPartitioner::partition") - logger.info( - f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}" - ) - - reporter = WhyNoPartitionReporter() - tags = self._tag_module( - exported_program.graph_module, exported_program, reporter - ) - partition_tags = {tag: self.delegation_spec for tag in tags} tag_constant_data(exported_program) logger.info(f"The following nodes were rejected for {self.tosa_spec}:") From a3ff3261bf749a82d8c8445458893a494e7b4e33 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 17/19] Revert "Arm backend: support mean.default (#15363)" This reverts commit 9075855703738107178a102495484c3cecf277cf. --- .../arm/_passes/decompose_meandim_pass.py | 21 ++++------ .../tosa_profile_supported_op_lists.py | 1 - backends/arm/scripts/parse_test_names.py | 1 - backends/arm/test/ops/test_mean_dim.py | 42 ------------------- 4 files changed, 8 insertions(+), 57 deletions(-) diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 135e2830d5d..4d4c0ee75b1 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -19,13 +19,13 @@ def get_meandim_decomposition(op) -> tuple: - if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): + if op == exir_ops.edge.aten.mean.dim: return ( exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.full.default, exir_ops.edge.aten.mul.Tensor, ) - if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): + if op == torch.ops.aten.mean.dim: return ( torch.ops.aten.sum.dim_IntList, torch.ops.aten.full.default, @@ -35,17 +35,17 @@ def get_meandim_decomposition(op) -> tuple: def get_avgpool(op): - if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): + if op == exir_ops.edge.aten.mean.dim: return exir_ops.edge.aten.avg_pool2d.default - if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): + if op == torch.ops.aten.mean.dim: return torch.ops.aten.avg_pool2d.default raise RuntimeError(f"Can't get meandim decomposition for op {op}") def get_view(op): - if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): + if op == exir_ops.edge.aten.mean.dim: return exir_ops.edge.aten.view_copy.default - if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): + if op == torch.ops.aten.mean.dim: return torch.ops.aten.view_copy.default raise RuntimeError(f"Can't get meandim decomposition for op {op}") @@ -87,18 +87,13 @@ def __init__(self, graph_module, tosa_spec): ) def call_operator(self, op, args, kwargs, meta): - if op not in ( - exir_ops.edge.aten.mean.dim, - torch.ops.aten.mean.dim, - exir_ops.edge.aten.mean.default, - torch.ops.aten.mean.default, - ): + if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim): return super().call_operator(op, args, kwargs, meta) x = get_node_arg(args, 0) input_shape = list(x.data.shape) output_shape = list(meta["val"].shape) - dims_to_reduce = get_node_arg(args, 1, range(len(input_shape))) + dims_to_reduce = get_node_arg(args, 1) if dims_to_reduce is None: dims_to_reduce = range(len(input_shape)) dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce] diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index b91ed4fb130..ee61aa4cce6 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -178,7 +178,6 @@ exir_ops.edge.aten.native_group_norm.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.mean.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.maximum.default, diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index 56a2a9c6890..54f8aa7421d 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -14,7 +14,6 @@ "hardswish.default", "linear.default", "maximum.default", - "mean.default", "multihead_attention.default", "adaptive_avg_pool2d.default", "bitwise_right_shift.Tensor", diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index babfb7d10da..31797e72e78 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -4,8 +4,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable - import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -346,43 +344,3 @@ def test_mean_dim_vgf_INT(test_data): tosa_version="TOSA-1.0+INT", ) pipeline.run() - - -mean_input_t = tuple[torch.Tensor, bool] - - -class MeanDefault(torch.nn.Module): - def forward(self, tensor: torch.Tensor, keepdim: bool): - return tensor.mean() - - test_data_suite: dict[str, Callable[[], mean_input_t]] = { - "rank1": lambda: ( - torch.rand( - 1, - ), - False, - ), - "rank2": lambda: (torch.rand(5, 5), True), - "rank4": lambda: (torch.rand(5, 1, 10, 1), False), - } - - -@common.parametrize("test_data", MeanDefault.test_data_suite) -def test_mean_tosa_FP(test_data): - pipeline = TosaPipelineFP[mean_input_t]( - MeanDefault(), - test_data(), - [], # Might be sum, avgpool, or both - ) - pipeline.run() - - -@common.parametrize("test_data", MeanDefault.test_data_suite) -def test_mean_tosa_INT(test_data): - pipeline = TosaPipelineINT[mean_input_t]( - MeanDefault(), - test_data(), - [], # Might be sum, avgpool, or both - symmetric_io_quantization=True, - ) - pipeline.run() From 53bb98b62bc8a614412731f1315066bb0fb20b8d Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:16:35 -0700 Subject: [PATCH 18/19] Revert "Arm backend: Support per-channel in TOSA.RESCALE (#15267)" This reverts commit f7ca57e3be96d1378107892b2017adc445dd56cf. --- .../decompose_int16_activation_conv2d_pass.py | 6 +- backends/arm/_passes/insert_rescales_pass.py | 16 ++-- backends/arm/_passes/insert_table_ops.py | 2 +- backends/arm/_passes/rewrite_conv2d_pass.py | 76 ++----------------- backends/arm/_passes/rewrite_matmul.py | 2 +- backends/arm/_passes/rewrite_upsample.py | 2 +- backends/arm/operators/op_tosa_conv2d.py | 59 +++++++++++++- .../arm/operators/op_tosa_depthwise_conv2d.py | 4 - backends/arm/operators/op_tosa_rescale.py | 6 +- .../arm/test/misc/test_tosa_dialect_conv2d.py | 4 +- .../test/misc/test_tosa_dialect_dw_conv2d.py | 4 +- backends/arm/test/passes/test_rescale_pass.py | 14 ++-- backends/arm/tosa/dialect/ops/conv2d.py | 3 +- backends/arm/tosa/dialect/ops/rescale.py | 6 +- 14 files changed, 94 insertions(+), 110 deletions(-) diff --git a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py index 388ce217807..d43c2a8c89c 100644 --- a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py +++ b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py @@ -105,14 +105,14 @@ def call_operator(self, op, args, kwargs, meta): conv_output = super().call_operator( exir_ops.backend.tosa.RESCALE.default, - (convolution, torch.int32, [conv_rescale_factor], 0, 0), + (convolution, torch.int32, conv_rescale_factor, 0, 0), {}, new_meta, ) bias_rescaled = super().call_operator( exir_ops.backend.tosa.RESCALE.default, - (channel_bias, torch.int32, [bias_rescale_factor], 0, 0), + (channel_bias, torch.int32, bias_rescale_factor, 0, 0), {}, new_meta, ) @@ -129,7 +129,7 @@ def call_operator(self, op, args, kwargs, meta): ( add, output_dtype, - [(common_scale / (conv_output_scale * (1 << bits_left_to_shift)))], + (common_scale / (conv_output_scale * (1 << bits_left_to_shift))), 0, 0, ), diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 89630978366..a7fa614c8c3 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -45,7 +45,7 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule ( node.all_input_nodes[0], q_args.dtype, - [new_scale], + new_scale, dq_args.zp, q_args.zp, ), @@ -228,10 +228,10 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b ( arg_node, torch.int32, - [ - qp.get_scale_per_tensor() - / rescale_qargs[i].get_scale_per_tensor() - ], # [Old scale / new scale] + qp.get_scale_per_tensor() + / rescale_qargs[ + i + ].get_scale_per_tensor(), # Old scale / new scale qp.get_zp_per_tensor(), # Old zero point rescale_qargs[i].get_zp_per_tensor(), # New zero point ), @@ -264,10 +264,8 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b ( node, qarg.dtype, - [ - rescale_qargs.get_scale_per_tensor() - / qarg.get_scale_per_tensor() - ], # [Old scale / new scale] + rescale_qargs.get_scale_per_tensor() + / qarg.get_scale_per_tensor(), # Old scale / new scale rescale_qargs.get_zp_per_tensor(), # Old zero point qarg.get_zp_per_tensor(), # New zero point ), diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 8d8a1284011..e77d0c64c71 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -286,7 +286,7 @@ def call(self, graph_module: GraphModule) -> PassResult: rescale_node = create_node( graph=graph_module.graph, op_target=exir_ops.backend.tosa.RESCALE.default, - args=(table_op_node, output_qparams[0].dtype, [scale], 0, 0), + args=(table_op_node, output_qparams[0].dtype, scale, 0, 0), ) output_node = rescale_node diff --git a/backends/arm/_passes/rewrite_conv2d_pass.py b/backends/arm/_passes/rewrite_conv2d_pass.py index c46cfb3b205..8b4f43c35c7 100644 --- a/backends/arm/_passes/rewrite_conv2d_pass.py +++ b/backends/arm/_passes/rewrite_conv2d_pass.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. -import itertools from typing import Set, Type import torch @@ -17,10 +16,6 @@ is_buffer, is_param, ) -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - get_output_qparams, -) from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import create_constant_placeholder @@ -161,40 +156,6 @@ def _add_bias( node.update_arg(2, bias_node) return bias_node - def insert_output_rescale(self, graph_module, node): - input_qparams = get_input_qparams(node) - output_qparams = get_output_qparams(node)[0] - weight_qparams = input_qparams[1] - input_qparams = input_qparams[0] - is_per_channel = weight_qparams.per_channel - if is_per_channel: - weight_scale = weight_qparams.get_scale_per_channel() - else: - weight_scale = [weight_qparams.get_scale_per_tensor()] - input_scale = input_qparams.get_scale_per_tensor() - post_conv2d_scale = [ - (inp * w) / out - for inp, w, out in zip( - itertools.cycle([input_scale]), - weight_scale, - itertools.cycle([output_qparams.get_scale_per_tensor()]), - ) - ] - with graph_module.graph.inserting_after(node): - rescale_node = create_node( - graph=graph_module.graph, - op_target=exir_ops.backend.tosa.RESCALE.default, - args=( - node, - output_qparams.dtype, - post_conv2d_scale, - 0, - output_qparams.get_zp_per_tensor(), - ), - from_node=node, - ) - return rescale_node - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: modified = False for node in graph_module.graph.nodes: @@ -219,20 +180,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: ) = node.args pad = [val for val in pad for _ in (0, 1)] - input_fake_tensor = get_first_fake_tensor(x) - weight_fake_tensor = get_first_fake_tensor(weight) + input_shape = get_first_fake_tensor(x).shape + weight_shape = get_first_fake_tensor(weight).shape # Adjust the pad value if needed to meet the # strict convolution output shape calculation. pad[1] = self._adjust_pad_if_needed( - input_fake_tensor.shape[2], - weight_fake_tensor.shape[2], + input_shape[2], + weight_shape[2], stride[0], pad[1], dilation[0], ) pad[3] = self._adjust_pad_if_needed( - input_fake_tensor.shape[3], - weight_fake_tensor.shape[3], + input_shape[3], + weight_shape[3], stride[1], pad[3], dilation[1], @@ -243,8 +204,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: if self._is_depthwise_conv2d(node): target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default - self._reshape_weights(weight, input_fake_tensor.shape[1]) - weight_fake_tensor = get_first_fake_tensor(weight) + self._reshape_weights(weight, input_shape[1]) else: target_op = exir_ops.backend.tosa.CONV2D.default @@ -267,29 +227,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: args=conv2d_args, from_node=node, ) - bias_fake_tensor = get_first_fake_tensor(bias) if bias else None - tosa_node_fake_tensor = target_op( - input_fake_tensor, - weight_fake_tensor, - bias_fake_tensor, - *conv2d_args[3:], - ) - if ( - tosa_node_fake_tensor.dtype == torch.int32 - and input_fake_tensor.dtype == torch.int8 - ) or ( - tosa_node_fake_tensor.dtype == torch.int32 - and input_fake_tensor.dtype == torch.int16 - ): - output_rescale = self.insert_output_rescale(graph_module, tosa_op) - node.replace_all_uses_with(output_rescale) - if input_fake_tensor.dtype == torch.int16: - tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 - else: node.replace_all_uses_with(tosa_op) - - graph_module.graph.erase_node(node) + graph_module.graph.erase_node(node) if modified: graph_module.recompile() diff --git a/backends/arm/_passes/rewrite_matmul.py b/backends/arm/_passes/rewrite_matmul.py index 410f0d62bff..28ff800792b 100644 --- a/backends/arm/_passes/rewrite_matmul.py +++ b/backends/arm/_passes/rewrite_matmul.py @@ -44,7 +44,7 @@ def _insert_output_rescale(self, graph_module, node, tosa_matmul_node, dtype): rescale_node.args = ( tosa_matmul_node, dtype, - [scale], + scale, 0, output_qparams.get_zp_per_tensor(), ) diff --git a/backends/arm/_passes/rewrite_upsample.py b/backends/arm/_passes/rewrite_upsample.py index e0ef1dbcf4a..c9f25a1e845 100644 --- a/backends/arm/_passes/rewrite_upsample.py +++ b/backends/arm/_passes/rewrite_upsample.py @@ -74,7 +74,7 @@ def call(self, graph_module): rescale_node.args = ( tosa_resize_node, output_dtype, - [output_scale], + output_scale, 0, # zero point 0, # zero point ) diff --git a/backends/arm/operators/op_tosa_conv2d.py b/backends/arm/operators/op_tosa_conv2d.py index 0e10867da7e..3631a143b50 100644 --- a/backends/arm/operators/op_tosa_conv2d.py +++ b/backends/arm/operators/op_tosa_conv2d.py @@ -8,12 +8,14 @@ """Provide a visitor for lowering 2D convolution to TOSA (INT/FP).""" +import itertools from typing import Any, List import torch from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, + get_output_qparams, ) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -24,7 +26,9 @@ validate_valid_dtype, ) from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.quant_utils import build_rescale from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification +from executorch.backends.arm.tosa.utils import tosa_shape @register_node_visitor @@ -54,8 +58,7 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - """Define the TOSA CONV2D/DEPTHWISE_CONV2D operator.""" - + """Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale.""" input, weight, bias, stride, pad, dilation, _, _, group = inputs validate_num_inputs(self.target, inputs, 9) @@ -102,8 +105,23 @@ def define_node( input_qparams = get_input_qparams(node) weight_zp = input_qparams[1].zp # type: ignore[assignment] - conv2d_output_name = output.name - acc_type = output.dtype + # The output type is int32 when input type is int8. + if inputs[0].dtype == ts.DType.INT8: + conv2d_res = tosa_graph.addIntermediate( + tosa_shape(output.shape, output.dim_order), ts.DType.INT32 + ) + conv2d_output_name = conv2d_res.name + acc_type = ts.DType.INT32 + elif inputs[0].dtype == ts.DType.INT16: + conv2d_res = tosa_graph.addIntermediate( + tosa_shape(output.shape, output.dim_order), ts.DType.INT48 + ) + conv2d_output_name = conv2d_res.name + acc_type = ts.DType.INT48 + else: + conv2d_output_name = output.name + conv2d_res = output + acc_type = ts.DType.FP32 tosa_graph.addConst( [1], inputs[0].dtype, [input_zp], name=f"{conv2d_output_name}_input_zp" @@ -140,3 +158,36 @@ def define_node( [conv2d_output_name], attr, ) + + # For quantized convolution, rescale the output value back to the same + # integer value domain of the next op. Otherwise return float32 output. + if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16: + # Get scale_factor from input, weight, and output. + input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61] + per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61] + if per_channel_quant: + weight_scale = input_qparams[1].get_scale_per_channel() + else: + weight_scale = [ + input_qparams[1].get_scale_per_tensor() + ] # pyre-ignore [61] + output_qargs = get_output_qparams(node) + post_conv2d_scale = [ + (inp * w) / out + for inp, w, out in zip( + itertools.cycle([input_scale]), + weight_scale, + itertools.cycle([output_qargs[0].get_scale_per_tensor()]), + ) + ] + build_rescale( + tosa_fb=tosa_graph, + scale=post_conv2d_scale, + input_node=conv2d_res, # type: ignore[possibly-undefined] + output_name=output.name, + output_type=output.dtype, + input_zp=[0], + output_zp=[output_qargs[0].get_zp_per_tensor()], + per_channel=per_channel_quant, + rounding_mode=ts.RoundingMode.SINGLE_ROUND, + ) diff --git a/backends/arm/operators/op_tosa_depthwise_conv2d.py b/backends/arm/operators/op_tosa_depthwise_conv2d.py index 1d1c317a0b8..3538b6f31da 100644 --- a/backends/arm/operators/op_tosa_depthwise_conv2d.py +++ b/backends/arm/operators/op_tosa_depthwise_conv2d.py @@ -4,11 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe - -"""Provide a visitor for lowering 2D depthwise convolution to TOSA (INT/FP).""" - import tosa_serializer as ts - from executorch.backends.arm.operators.node_visitor import register_node_visitor from executorch.backends.arm.operators.op_tosa_conv2d import Conv2dVisitor from executorch.backends.arm.tosa import TosaSpecification diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index 26f27370bdd..db3738a8fd1 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -41,7 +41,7 @@ def define_node( input_dtype = inputs[0].dtype output_dtype = cast(torch.dtype, node.args[1]) - scales = cast(list[float], node.args[2]) + scale = cast(float, node.args[2]) input_zp = cast(int, node.args[3]) output_zp = cast(int, node.args[4]) @@ -63,12 +63,12 @@ def define_node( build_rescale( tosa_graph, - scale=scales, + scale=[scale], input_node=inputs[0], output_name=output.name, output_type=output.dtype, input_zp=[input_zp], output_zp=[output_zp], rounding_mode=ts.RoundingMode.SINGLE_ROUND, - per_channel=len(scales) > 1, + per_channel=False, ) diff --git a/backends/arm/test/misc/test_tosa_dialect_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_conv2d.py index 3496ca0d5b6..867578a4ff5 100644 --- a/backends/arm/test/misc/test_tosa_dialect_conv2d.py +++ b/backends/arm/test/misc/test_tosa_dialect_conv2d.py @@ -31,7 +31,7 @@ def test_conv2d_tosa_INT(): 4, ), (1, 8, 20, 20), - torch.int32, + torch.int8, ), ( ( @@ -46,7 +46,7 @@ def test_conv2d_tosa_INT(): 4, ), (1, 4, 10, 10), - torch.int32, + torch.int8, ), ] diff --git a/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py index 8b50df20830..8d9224d90fe 100644 --- a/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py +++ b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py @@ -32,7 +32,7 @@ def test_depthwise_conv2d_tosa_INT(): 8, ), (1, 16, 20, 20), - torch.int32, + torch.int8, ), ( ( @@ -48,7 +48,7 @@ def test_depthwise_conv2d_tosa_INT(): 8, ), (1, 32, 10, 10), - torch.int32, + torch.int8, ), ] diff --git a/backends/arm/test/passes/test_rescale_pass.py b/backends/arm/test/passes/test_rescale_pass.py index ecd1deadf4f..9774ebd2fcd 100644 --- a/backends/arm/test/passes/test_rescale_pass.py +++ b/backends/arm/test/passes/test_rescale_pass.py @@ -31,21 +31,21 @@ def test_rescale_op(): ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int32, - [0.2], + 0.2, 2, 0, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), torch.int8, - [0.2], + 0.2, 0, -128, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int8, - [0.8], + 0.8, 10, 127, ), @@ -71,14 +71,14 @@ def test_nonzero_zp_for_int32(): ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int32, - [0.2], + 0.2, 2, # Should be 0, expect error 1, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), torch.int8, - [0.2], + 0.2, 1, 1, # Should be 0, expect error ), @@ -107,14 +107,14 @@ def test_zp_outside_range(): ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int32, - [0.2], + 0.2, 128, # Should be <128, expect error 0, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), torch.int8, - [0.2], + 0.2, 0, -129, # Should be >-129m expect error ), diff --git a/backends/arm/tosa/dialect/ops/conv2d.py b/backends/arm/tosa/dialect/ops/conv2d.py index 45afae51708..052c1111615 100644 --- a/backends/arm/tosa/dialect/ops/conv2d.py +++ b/backends/arm/tosa/dialect/ops/conv2d.py @@ -45,7 +45,8 @@ def validate_conv2d_args_dtypes( f"TOSA spec {tosa_spec} only supports {torch.int32} bias for {x.dtype} input but found {bias.dtype}", op=op, ) - output_dtype = torch.int32 + # TODO update to int32 for int8 inputs + output_dtype = torch.int8 if x.dtype == torch.int8 else torch.int16 elif x.dtype in supported_float_types: if not tosa_spec.support_float(): diff --git a/backends/arm/tosa/dialect/ops/rescale.py b/backends/arm/tosa/dialect/ops/rescale.py index f622bbf115d..5f0cf9d15dc 100644 --- a/backends/arm/tosa/dialect/ops/rescale.py +++ b/backends/arm/tosa/dialect/ops/rescale.py @@ -3,8 +3,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List - import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op @@ -16,13 +14,13 @@ @register_fake_tosa_op( - "RESCALE(Tensor input1, ScalarType dtype, float[] scale, int in_zp, int out_zp) -> Tensor", # schema + "RESCALE(Tensor input1, ScalarType dtype, float scale, int in_zp, int out_zp) -> Tensor", # schema ( TosaSpecification.create_from_string("TOSA-1.0+INT"), ), # target TOSA specifications ) def RESCALE( - x: torch.Tensor, dtype: torch.dtype, scales: List[float], in_zp: int, out_zp: int + x: torch.Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int ) -> torch.Tensor: tosa_spec = get_context_spec() """Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op. From 36f25ce24a4d1fa5d0cf79decdc70d217089c47e Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 27 Oct 2025 11:25:46 -0700 Subject: [PATCH 19/19] Revert "ArBackend: Enable Pybindings for tosa_serialization lib (#15356)" This reverts commit fdfeaa4aed880377758ed43f8dd5e90ae95e6ac5. --- .mypy.ini | 6 -- backends/arm/common/debug.py | 22 +++++--- backends/arm/debug/schema.py | 13 +++-- backends/arm/ethosu/backend.py | 9 --- backends/arm/operators/node_visitor.py | 7 +-- backends/arm/operators/op_abs.py | 16 +++--- backends/arm/operators/op_add.py | 16 +++--- backends/arm/operators/op_amax.py | 7 +-- backends/arm/operators/op_amin.py | 8 +-- backends/arm/operators/op_any.py | 4 +- backends/arm/operators/op_avg_pool2d.py | 15 +++-- backends/arm/operators/op_bitwise_not.py | 8 +-- backends/arm/operators/op_cat.py | 4 +- backends/arm/operators/op_ceil.py | 8 +-- backends/arm/operators/op_clamp.py | 56 +++++++------------ backends/arm/operators/op_constant_pad_nd.py | 9 +-- backends/arm/operators/op_cos.py | 7 +-- backends/arm/operators/op_eq.py | 9 ++- backends/arm/operators/op_erf.py | 8 +-- backends/arm/operators/op_exp.py | 6 +- backends/arm/operators/op_floor.py | 8 +-- backends/arm/operators/op_ge.py | 8 +-- backends/arm/operators/op_gt.py | 8 +-- backends/arm/operators/op_index_select.py | 8 +-- backends/arm/operators/op_index_tensor.py | 18 ++---- backends/arm/operators/op_le.py | 8 +-- backends/arm/operators/op_log.py | 7 +-- backends/arm/operators/op_logical_not.py | 9 +-- backends/arm/operators/op_lt.py | 8 +-- backends/arm/operators/op_max_pool2d.py | 11 ++-- backends/arm/operators/op_maximum.py | 9 ++- backends/arm/operators/op_minimum.py | 8 ++- backends/arm/operators/op_mul.py | 12 +--- backends/arm/operators/op_neg.py | 10 ++-- backends/arm/operators/op_permute.py | 6 +- backends/arm/operators/op_pow.py | 9 ++- backends/arm/operators/op_reciprocal.py | 9 ++- backends/arm/operators/op_repeat.py | 10 ++-- backends/arm/operators/op_rshift_tensor.py | 6 +- backends/arm/operators/op_rsqrt.py | 9 ++- backends/arm/operators/op_sigmoid.py | 7 +-- backends/arm/operators/op_sin.py | 7 +-- backends/arm/operators/op_slice.py | 41 ++++---------- backends/arm/operators/op_sub.py | 14 ++--- backends/arm/operators/op_sum.py | 6 +- backends/arm/operators/op_tanh.py | 7 +-- .../arm/operators/op_to_dim_order_copy.py | 9 ++- backends/arm/operators/op_tosa_conv2d.py | 10 +++- .../arm/operators/op_tosa_depthwise_conv2d.py | 5 +- backends/arm/operators/op_tosa_matmul.py | 9 +-- backends/arm/operators/op_tosa_rescale.py | 8 ++- backends/arm/operators/op_tosa_resize.py | 29 +++++----- backends/arm/operators/op_tosa_table.py | 11 ++-- backends/arm/operators/op_tosa_transpose.py | 6 +- backends/arm/operators/op_view.py | 6 +- backends/arm/operators/op_where.py | 8 +-- .../operators/operator_validation_utils.py | 19 ++++--- backends/arm/operators/ops_binary.py | 50 ++++------------- backends/arm/operators/ops_identity.py | 9 +-- backends/arm/process_node.py | 3 +- backends/arm/requirements-arm-tosa.txt | 2 + backends/arm/scripts/mlsdk_utils.sh | 45 ++++++++++----- backends/arm/test/models/test_llama.py | 8 --- backends/arm/test/tester/arm_tester.py | 5 +- backends/arm/tosa/backend.py | 15 ++--- backends/arm/tosa/mapping.py | 5 +- backends/arm/tosa/quant_utils.py | 19 ++++--- backends/arm/tosa/utils.py | 12 ++-- backends/arm/vgf/backend.py | 2 +- examples/arm/setup.sh | 38 ++----------- 70 files changed, 345 insertions(+), 489 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index baea2efefa9..cd14cbac7ea 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -83,12 +83,6 @@ ignore_missing_imports = True [mypy-tosa_tools.*] ignore_missing_imports = True -[mypy-tosa_serializer] -ignore_missing_imports = True - -[mypy-tosa_serializer.*] -ignore_missing_imports = True - [mypy-setuptools.*] ignore_missing_imports = True diff --git a/backends/arm/common/debug.py b/backends/arm/common/debug.py index e5c90fe7c3d..5a74805591d 100644 --- a/backends/arm/common/debug.py +++ b/backends/arm/common/debug.py @@ -7,9 +7,8 @@ import os from typing import Optional +import serializer.tosa_serializer as ts import torch - -import tosa_serializer as ts from executorch.exir.print_program import inspect_node logger = logging.getLogger(__name__) @@ -51,20 +50,29 @@ def get_node_debug_info( return output -# Output TOSA flatbuffer for debugging -def debug_tosa_dump(tosa_graph: bytes, path: str, suffix: str = ""): +# Output TOSA flatbuffer and test harness file +def debug_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): filename = f"output{suffix}.tosa" logger.info(f"Emitting debug output to: {path=}, {suffix=}") os.makedirs(path, exist_ok=True) + fb = tosa_graph.serialize() + js = tosa_graph.writeJson(filename) + filepath_tosa_fb = os.path.join(path, filename) with open(filepath_tosa_fb, "wb") as f: - f.write(tosa_graph) + f.write(fb) if not os.path.exists(filepath_tosa_fb): raise IOError("Failed to write TOSA flatbuffer") + filepath_desc_json = os.path.join(path, f"desc{suffix}.json") + with open(filepath_desc_json, "w") as f: + f.write(js) + if not os.path.exists(filepath_desc_json): + raise IOError("Failed to write TOSA JSON") + def debug_fail( node, @@ -73,7 +81,7 @@ def debug_fail( path: Optional[str] = None, ): logger.warning("Internal error due to poorly handled node:") - if tosa_graph is not None and path: - debug_tosa_dump(tosa_graph.serialize(), path) + if tosa_graph is not None and path is not None: + debug_tosa_dump(tosa_graph, path) logger.warning(f"Debug output captured in '{path}'.") debug_node(node, graph_module) diff --git a/backends/arm/debug/schema.py b/backends/arm/debug/schema.py index d4df2285304..48978c51a81 100644 --- a/backends/arm/debug/schema.py +++ b/backends/arm/debug/schema.py @@ -10,8 +10,8 @@ from dataclasses import asdict, dataclass from typing import Any, Optional +import serializer.tosa_serializer as ts import torch -import tosa_serializer as ts from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec @@ -114,9 +114,14 @@ def to_dict(self) -> dict[str, Any]: class DebugHook: def __init__(self, debug_mode: ArmCompileSpec.DebugMode) -> None: self._debug_events: list[DebugSchema] = [] + self.__op_id_to_name = {} self.mode = debug_mode - def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: ts.Op) -> DebugSchema: + # Build up a mapping from TOSA 1.0 operator IDs to their names + for name, val in vars(ts.Op).items(): + self.__op_id_to_name[val] = name + + def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema: tosa_debug_info = None # If the debug data is being embedded into the TOSA flatbuffer @@ -124,8 +129,8 @@ def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: ts.Op) -> DebugSche if self.mode != ArmCompileSpec.DebugMode.TOSA: tosa_debug_info = TosaDebugSchema( node_name=str(tosa_op), - operator_name=str(tosa_op_id), - operator_id=int(tosa_op_id), + operator_name=self.__op_id_to_name[tosa_op_id], + operator_id=tosa_op_id, ) aten_debug_info = ATenDebugSchema.from_node(node) diff --git a/backends/arm/ethosu/backend.py b/backends/arm/ethosu/backend.py index a529aa126f4..00da88ef60b 100644 --- a/backends/arm/ethosu/backend.py +++ b/backends/arm/ethosu/backend.py @@ -51,15 +51,6 @@ def _compile_tosa_flatbuffer( "compile_flags are required in the CompileSpec list for EthosUBackend" ) - # Vela tooling only supports flatbuffers up to 2 GiB. - max_flatbuffer_size = 2 * 1024 * 1024 * 1024 - flatbuffer_size = len(tosa_flatbuffer) - if flatbuffer_size > max_flatbuffer_size: - raise RuntimeError( - "TOSA flatbuffer is too large for Vela " - f"({flatbuffer_size} bytes > {max_flatbuffer_size} bytes limit)." - ) - # Pass on the TOSA flatbuffer to the vela compiler. binary = vela_compile( tosa_flatbuffer, diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index c26929dab28..172adbc7c78 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -9,7 +9,6 @@ from typing import Any, Dict, List, Optional import torch -import tosa_serializer as ts from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.debug.schema import DebugHook @@ -47,12 +46,12 @@ def _serialize_operator( self, node: torch.fx.Node, tosa_graph: Any, - tosa_op: ts.Op, + tosa_op: Any, inputs: List[str], outputs: List[str], attributes: Optional[Any] = None, ) -> None: - op_location = ts.TosaOpLocation() + op_location = "" if self.debug_hook: debug_info = self.debug_hook.add( node, @@ -61,7 +60,7 @@ def _serialize_operator( ) if self.debug_hook.mode == ArmCompileSpec.DebugMode.TOSA: - op_location.text = json.dumps(debug_info.to_dict()) + op_location = json.dumps(debug_info.to_dict()) tosa_graph.addOperator( tosa_op, diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index 82e09f5f1d4..b6213bfb467 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -6,7 +6,7 @@ # pyre-unsafe from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -48,13 +48,11 @@ def define_node( output.tosa_spec, ) - attr = ts.TosaSerializerAttribute() - attr.AbsAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.Op.ABS, - [inputs[0].name], + tosa_graph.addOperator( + ts.TosaOp.Op().ABS, + [ + inputs[0].name, + ], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 24f43d62e56..4215e53902f 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -9,7 +9,7 @@ import executorch.backends.arm.tosa.quant_utils as tqutils import executorch.backends.arm.tosa.utils as tutils -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -81,16 +81,15 @@ def define_node( add_output = output input1, input2 = rescaled_inputs - attr = ts.TosaSerializerAttribute() - attr.AddAttribute() + # Do the INT32 Add self._serialize_operator( node, tosa_graph, - ts.Op.ADD, + ts.TosaOp.Op().ADD, [input1.name, input2.name], [add_output.name], - attr, + None, ) if output.dtype == ts.DType.INT8: @@ -144,14 +143,13 @@ def define_node( ) input1, input2 = inputs - attr = ts.TosaSerializerAttribute() - attr.AddAttribute() + # FP lowering self._serialize_operator( node, tosa_graph, - ts.Op.ADD, + ts.TosaOp.Op().ADD, [input1.name, input2.name], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_amax.py b/backends/arm/operators/op_amax.py index e4824fb59c2..e576aed711a 100644 --- a/backends/arm/operators/op_amax.py +++ b/backends/arm/operators/op_amax.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.operators.node_visitor import ( @@ -60,12 +60,11 @@ def define_node( ) attr = ts.TosaSerializerAttribute() - nan_mode = ts.NanPropagationMode.PROPAGATE - attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=nan_mode) + attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=1) self._serialize_operator( node, tosa_graph, - ts.Op.REDUCE_MAX, + ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr, diff --git a/backends/arm/operators/op_amin.py b/backends/arm/operators/op_amin.py index 34d4d37cdeb..5dca08a73f3 100644 --- a/backends/arm/operators/op_amin.py +++ b/backends/arm/operators/op_amin.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.operators.node_visitor import ( @@ -60,13 +60,11 @@ def define_node( ) attr = ts.TosaSerializerAttribute() - attr.ReduceMinAttribute( - axis=input.dim_order.index(dim), nan_mode=ts.NanPropagationMode.PROPAGATE - ) + attr.ReduceMinAttribute(axis=input.dim_order.index(dim), nan_mode=1) self._serialize_operator( node, tosa_graph, - ts.Op.REDUCE_MIN, + ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr, diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index b024a01bd58..308c1af1738 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -6,7 +6,7 @@ # pyre-unsafe from typing import Any, cast, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( # type: ignore NodeVisitor, @@ -55,7 +55,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.Op.REDUCE_ANY, + ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr, diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index eb4bb743d5b..bad9c653d05 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -6,9 +6,9 @@ # pyre-unsafe from typing import Any, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, @@ -93,14 +93,17 @@ def _build_generic_avgpool2d( pad=pad_size_list, acc_type=accumulator_type, ) - dt: ts.DType = output.dtype - input_zp_tensor = tosa_graph.addConst(shape=[1], dtype=dt, vals=[input_zp]) - output_zp_tensor = tosa_graph.addConst(shape=[1], dtype=dt, vals=[output_zp]) + input_zp_tensor = tosa_graph.addConst( + shape=[1], dtype=output.dtype, vals=[input_zp] + ) + output_zp_tensor = tosa_graph.addConst( + shape=[1], dtype=output.dtype, vals=[output_zp] + ) self._serialize_operator( node, tosa_graph, - ts.Op.AVG_POOL2D, + ts.TosaOp.Op().AVG_POOL2D, [input_tensor.name, input_zp_tensor.name, output_zp_tensor.name], [output.name], attr, diff --git a/backends/arm/operators/op_bitwise_not.py b/backends/arm/operators/op_bitwise_not.py index ac0f758469d..79a2b53a8bc 100644 --- a/backends/arm/operators/op_bitwise_not.py +++ b/backends/arm/operators/op_bitwise_not.py @@ -5,7 +5,7 @@ from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -49,14 +49,10 @@ def define_node( output.tosa_spec, ) - attr = ts.TosaSerializerAttribute() - attr.BitwiseNotAttribute() - self._serialize_operator( node, tosa_graph, - ts.Op.BITWISE_NOT, + ts.TosaOp.Op().BITWISE_NOT, [inputs[0].name], [output.name], - attr, ) diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 11b46038fbf..60c112f455f 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -7,7 +7,7 @@ from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -50,7 +50,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.Op.CONCAT, + ts.TosaOp.Op().CONCAT, [tensor.name for tensor in tensors], [output.name], attr, diff --git a/backends/arm/operators/op_ceil.py b/backends/arm/operators/op_ceil.py index 27ee81d0abe..a33117e38e6 100644 --- a/backends/arm/operators/op_ceil.py +++ b/backends/arm/operators/op_ceil.py @@ -5,9 +5,9 @@ from typing import Any, List -import torch.fx +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch.fx from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -49,8 +49,6 @@ def define_node( output.tosa_spec, ) - attr = ts.TosaSerializerAttribute() - attr.CeilAttribute() self._serialize_operator( - node, tosa_graph, ts.Op.CEIL, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().CEIL, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index 1d394f76dab..0d58fa72c0c 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -9,8 +9,8 @@ from typing import Any, List, Tuple import numpy as np +import serializer.tosa_serializer as ts import torch -import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -83,15 +83,16 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.ClampAttribute( - np.frombuffer(np.int8(min_int8).tobytes(), dtype=np.uint8).tolist(), - np.frombuffer(np.int8(max_int8).tobytes(), dtype=np.uint8).tolist(), - ts.NanPropagationMode.PROPAGATE, + tosa_graph.builder, + np.int8(min_int8).tobytes(), + np.int8(max_int8).tobytes(), + nan_mode=1, ) self._serialize_operator( node, tosa_graph, - ts.Op.CLAMP, + ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr, @@ -125,43 +126,24 @@ def define_node( output.tosa_spec, ) + min_fp32, max_fp32 = self._get_min_max_arguments( + node, + torch.finfo(torch.float32).min, + torch.finfo(torch.float32).max, + ) + attr = ts.TosaSerializerAttribute() - match inputs[0].dtype: - case ts.DType.FP16: - min_f, max_f = self._get_min_max_arguments( - node, - torch.finfo(torch.float16).min, - torch.finfo(torch.float16).max, - ) - min_bytes = np.frombuffer( - np.float16(min_f).tobytes(), dtype=np.uint8 - ).tolist() - max_bytes = np.frombuffer( - np.float16(max_f).tobytes(), dtype=np.uint8 - ).tolist() - case ts.DType.FP32: - min_f, max_f = self._get_min_max_arguments( - node, - torch.finfo(torch.float32).min, - torch.finfo(torch.float32).max, - ) - min_bytes = np.frombuffer( - np.float32(min_f).tobytes(), dtype=np.uint8 - ).tolist() - max_bytes = np.frombuffer( - np.float32(max_f).tobytes(), dtype=np.uint8 - ).tolist() - case _: - raise RuntimeError( - f"Internal error: Unsupported dtype {inputs[0].dtype} in {self.target}" - ) - - attr.ClampAttribute(min_bytes, max_bytes, ts.NanPropagationMode.PROPAGATE) + attr.ClampAttribute( + tosa_graph.builder, + np.float32(min_fp32).tobytes(), + np.float32(max_fp32).tobytes(), + nan_mode=1, + ) self._serialize_operator( node, tosa_graph, - ts.Op.CLAMP, + ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr, diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index f1f0f1bcb19..07d4876b750 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -7,9 +7,9 @@ from typing import Any, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, @@ -100,13 +100,10 @@ def define_node( shape=[1], dtype=pad_const_dtype, vals=[pad_const_val] ) - attr = ts.TosaSerializerAttribute() - attr.PadAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.PAD, + ts.TosaOp.Op().PAD, [inputs[0].name, padding.name, pad_const.name], [output.name], - attr, ) diff --git a/backends/arm/operators/op_cos.py b/backends/arm/operators/op_cos.py index af97c0c95b3..59238535be6 100644 --- a/backends/arm/operators/op_cos.py +++ b/backends/arm/operators/op_cos.py @@ -6,7 +6,7 @@ # pyre-unsafe from typing import List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -43,8 +43,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - attr = ts.TosaSerializerAttribute() - attr.CosAttribute() + self._serialize_operator( - node, tosa_graph, ts.Op.COS, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().COS, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index 1ffdda219a2..18ada5ba3bd 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -7,7 +7,7 @@ from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -53,13 +53,12 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - attr = ts.TosaSerializerAttribute() - attr.EqualAttribute() + # Do the equal comparison self._serialize_operator( node, tosa_graph, - ts.Op.EQUAL, + ts.TosaOp.Op().EQUAL, [inputs[0].name, inputs[1].name], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py index ef68c97ffcf..bea4be0058b 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_erf.py @@ -5,9 +5,9 @@ # pyre-unsafe from typing import Any, List -import torch.fx +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch.fx from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -48,8 +48,6 @@ def define_node( ) # MI lowering - attr = ts.TosaSerializerAttribute() - attr.ErfAttribute() self._serialize_operator( - node, tosa_graph, ts.Op.ERF, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().ERF, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index aef9ec7aca0..4f9b32687ea 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -6,7 +6,7 @@ # pyre-unsafe from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -48,8 +48,6 @@ def define_node( output.tosa_spec, ) - attr = ts.TosaSerializerAttribute() - attr.ExpAttribute() self._serialize_operator( - node, tosa_graph, ts.Op.EXP, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().EXP, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_floor.py b/backends/arm/operators/op_floor.py index d9f831dfb35..8189076f73d 100644 --- a/backends/arm/operators/op_floor.py +++ b/backends/arm/operators/op_floor.py @@ -5,9 +5,9 @@ from typing import Any, List -import torch.fx +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch.fx from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -49,8 +49,6 @@ def define_node( output.tosa_spec, ) - attr = ts.TosaSerializerAttribute() - attr.FloorAttribute() self._serialize_operator( - node, tosa_graph, ts.Op.FLOOR, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().FLOOR, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index 3b55d526282..f9367c9a2b5 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -7,7 +7,7 @@ from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -53,13 +53,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - attr = ts.TosaSerializerAttribute() - attr.GreaterEqualAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.GREATER_EQUAL, + ts.TosaOp.Op().GREATER_EQUAL, [inputs[0].name, inputs[1].name], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index aa261ea06d7..4721205af77 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -7,7 +7,7 @@ from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -53,13 +53,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - attr = ts.TosaSerializerAttribute() - attr.GreaterAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.GREATER, + ts.TosaOp.Op().GREATER, [inputs[0].name, inputs[1].name], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_index_select.py b/backends/arm/operators/op_index_select.py index cf99f4efc50..79a0e5ecf5e 100644 --- a/backends/arm/operators/op_index_select.py +++ b/backends/arm/operators/op_index_select.py @@ -8,7 +8,7 @@ from typing import Any, List import executorch.backends.arm.tosa.quant_utils as tqutils # noqa: F401 -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -84,15 +84,13 @@ def define_node( tosa_graph, indices.name, indices_new_shape, indices_reshaped.name ) - attr = ts.TosaSerializerAttribute() - attr.GatherAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.GATHER, + ts.TosaOp.Op().GATHER, [weights_reshaped.name, indices_reshaped.name], [output_name], - attr, + None, ) if len(weights.shape) == 2: diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index af738033e1d..a7e22c5a14b 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -11,7 +11,7 @@ import executorch.backends.arm.tosa.utils as tutils import numpy as np -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -167,15 +167,12 @@ def define_node( data = np.full(index_shape, int(values_strides[i] / C)) mul_const = tosa_graph.addConst(index_shape, index_dtype, data) tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift") - attr = ts.TosaSerializerAttribute() - attr.MulAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.MUL, + ts.TosaOp.Op().MUL, [index_name, mul_const.name, f"{node.name}_{i}_shift"], [stride_shifted_indices.name], - attr, ) reshaped_idxs = tosa_graph.addIntermediate( @@ -199,15 +196,12 @@ def define_node( reshaped_idxs.shape, reshaped_idxs.dtype, ) - attr = ts.TosaSerializerAttribute() - attr.AddAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.ADD, + ts.TosaOp.Op().ADD, [gather_index_name, reshaped_idxs.name], [add_idxs.name], - attr, ) gather_index_name = add_idxs.name @@ -227,15 +221,13 @@ def define_node( gather_out_shape, output.dtype, ) - attr = ts.TosaSerializerAttribute() - attr.GatherAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.GATHER, + ts.TosaOp.Op().GATHER, [reshaped_input.name, gather_index_name], [gather_out.name], - attr, + None, ) output_shape = tutils.tosa_shape(output.shape, output.dim_order) diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index 086fd892a49..9f2515341a2 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -7,7 +7,7 @@ from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -53,13 +53,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - attr = ts.TosaSerializerAttribute() - attr.GreaterEqualAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.GREATER_EQUAL, + ts.TosaOp.Op().GREATER_EQUAL, [inputs[1].name, inputs[0].name], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 254e02c9adf..0473b39589d 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -6,7 +6,7 @@ # pyre-unsafe from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -44,8 +44,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - attr = ts.TosaSerializerAttribute() - attr.LogAttribute() + self._serialize_operator( - node, tosa_graph, ts.Op.LOG, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().LOG, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_logical_not.py b/backends/arm/operators/op_logical_not.py index 695af5f7a26..943b13f085f 100644 --- a/backends/arm/operators/op_logical_not.py +++ b/backends/arm/operators/op_logical_not.py @@ -5,9 +5,9 @@ from typing import Any, List -import torch.fx +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch.fx from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -50,13 +50,10 @@ def define_node( output.tosa_spec, ) - attr = ts.TosaSerializerAttribute() - attr.LogicalNotAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.LOGICAL_NOT, + ts.TosaOp.Op().LOGICAL_NOT, [inputs[0].name], [output.name], - attr, ) diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index ed831206e36..e48671474f3 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -7,7 +7,7 @@ from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -53,13 +53,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - attr = ts.TosaSerializerAttribute() - attr.GreaterAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.GREATER, + ts.TosaOp.Op().GREATER, [inputs[1].name, inputs[0].name], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index a068a2f49a7..dfe5a09d7d0 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -6,9 +6,9 @@ # pyre-unsafe from typing import Any, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -90,16 +90,13 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.MaxPool2dAttribute( - kernel=kernel_size, - stride=stride, - pad=pad_size_list, - nan_mode=ts.NanPropagationMode.PROPAGATE, + kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1 ) self._serialize_operator( node, tosa_graph, - ts.Op.MAX_POOL2D, + ts.TosaOp.Op().MAX_POOL2D, [input_tensor.name], [output.name], attr, diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index 463bee41a52..7e0c3a1cb16 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -7,7 +7,7 @@ from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -42,6 +42,8 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: + from tosa.NanPropagationMode import NanPropagationMode # type: ignore + validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( @@ -52,12 +54,13 @@ def define_node( ) attr_maximum = ts.TosaSerializerAttribute() - attr_maximum.MaximumAttribute(nan_mode=ts.NanPropagationMode.PROPAGATE) + # Set to PROPAGATE as default + attr_maximum.MaximumAttribute(nan_mode=NanPropagationMode.PROPAGATE) self._serialize_operator( node, tosa_graph, - ts.Op.MAXIMUM, + ts.TosaOp.Op().MAXIMUM, [ inputs[0].name, inputs[1].name, diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index 52125ed1f54..3050e3f7a86 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -7,7 +7,7 @@ from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -42,6 +42,7 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: + from tosa.NanPropagationMode import NanPropagationMode # type: ignore validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) @@ -53,12 +54,13 @@ def define_node( ) attr_minimum = ts.TosaSerializerAttribute() - attr_minimum.MinimumAttribute(nan_mode=ts.NanPropagationMode.PROPAGATE) + # Set to PROPAGATE as default + attr_minimum.MinimumAttribute(nan_mode=NanPropagationMode.PROPAGATE) self._serialize_operator( node, tosa_graph, - ts.Op.MINIMUM, + ts.TosaOp.Op().MINIMUM, [ inputs[0].name, inputs[1].name, diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 4b97d5e50b4..30b8eb23154 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -7,10 +7,9 @@ from typing import Any, List +import serializer.tosa_serializer as ts import torch -import tosa_serializer as ts - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -50,13 +49,8 @@ def define_node( ) tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") - attr = ts.TosaSerializerAttribute() - attr.MulAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.Op.MUL, + tosa_graph.addOperator( + ts.TosaOp.Op().MUL, [inputs[0].name, inputs[1].name, f"{node.name}_shift"], [output.name], - attr, ) diff --git a/backends/arm/operators/op_neg.py b/backends/arm/operators/op_neg.py index d025e8f0bd4..29a9073e5ed 100644 --- a/backends/arm/operators/op_neg.py +++ b/backends/arm/operators/op_neg.py @@ -6,9 +6,9 @@ # pyre-unsafe from typing import Any, List -import torch.fx +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch.fx from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, @@ -81,13 +81,11 @@ def define_node( output_zp_tensor = tosa_graph.addConst( (1,), output.dtype, [output_zp], name=output.name + "_output_zp" ) - attr = ts.TosaSerializerAttribute() - attr.NegateAttribute() + self._serialize_operator( node, tosa_graph, - ts.Op.NEGATE, + ts.TosaOp.Op().NEGATE, [inputs[0].name, input_zp_tensor.name, output_zp_tensor.name], [output.name], - attr, ) diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 26e66b40301..bcbe8d1d8b4 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -7,9 +7,9 @@ from typing import Any, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -138,7 +138,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.Op.TRANSPOSE, + ts.TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr, diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py index a46c6dd8df9..6545cd9470d 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_pow.py @@ -7,7 +7,7 @@ from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -49,16 +49,15 @@ def define_node( [ts.DType.FP16, ts.DType.FP32], output.tosa_spec, ) - attr = ts.TosaSerializerAttribute() - attr.PowAttribute() + self._serialize_operator( node, tosa_graph, - ts.Op.POW, + ts.TosaOp.Op().POW, [ inputs[0].name, inputs[1].name, ], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 10f9192a9c2..d73fbdc5761 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -6,9 +6,9 @@ # pyre-unsafe from typing import Any, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -45,8 +45,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - attr = ts.TosaSerializerAttribute() - attr.ReciprocalAttribute() + self._serialize_operator( - node, tosa_graph, ts.Op.RECIPROCAL, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 49c45913614..9305147d7d0 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -7,9 +7,9 @@ from typing import Any -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -60,13 +60,11 @@ def define_node( name=node.name + "_multiples", ) - attr = ts.TosaSerializerAttribute() - attr.TileAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.TILE, + ts.TosaOp.Op().TILE, [inputs[0].name, multiple_shapes.name], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index 9fcd2b56381..70661583d4e 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -7,9 +7,9 @@ from typing import Any, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -56,7 +56,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.Op.ARITHMETIC_RIGHT_SHIFT, + ts.TosaOp.Op().ARITHMETIC_RIGHT_SHIFT, [inputs[0].name, inputs[1].name], [output.name], attr, diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 259e34f129a..3042c8054dc 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -6,9 +6,9 @@ # pyre-unsafe from typing import Any, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -45,8 +45,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - attr = ts.TosaSerializerAttribute() - attr.RsqrtAttribute() + self._serialize_operator( - node, tosa_graph, ts.Op.RSQRT, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().RSQRT, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index 814158a1d32..445666599ef 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -6,7 +6,7 @@ # pyre-unsafe from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -44,8 +44,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - attr = ts.TosaSerializerAttribute() - attr.SigmoidAttribute() + self._serialize_operator( - node, tosa_graph, ts.Op.SIGMOID, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().SIGMOID, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_sin.py b/backends/arm/operators/op_sin.py index ac0ae212e78..7b78070613e 100644 --- a/backends/arm/operators/op_sin.py +++ b/backends/arm/operators/op_sin.py @@ -6,7 +6,7 @@ # pyre-unsafe from typing import List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -43,8 +43,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - attr = ts.TosaSerializerAttribute() - attr.SinAttribute() + self._serialize_operator( - node, tosa_graph, ts.Op.SIN, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().SIN, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index 941f8d690c3..7290c1243be 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -7,7 +7,7 @@ from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -23,34 +23,17 @@ def _fixup_start(start, shape, dim): - # Normalize start index and clamp into [0, shape[dim]]. - # If not a constant, default to 0. - idx = getattr(start, "number", 0) - # Handle negative wrap-around - if idx < 0: - idx = idx % shape[dim] - # Clamp into valid bounds - if idx < 0: - idx = 0 - elif idx > shape[dim]: - idx = shape[dim] - return idx + if start.number < 0: + return start.number % shape[dim] + else: + return start.number def _fixup_end(end, shape, dim): - # Normalize end index and clamp into [0, shape[dim]]. - max_dim = shape[dim] - # If not a constant, default to the full size - idx = getattr(end, "number", max_dim) - # Handle negative wrap-around - if idx < 0: - idx = idx % max_dim - # Clamp into valid bounds - if idx < 0: - idx = 0 - elif idx > max_dim: - idx = max_dim - return idx + if end.number < 0: + return end.number % shape[dim] + else: + return min(end.number, shape[dim]) @register_node_visitor @@ -134,13 +117,11 @@ def define_node( (sizes_len,), ts.DType.SHAPE, sizes, node.name + "_sizes_shape" ) - attr = ts.TosaSerializerAttribute() - attr.SliceAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.SLICE, + ts.TosaOp.Op().SLICE, [input_node.name, start_tensor.name, sizes_tensor.name], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 52caa9d0f8a..12ed94ec083 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -9,7 +9,7 @@ import executorch.backends.arm.tosa.quant_utils as tqutils import executorch.backends.arm.tosa.utils as tutils -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -76,18 +76,16 @@ def define_node( sub_output = output # Do the INT32 Sub - attr = ts.TosaSerializerAttribute() - attr.SubAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.SUB, + ts.TosaOp.Op().SUB, [ rescaled_inputs[0].name, rescaled_inputs[1].name, ], [sub_output.name], - attr, + None, ) if output.dtype == ts.DType.INT8: @@ -141,13 +139,11 @@ def define_node( ) # MI lowering - attr = ts.TosaSerializerAttribute() - attr.SubAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.SUB, + ts.TosaOp.Op().SUB, [inputs[0].name, inputs[1].name], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index 5c88c00537e..3f637d18390 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -9,7 +9,7 @@ import executorch.backends.arm.tosa.quant_utils as tqutils import executorch.backends.arm.tosa.utils as tutils -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -68,7 +68,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.Op.REDUCE_SUM, + ts.TosaOp.Op().REDUCE_SUM, [rescaled_inputs[0].name], [intermediate.name], attr, @@ -111,7 +111,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.Op.REDUCE_SUM, + ts.TosaOp.Op().REDUCE_SUM, [tensor.name], [output.name], attr, diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 0799628cd7d..dda86190202 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -6,7 +6,7 @@ # pyre-unsafe from typing import Any, List -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -45,8 +45,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - attr = ts.TosaSerializerAttribute() - attr.TanhAttribute() + self._serialize_operator( - node, tosa_graph, ts.Op.TANH, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().TANH, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_to_dim_order_copy.py b/backends/arm/operators/op_to_dim_order_copy.py index c41431a1b6d..08c5f70887d 100644 --- a/backends/arm/operators/op_to_dim_order_copy.py +++ b/backends/arm/operators/op_to_dim_order_copy.py @@ -6,9 +6,9 @@ # pyre-unsafe from typing import Any, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -43,8 +43,7 @@ def define_node( output: TosaArg, ) -> None: validate_num_inputs(self.target, inputs, 1) - attr = ts.TosaSerializerAttribute() - attr.CastAttribute() + self._serialize_operator( - node, tosa_graph, ts.Op.CAST, [inputs[0].name], [output.name], attr + node, tosa_graph, ts.TosaOp.Op().CAST, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_tosa_conv2d.py b/backends/arm/operators/op_tosa_conv2d.py index 3631a143b50..8672ea4f7ba 100644 --- a/backends/arm/operators/op_tosa_conv2d.py +++ b/backends/arm/operators/op_tosa_conv2d.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -import tosa_serializer as ts +import serializer.tosa_serializer as ts """Provide a visitor for lowering 2D convolution to TOSA (INT/FP).""" @@ -46,7 +46,9 @@ def __init__(self, *args): super().__init__(*args) def _get_tosa_op(self): - return ts.Op.CONV2D + import serializer.tosa_serializer as ts # type: ignore + + return ts.TosaOp.Op().CONV2D def _get_attr_func(self, attr): return attr.Conv2dAttribute @@ -59,6 +61,8 @@ def define_node( output: TosaArg, ) -> None: """Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale.""" + from tosa.RoundingMode import RoundingMode # type: ignore + input, weight, bias, stride, pad, dilation, _, _, group = inputs validate_num_inputs(self.target, inputs, 9) @@ -189,5 +193,5 @@ def define_node( input_zp=[0], output_zp=[output_qargs[0].get_zp_per_tensor()], per_channel=per_channel_quant, - rounding_mode=ts.RoundingMode.SINGLE_ROUND, + rounding_mode=RoundingMode.SINGLE_ROUND, ) diff --git a/backends/arm/operators/op_tosa_depthwise_conv2d.py b/backends/arm/operators/op_tosa_depthwise_conv2d.py index 3538b6f31da..ef4da3845fe 100644 --- a/backends/arm/operators/op_tosa_depthwise_conv2d.py +++ b/backends/arm/operators/op_tosa_depthwise_conv2d.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import register_node_visitor from executorch.backends.arm.operators.op_tosa_conv2d import Conv2dVisitor from executorch.backends.arm.tosa import TosaSpecification @@ -22,7 +21,9 @@ class DepthwiseConv2dVisitor(Conv2dVisitor): ] def _get_tosa_op(self): - return ts.Op.DEPTHWISE_CONV2D + import serializer.tosa_serializer as ts # type: ignore + + return ts.TosaOp.Op().DEPTHWISE_CONV2D def _get_attr_func(self, attr): return attr.DepthwiseConv2dAttribute diff --git a/backends/arm/operators/op_tosa_matmul.py b/backends/arm/operators/op_tosa_matmul.py index e88ef9be55d..b177fd2ba37 100644 --- a/backends/arm/operators/op_tosa_matmul.py +++ b/backends/arm/operators/op_tosa_matmul.py @@ -10,7 +10,6 @@ from typing import Any, List import torch -import tosa_serializer as ts from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, @@ -50,6 +49,8 @@ def define_node( output: TosaArg, ) -> None: """Define the TOSA ``MATMUL`` operator.""" + import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs], ts) validate_valid_dtype( @@ -79,13 +80,10 @@ def define_node( tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=input_B_ZP_name) # Add the MATMUL to the TOSA graph. - attr = ts.TosaSerializerAttribute() - attr.MatMulAttribute() - self._serialize_operator( node, tosa_graph, - ts.Op.MATMUL, + ts.TosaOp.Op().MATMUL, [ inputs[0].name, inputs[1].name, @@ -93,5 +91,4 @@ def define_node( input_B_ZP_name, ], [output.name], - attr, ) diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index db3738a8fd1..7c76d815645 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -7,9 +7,9 @@ from typing import Any, cast, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -37,6 +37,8 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: + from tosa.RoundingMode import RoundingMode # type: ignore + validate_num_inputs(self.target, inputs, 5) input_dtype = inputs[0].dtype @@ -69,6 +71,6 @@ def define_node( output_type=output.dtype, input_zp=[input_zp], output_zp=[output_zp], - rounding_mode=ts.RoundingMode.SINGLE_ROUND, + rounding_mode=RoundingMode.SINGLE_ROUND, per_channel=False, ) diff --git a/backends/arm/operators/op_tosa_resize.py b/backends/arm/operators/op_tosa_resize.py index 60328a3b3ab..f56f6eb19b2 100644 --- a/backends/arm/operators/op_tosa_resize.py +++ b/backends/arm/operators/op_tosa_resize.py @@ -6,9 +6,9 @@ # pyre-unsafe from typing import Any, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -22,6 +22,8 @@ from executorch.backends.arm.tosa.mapping import TosaArg from executorch.backends.arm.tosa.utils import get_resize_parameters +from tosa.ResizeMode import ResizeMode # type: ignore + @register_node_visitor class ResizeVisitor(NodeVisitor): @@ -41,10 +43,10 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, [3, 4]) if node.kwargs.get("resize_mode") == "bilinear": - resize_mode = ts.ResizeMode.BILINEAR + resize_mode = ResizeMode.BILINEAR align_corners = bool(node.args[2]) else: - resize_mode = ts.ResizeMode.NEAREST + resize_mode = ResizeMode.NEAREST align_corners = False validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( @@ -76,32 +78,27 @@ def in_int16_range(x): if not in_int16_range(border_yx): raise ValueError("border_yx is out of the int16 range") - scale_n_vals = [int(v) for v in scale_n_yx.tolist()] - scale_d_vals = [int(v) for v in scale_d_yx.tolist()] - scales = [ - scale_n_vals[0], - scale_d_vals[0], - scale_n_vals[1], - scale_d_vals[1], - ] + scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]] scales_tensor = tosa_graph.addConst( [len(scales)], ts.DType.SHAPE, scales, node.name + "_scales" ) - offset = [int(v) for v in offset_yx.tolist()] + offset = offset_yx.tolist() offset_tensor = tosa_graph.addConst( [len(offset)], ts.DType.SHAPE, offset, node.name + "_offset" ) - border = [int(v) for v in border_yx.tolist()] + border = border_yx.tolist() border_tensor = tosa_graph.addConst( [len(border)], ts.DType.SHAPE, border, node.name + "_border" ) attr = ts.TosaSerializerAttribute() - attr.ResizeAttribute(resize_mode) + attr.ResizeAttribute( + mode=resize_mode, + ) self._serialize_operator( node, tosa_graph, - ts.Op.RESIZE, + ts.TosaOp.Op().RESIZE, [ inputs[0].name, scales_tensor.name, diff --git a/backends/arm/operators/op_tosa_table.py b/backends/arm/operators/op_tosa_table.py index 9572e49781f..12f1f8dc5bd 100644 --- a/backends/arm/operators/op_tosa_table.py +++ b/backends/arm/operators/op_tosa_table.py @@ -7,9 +7,9 @@ from typing import Any, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -59,13 +59,12 @@ def define_node( table.detach().numpy(), name=table_tensor_name, ) - attr = ts.TosaSerializerAttribute() - attr.TableAttribute() + self._serialize_operator( node, tosa_graph, - ts.Op.TABLE, + ts.TosaOp.Op().TABLE, [inputs[0].name, table_tensor_name], [output.name], - attr, + None, ) diff --git a/backends/arm/operators/op_tosa_transpose.py b/backends/arm/operators/op_tosa_transpose.py index 2159a67b285..ff3e54e3d0e 100644 --- a/backends/arm/operators/op_tosa_transpose.py +++ b/backends/arm/operators/op_tosa_transpose.py @@ -7,9 +7,9 @@ from typing import Any, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -65,7 +65,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.Op.TRANSPOSE, + ts.TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr, diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index c6f6e36b6e9..d6c6aa588e8 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -6,9 +6,9 @@ # pyre-unsafe from typing import Any, cast, List -import torch +import serializer.tosa_serializer as ts -import tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -75,7 +75,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.Op.RESHAPE, + ts.TosaOp.Op().RESHAPE, [inputs[0].name, shape.name], [output.name], attr, diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py index c6c9a95070c..9b1518e2bbc 100644 --- a/backends/arm/operators/op_where.py +++ b/backends/arm/operators/op_where.py @@ -5,7 +5,7 @@ from typing import Any, List, Sequence -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -53,15 +53,13 @@ def _add_node_to_tosa_graph( output.tosa_spec, ) - attr = ts.TosaSerializerAttribute() - attr.SelectAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.SELECT, + ts.TosaOp.Op().SELECT, [inputs[0].name, inputs[1].name, inputs[2].name], [output.name], - attr, + None, ) def define_node( diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py index 9419e116789..d712b461304 100644 --- a/backends/arm/operators/operator_validation_utils.py +++ b/backends/arm/operators/operator_validation_utils.py @@ -6,6 +6,8 @@ from math import ceil, floor from typing import Any, List, Optional +import serializer.tosa_serializer as ts + def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[int]): """ @@ -96,14 +98,18 @@ def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = No # Get dtype of the first tensor to reference for comparison reference_dtype = tensors[0].dtype - reference_dtype_name = str(reference_dtype) for tensor in tensors: + ref_dtype_name = ( + ts.DTypeNames[reference_dtype] if ts is not None else str(reference_dtype) + ) + inconsistent_dtype_name = ( + ts.DTypeNames[tensor.dtype] if ts is not None else str(tensor.dtype) + ) if tensor.dtype != reference_dtype: - inconsistent_dtype_name = str(tensor.dtype) raise ValueError( - f"{op_name}: Expected all tensors to have dtype {reference_dtype_name}, " - f"but found inconsistent dtype {inconsistent_dtype_name}." + f"{op_name}: Expected all tensors to have dtype {ref_dtype_name}, but " + f"found inconsistent dtype {inconsistent_dtype_name}." ) @@ -165,11 +171,10 @@ def validate_valid_dtype( for tensor in tensors: if tensor.dtype not in valid_dtypes: - valid_names = [str(dtype) for dtype in valid_dtypes] - got_name = str(tensor.dtype) raise ValueError( f"Expected tensor {tensor.name} in {op_name} to have one of the " - f"following dtypes: {valid_names}, got: {got_name}" + f"following dtypes: {[ts.DTypeNames[i] for i in valid_dtypes]}, " + f"got: {ts.DTypeNames[tensor.dtype]}" ) diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 360e15a0ad2..af0ece81c34 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -5,13 +5,13 @@ # pyre-unsafe -from typing import Any, Callable, List +from typing import Any, List + +import serializer.tosa_serializer as ts import torch import torch.fx -import tosa_serializer as ts - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -24,9 +24,7 @@ from executorch.backends.arm.tosa.mapping import TosaArg -def binary_operator_factory( - bw_target: str, tosa_op, attr_builder: Callable[[Any], None] -): +def binary_operator_factory(bw_target: str, tosa_op): """Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op.""" class BinaryOperator(NodeVisitor): @@ -66,48 +64,24 @@ def define_node( [ts.DType.BOOL], output.tosa_spec, ) - attr = ts.TosaSerializerAttribute() - attr_builder(attr) + self._serialize_operator( node, tosa_graph, tosa_op, [inputs[0].name, inputs[1].name], [output.name], - attr, ) register_node_visitor(BinaryOperator) +binary_operator_factory("aten.bitwise_and.Tensor", ts.TosaOp.Op().BITWISE_AND) +binary_operator_factory("aten.bitwise_xor.Tensor", ts.TosaOp.Op().BITWISE_XOR) +binary_operator_factory("aten.bitwise_or.Tensor", ts.TosaOp.Op().BITWISE_OR) +binary_operator_factory("aten.logical_and.default", ts.TosaOp.Op().LOGICAL_AND) +binary_operator_factory("aten.logical_xor.default", ts.TosaOp.Op().LOGICAL_XOR) +binary_operator_factory("aten.logical_or.default", ts.TosaOp.Op().LOGICAL_OR) binary_operator_factory( - "aten.bitwise_and.Tensor", - ts.Op.BITWISE_AND, - lambda attr: attr.BitwiseAndAttribute(), -) -binary_operator_factory( - "aten.bitwise_xor.Tensor", - ts.Op.BITWISE_XOR, - lambda attr: attr.BitwiseXorAttribute(), -) -binary_operator_factory( - "aten.bitwise_or.Tensor", ts.Op.BITWISE_OR, lambda attr: attr.BitwiseOrAttribute() -) -binary_operator_factory( - "aten.logical_and.default", - ts.Op.LOGICAL_AND, - lambda attr: attr.LogicalAndAttribute(), -) -binary_operator_factory( - "aten.logical_xor.default", - ts.Op.LOGICAL_XOR, - lambda attr: attr.LogicalXorAttribute(), -) -binary_operator_factory( - "aten.logical_or.default", ts.Op.LOGICAL_OR, lambda attr: attr.LogicalOrAttribute() -) -binary_operator_factory( - "aten.bitwise_left_shift.Tensor", - ts.Op.LOGICAL_LEFT_SHIFT, - lambda attr: attr.LogicalLeftShiftAttribute(), + "aten.bitwise_left_shift.Tensor", ts.TosaOp.Op().LOGICAL_LEFT_SHIFT ) diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index 994b43a7c15..cb251ae12c1 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -7,11 +7,11 @@ from typing import Any, List +import serializer.tosa_serializer as ts + import torch import torch.fx -import tosa_serializer as ts - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -45,15 +45,12 @@ def define_node( validate_same_dtype(self.target, [*inputs, output], ts) # Simply add an identityOp - attr = ts.TosaSerializerAttribute() - attr.IdentityAttribute() self._serialize_operator( node, tosa_graph, - ts.Op.IDENTITY, + ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name], - attr, ) register_node_visitor(IdentityOperatorVisitor) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 7dd8f9a7d38..8865513a6dd 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -8,9 +8,9 @@ from typing import Any, cast, Dict import numpy as np +import serializer.tosa_serializer as ts import torch import torch.fx -import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype from executorch.backends.arm.tosa.specification import TosaSpecification @@ -85,6 +85,7 @@ def process_inputs( tosa_shape(input_shape, input_dim_order), tosa_arg.dtype, data=None, + placeholderFilename=tosa_arg.name + ".npy", ) tosa_graph.addInputTensor(tensor) diff --git a/backends/arm/requirements-arm-tosa.txt b/backends/arm/requirements-arm-tosa.txt index da115441c52..16aa01a6c23 100644 --- a/backends/arm/requirements-arm-tosa.txt +++ b/backends/arm/requirements-arm-tosa.txt @@ -7,3 +7,5 @@ ml_dtypes == 0.5.1 flatbuffers == 24.3.25 tosa-adapter-model-explorer == 0.0.1 ai-edge-model-explorer >= 0.1.16 + +tosa-tools @ git+https://git.gitlab.arm.com/tosa/tosa-reference-model.git@v2025.07.1 diff --git a/backends/arm/scripts/mlsdk_utils.sh b/backends/arm/scripts/mlsdk_utils.sh index 2257bc674ca..7f69cefb462 100755 --- a/backends/arm/scripts/mlsdk_utils.sh +++ b/backends/arm/scripts/mlsdk_utils.sh @@ -9,7 +9,7 @@ set -euo pipefail # URL and tag of the MLSDK manifest repository. Can be overridden by environment variables. # eg. export MLSDK_MANIFEST_URL=...; export MLSDK_MANIFEST_TAG=... mlsdk_manifest_url="${MLSDK_MANIFEST_URL:-https://github.com/arm/ai-ml-sdk-manifest.git}" -mlsdk_manifest_tag="${MLSDK_MANIFEST_TAG:-refs/tags/v2025.10.0}" +mlsdk_manifest_tag="${MLSDK_MANIFEST_TAG:-refs/tags/dev-snapshot-2025-09-12}" script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) @@ -36,14 +36,33 @@ function mlsdk_sync_manifest() { -g model-converter,emulation-layer,vgf-library local default_manifest=".repo/manifests/default.xml" - - # TODO: Remove this workaround once 2GB capable mlir translator is available - # in the official MLSDK repository. - if [[ "${OSTYPE:-}" == darwin* ]]; then - sed -i '' 's|revision="refs/tags/v2025.07.1"|revision="c3b324e643b4b4e592de8a9123a58c4179649d8c"|' "${default_manifest}" - else - sed -i 's|revision="refs/tags/v2025.07.1"|revision="c3b324e643b4b4e592de8a9123a58c4179649d8c"|' "${default_manifest}" + local local_manifest_path=".repo/local_manifests/tosa_gitlab.xml" + + # TODO: Remove this workaround once MLSDK switches to GitLab for tosa-mlir-translator + if [[ -f "${default_manifest}" ]] && grep -q '' "${default_manifest}"; then + log_step "mlsdk" "Patching MLSDK manifest to use GitLab tosa-mlir-translator mirror" + # Update dependencies to use gitlab tosa-mlir-translator + # Do not indent the xml. Heredoc indentation is significant. + mkdir -p .repo/local_manifests/ + cat > "${local_manifest_path}" <<'XML' + + + + + + + + + +XML fi + ./repo sync --force-sync -j"${parallel_jobs}" popd @@ -100,12 +119,12 @@ function download_ai_mlsdk_manifest() { log_step "mlsdk" "Manifest changed (url=${cached_url:-} -> ${mlsdk_manifest_url}, tag=${cached_tag:-} -> ${mlsdk_manifest_tag}); refreshing checkout" fi - # Clean up any local manifest changes to avoid repo sync errors. - # Since we patched in a local manifest for tosa_gitlab.xml, - # remove any existing local manifests to avoid conflicts. - # TODO: we should remove this at some point in the future, but its not hurting anything for now. + # Clean up any local changes to avoid repo sync errors. + # Note: This does not delete untracked files outside of .repo. + # Users should manually delete the checkout if they want a full clean. + # TODO: Remove this workaround once MLSDK switches to GitLab for tosa-mlir-translator if [[ -d "${_manifest_dir}/.repo/local_manifests" ]]; then - rm -rf "${_manifest_dir}/.repo/local_manifests/" + rm -f "${_manifest_dir}/.repo/local_manifests/tosa_gitlab.xml" fi # Clean up any local changes in the manifests repository. diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index 937dbf93674..d47398be3b0 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -111,12 +111,9 @@ def test_llama_tosa_FP(): llama_inputs, aten_op=[], exir_op=[], - custom_path="llama_tosa_fb", - run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk use_to_edge_transform_and_lower=True, transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()], ) - pipeline.add_stage_after("to_executorch", pipeline.tester.serialize) pipeline.run() @@ -132,11 +129,8 @@ def test_llama_tosa_INT(): llama_inputs, aten_op=[], exir_op=[], - custom_path="llama_tosa_fb_int", - run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk use_to_edge_transform_and_lower=True, ) - pipeline.add_stage_after("to_executorch", pipeline.tester.serialize) pipeline.run() @@ -156,7 +150,6 @@ def test_llama_vgf_FP(): tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()], - run_on_vulkan_runtime=True, ) pipeline.run() @@ -176,6 +169,5 @@ def test_llama_vgf_INT(): exir_op=[], tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, - run_on_vulkan_runtime=True, ) pipeline.run() diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 44b1a7aef13..996ae7340c6 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -24,11 +24,10 @@ import executorch.backends.xnnpack.test.tester.tester as tester +import serializer.tosa_serializer as ts + import torch.fx import torch.utils._pytree as pytree - -import tosa_serializer as ts - from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index e19d026e03b..6ffb7df4bfc 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -15,7 +15,7 @@ from itertools import count from typing import cast, Dict, final, List, Set -import tosa_serializer as ts +import serializer.tosa_serializer as ts from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump from executorch.backends.arm.debug.schema import DebugHook @@ -102,9 +102,6 @@ def _preprocess( # noqa: C901 # Converted output for this subgraph, serializer needs path early as it emits # const data directly. Path created and data written only in debug builds. - if not artifact_path: - artifact_path = "" - tosa_graph = ts.TosaSerializer(artifact_path) if not ( @@ -174,16 +171,13 @@ def _sort_key(t: Node) -> int: # any checking of compatibility. raise RuntimeError(f"{node.name} is unsupported op {node.op}") except Exception: - debug_fail(node, graph_module, tosa_graph.serialize(), artifact_path) + debug_fail(node, graph_module, tosa_graph, artifact_path) raise - # Serialize and return the TOSA flatbuffer. - binary = tosa_graph.serialize() - if artifact_path: tag = arm_get_first_delegation_tag(graph_module) debug_tosa_dump( - binary, + tosa_graph, artifact_path, suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), ) @@ -194,6 +188,9 @@ def _sort_key(t: Node) -> int: with open(f"{artifact_path}/debug.json", "w") as f: f.write(json_output) + # Serialize and return the TOSA flatbuffer. + binary = bytes(tosa_graph.serialize()) + return PreprocessResult(processed_bytes=binary) @staticmethod diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index e21fd38723b..4bc60f9fb3f 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -14,8 +14,9 @@ from enum import Enum from typing import Any, Optional, Sequence +import serializer.tosa_serializer as ts + import torch -import tosa_serializer as ts from executorch.backends.arm.tosa.specification import TosaSpecification UNSUPPORTED_DTYPES = ( @@ -39,7 +40,7 @@ class TosaSpecialDtype(Enum): INT48 = ts.DType.INT48 - def get_tosa_dtype(self) -> ts.DType: + def get_tosa_dtype(self) -> ts.TosaDType.DType: return self.value @staticmethod diff --git a/backends/arm/tosa/quant_utils.py b/backends/arm/tosa/quant_utils.py index ddf9d31145c..dd295ff3e73 100644 --- a/backends/arm/tosa/quant_utils.py +++ b/backends/arm/tosa/quant_utils.py @@ -11,7 +11,8 @@ from typing import Any, Tuple -import tosa_serializer as ts +import serializer.tosa_serializer as ts + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -20,6 +21,8 @@ from executorch.backends.arm.tosa.mapping import TosaArg from torch.fx import Node +from tosa.RoundingMode import RoundingMode # type: ignore + def insert_rescale_ops_to_int32_maxscale( tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None @@ -368,10 +371,12 @@ def build_rescale( output_type: Any, input_zp: list[int], output_zp: list[int], - rounding_mode: ts.RoundingMode, - per_channel: bool = False, - is_scale32: bool = True, + rounding_mode: RoundingMode, + per_channel=False, ): + + import tosa.Op as TosaOp # type: ignore + scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 is_scale32 = False if input_node.dtype == ts.DType.INT48 else True multipliers, shifts = compute_multiplier_and_shift(scale, scaleWidth) @@ -397,7 +402,7 @@ def build_rescale( ) tosa_fb.addOperator( - ts.Op.RESCALE, + TosaOp.Op().RESCALE, [input_node.name, *rescale_inputs], [output_name], attr_rescale, @@ -428,7 +433,7 @@ def build_rescale_to_int32( ts.DType.INT32, [input_zp], [0], - rounding_mode=ts.RoundingMode.SINGLE_ROUND, + rounding_mode=RoundingMode.SINGLE_ROUND, ) # type: ignore[call-arg] return input_A_rescaled_to_int32 @@ -496,7 +501,7 @@ def build_rescale_from_int32_to_dtype( output_type=output_dtype, input_zp=[0], output_zp=[output_zp], - rounding_mode=ts.RoundingMode.SINGLE_ROUND, + rounding_mode=RoundingMode.SINGLE_ROUND, ) # type: ignore[call-arg] return diff --git a/backends/arm/tosa/utils.py b/backends/arm/tosa/utils.py index edcef8ceb9d..f6f16724f30 100644 --- a/backends/arm/tosa/utils.py +++ b/backends/arm/tosa/utils.py @@ -9,11 +9,11 @@ from typing import Any import numpy as np +import serializer.tosa_serializer as ts import sympy # type: ignore import torch -import tosa_serializer as ts from executorch.backends.arm.tosa.mapping import extract_tensor_meta from executorch.backends.arm.tosa.specification import TosaSpecification @@ -121,13 +121,11 @@ def broadcast_tensors( name=f"{node.name}_multiples", ) - attr = ts.TosaSerializerAttribute() - attr.TileAttribute() tosa_fb.addOperator( - ts.Op.TILE, + ts.TosaOp.Op().TILE, [reshaped.name, multiple_shapes.name], [tiled.name], - attr, + None, ) broadcast_tensors.append(tiled) @@ -148,7 +146,7 @@ def build_reshape_tosa_1_0( attr = ts.TosaSerializerAttribute() attr.ReshapeAttribute() tosa_graph.addOperator( - ts.Op.RESHAPE, + ts.TosaOp.Op().RESHAPE, [input_name, shape.name], [output_name], attr, @@ -162,7 +160,7 @@ def tosa_shape(shape, dim_order): removed_symints = tuple( [-1 if isinstance(d, torch.SymInt) else d for d in reordered] ) - return list(removed_symints) + return removed_symints def get_resize_parameters_1d( diff --git a/backends/arm/vgf/backend.py b/backends/arm/vgf/backend.py index 7ed0154ab99..3f65456bf8b 100644 --- a/backends/arm/vgf/backend.py +++ b/backends/arm/vgf/backend.py @@ -112,7 +112,7 @@ def vgf_compile( Stdout:\n{process_error.stdout.decode()}" ) - if artifact_path: + if artifact_path is not None: logger.info(f"Emitting debug output to: {vgf_path=}") os.makedirs(artifact_path, exist_ok=True) cp = f"cp {vgf_path} {artifact_path}" diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index db32c0c416f..495a5181266 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -35,9 +35,6 @@ toolchain_url="" toolchain_dir="" toolchain_md5_checksum="" -# Load logging helpers early so option parsing can emit status messages. -source "$et_dir/backends/arm/scripts/utils.sh" - # List of supported options and their descriptions OPTION_LIST=( @@ -148,7 +145,7 @@ function check_options() { shift ;; --setup-test-dependency) - log_step "deps" "Installing test dependency..." + echo "Installing test dependency..." source $et_dir/backends/arm/scripts/install_models_for_test.sh exit 0 ;; @@ -221,6 +218,7 @@ if [[ $is_script_sourced -eq 0 ]]; then check_options "$@" # Import utils + source $et_dir/backends/arm/scripts/utils.sh source $et_dir/backends/arm/scripts/fvp_utils.sh source $et_dir/backends/arm/scripts/toolchain_utils.sh source $et_dir/backends/arm/scripts/vulkan_utils.sh @@ -279,39 +277,11 @@ if [[ $is_script_sourced -eq 0 ]]; then # Create the setup_path.sh used to create the PATH variable for shell create_setup_path - # Setup the TOSA reference model and serialization dependencies + # Setup the tosa_reference_model and dependencies log_step "deps" "Installing TOSA reference model dependencies" - CMAKE_POLICY_VERSION_MINIMUM=3.5 \ - pip install --no-dependencies -r "$et_dir/backends/arm/requirements-arm-tosa.txt" - - pushd "$root_dir" - if [[ ! -d "tosa-tools" ]]; then - git clone https://git.gitlab.arm.com/tosa/tosa-tools.git - fi - - pushd tosa-tools - git checkout 8468d041c50c6d806f3c1c18c66d7ef641e46580 # serialization lib pybindings - if [[ ! -d "reference_model" ]]; then - log_step "main" "[error] Missing reference_model directory in tosa-tools repo." - exit 1 - fi - if [[ ! -d "serialization" ]]; then - log_step "main" "[error] Missing serialization directory in tosa-tools repo." - exit 1 - fi - - - export CMAKE_BUILD_PARALLEL_LEVEL="$(get_parallel_jobs)" - - CMAKE_POLICY_VERSION_MINIMUM=3.5 \ - BUILD_PYBIND=1 \ - pip install --no-dependencies ./reference_model - CMAKE_POLICY_VERSION_MINIMUM=3.5 \ BUILD_PYBIND=1 \ - pip install --no-dependencies ./serialization - popd - popd + pip install --no-dependencies -r $et_dir/backends/arm/requirements-arm-tosa.txt if [[ "${enable_vela}" -eq 1 ]]; then log_step "deps" "Installing Ethos-U Vela compiler"