From 1882bc1aa577d1f4fba6c76215f97a64f9489977 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Fri, 17 Oct 2025 14:37:34 +0200 Subject: [PATCH 1/2] Arm backend: Break out processing per graph_module in backend. This will enable us to process multiple submodules contained in a partitioned ExportedProgram. Signed-off-by: Erik Lundell Change-Id: I82e41b1e9ff2409ca31e86e4a89747e694ab4ea4 --- backends/arm/_passes/arm_pass_manager.py | 28 ++--- backends/arm/tosa/backend.py | 125 +++++++++++++---------- 2 files changed, 88 insertions(+), 65 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b1eea847792..728eb984ae9 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -153,15 +153,15 @@ def _transform(self, graph_module: GraphModule): with TosaLoweringContext(self.tosa_spec): return self(graph_module).graph_module - def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + def _tosa_INT_pipeline( + self, exported_program: ExportedProgram, graph_module: GraphModule + ) -> GraphModule: self.add_pass(AnnotateOutputDimOrderPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) self.add_pass(ConvertSplitToSlicePass()) self.add_pass(ConvertMmToBmmPass()) - self.add_pass( - DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) - ) + self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) @@ -218,9 +218,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(InsertRescalePass()) self.validate_constraints_mandatory() - return self._transform(exported_program.graph_module) + return self._transform(graph_module) - def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + def _tosa_FP_pipeline( + self, exported_program: ExportedProgram, graph_module: GraphModule + ) -> GraphModule: self.add_pass(AnnotateOutputDimOrderPass()) self.add_pass(DecomposeExpm1Pass()) self.add_pass(DecomposeLogitPass()) @@ -255,9 +257,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(DecomposeLayerNormPass()) self.add_pass(DecomposeBatchNormNoStatsPass()) self.add_pass(DecomposeVarPass()) - self.add_pass( - DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) - ) + self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) self.add_pass(DecomposeNotEqualPass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeAddSubAlphaPass()) @@ -305,14 +305,16 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(InsertRescalePass()) self.validate_constraints_mandatory() - return self._transform(exported_program.graph_module) + return self._transform(graph_module) - def transform_to_backend_pipeline(self, exported_program: ExportedProgram): + def transform_to_backend_pipeline( + self, exported_program: ExportedProgram, graph_module: GraphModule + ): """Apply passes before transforming program to backend""" if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"): - return self._tosa_FP_pipeline(exported_program) + return self._tosa_FP_pipeline(exported_program, graph_module) elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"): - return self._tosa_INT_pipeline(exported_program) + return self._tosa_INT_pipeline(exported_program, graph_module) else: raise NotImplementedError( f"No pass pipeline implemented for {self.tosa_spec=}" diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 8fb50707952..8b62761ec47 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -27,7 +27,7 @@ from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.export.exported_program import ExportedProgram -from torch.fx import Graph, Node +from torch.fx import Graph, GraphModule, Node # TOSA backend debug functionality logger = logging.getLogger(__name__) @@ -52,13 +52,39 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]): # Walk backwards so we touch every producer q.extend(n.all_input_nodes) - out = next(n for n in ep_graph.nodes if n.op == "output") + out = ep_graph.output_node() + # First argument of output node is tuple of outputs + output_list = cast(tuple, out.args[0]) seen: Set[Node] = set() - for idx, val in enumerate(out.args[0]): + for idx, val in enumerate(output_list): bfs_mark([val], idx, seen) return node2external_id +def _sort_outputs(graph_module: GraphModule, node_to_id_map: dict[str, int]): + def _external_id(n: Node, node_2_id, fallback: int) -> int: + return node_2_id.get(n.name, fallback) + + out_node = graph_module.graph.output_node() + out_list = cast(tuple, out_node.args[0]) + _counter = count() + + # sort nodes by the key that is id + def _sort_key(t: Node) -> int: + return _external_id(t, node_to_id_map, next(_counter)) + + orig_ord = tuple(sorted(out_list, key=_sort_key)) + + current_order = tuple(out_list) + if orig_ord != current_order: + replacement = list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord + out_node.args = (replacement,) + graph_module.graph.lint() + graph_module.recompile() + + return graph_module + + def arm_get_first_delegation_tag(graph_module) -> str: """Get the first delegation tag from the graph_module or return empty string.""" for node in graph_module.graph.nodes: @@ -93,9 +119,9 @@ def _preprocess( # noqa: C901 artifact_path = compile_spec.get_intermediate_path() tosa_spec = compile_spec.tosa_spec dump_debug_info = compile_spec.tosa_debug_mode - - # Assign to every node external id - node_2_id = _annotate_external_ids(edge_program.graph) + debug_hook = None + if dump_debug_info is not None: + debug_hook = DebugHook(dump_debug_info) logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}") @@ -116,43 +142,57 @@ def _preprocess( # noqa: C901 f"doesn't match specification {tosa_spec}" ) + TOSABackend._preprocess_module( + edge_program.graph_module, + edge_program, + compile_spec, + tosa_graph, + debug_hook, + ) + # Serialize and return the TOSA flatbuffer. + binary = tosa_graph.serialize() + + if artifact_path: + tag = arm_get_first_delegation_tag(edge_program.graph_module) + debug_tosa_dump( + binary, + artifact_path, + suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), + ) + + if debug_hook is not None: + if debug_hook.mode == ArmCompileSpec.DebugMode.JSON: + json_output = debug_hook.serialize() + with open(f"{artifact_path}/debug.json", "w") as f: + f.write(json_output) + + return PreprocessResult(processed_bytes=binary) + + @staticmethod + def _preprocess_module( + graph_module: GraphModule, + edge_program: ExportedProgram, + compile_spec: TosaCompileSpec, + tosa_graph: ts.TosaSerializer, + debug_hook: DebugHook | None, + ): + """Convert 'graph_module' to a tosa_graph""" + tosa_spec = compile_spec.tosa_spec + node_to_id_map = _annotate_external_ids(graph_module.graph) + artifact_path = compile_spec.get_intermediate_path() + # TODO: Fix the need to lazily import this. from executorch.backends.arm._passes import ArmPassManager graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore - exported_program=edge_program + exported_program=edge_program, graph_module=graph_module ) - debug_hook = None - if dump_debug_info is not None: - debug_hook = DebugHook(dump_debug_info) - # TODO: Fix the need to lazily import this. from executorch.backends.arm.operators.node_visitor import get_node_visitors node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook) - - # Re-shuffle output nodes to preserve author's order - def _external_id(n: Node, node_2_id, fallback: int) -> int: - return node_2_id.get(n.name, fallback) - - out_node = next(n for n in graph_module.graph.nodes if n.op == "output") - _counter = count() - - # sort nodes by the key that is id - def _sort_key(t: Node) -> int: - return _external_id(t, node_2_id, next(_counter)) - - orig_ord = tuple(sorted(out_node.args[0], key=_sort_key)) - - current_order = tuple(out_node.args[0]) - if orig_ord != current_order: - replacement = ( - list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord - ) - out_node.args = (replacement,) - graph_module.graph.lint() - graph_module.recompile() + graph_module = _sort_outputs(graph_module, node_to_id_map) input_count = 0 for node in graph_module.graph.nodes: @@ -176,25 +216,6 @@ def _sort_key(t: Node) -> int: debug_fail(node, graph_module, tosa_graph.serialize(), artifact_path) raise - # Serialize and return the TOSA flatbuffer. - binary = tosa_graph.serialize() - - if artifact_path: - tag = arm_get_first_delegation_tag(graph_module) - debug_tosa_dump( - binary, - artifact_path, - suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), - ) - - if debug_hook is not None: - if debug_hook.mode == ArmCompileSpec.DebugMode.JSON: - json_output = debug_hook.serialize() - with open(f"{artifact_path}/debug.json", "w") as f: - f.write(json_output) - - return PreprocessResult(processed_bytes=binary) - @staticmethod def filter_tosa_compile_specs( compile_spec: ArmCompileSpec, From 29fba268cd84d775172a42d2c471c6724db55c98 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Fri, 17 Oct 2025 15:10:41 +0200 Subject: [PATCH 2/2] Arm backend: Enable serializing to different regions. Each conditional submodule in the graph_module gets its own region. The TOSA reference model requires all tensor names in one model to be unique, regardless of region. Pytorch's naming semantics, however don't guarantee this. To fix this, attach a suffix containing the submodule name to tensors in submodules. Signed-off-by: Erik Lundell Change-Id: I910a7d71f0b5da2d2d9219746efd012f9bd251fb --- backends/arm/process_node.py | 10 ++++------ backends/arm/tosa/backend.py | 31 +++++++++++++++++++++++++------ backends/arm/tosa/mapping.py | 4 +++- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 7dd8f9a7d38..a4f3ba7fb5c 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -159,7 +159,7 @@ def process_inputs_to_buffers( buffer_values = np.transpose(buffer_values, tosa_arg.dim_order) tosa_graph.addConst( - buffer_values.shape, tosa_arg.dtype, buffer_values, name=node.name + buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name ) @@ -216,11 +216,9 @@ def process_placeholder( raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.") -def process_output( - node: torch.fx.Node, - tosa_graph: Any, -): +def process_output(node: torch.fx.Node, tosa_graph: Any, tosa_spec: TosaSpecification): for output in cast(tuple[torch.fx.Node, ...], node.args[0]): + output_arg = TosaArg(output, tosa_spec) tosa_graph.addOutputTensor( - tosa_graph.currRegion.currBasicBlock.tensors[output.name] + tosa_graph.currRegion.currBasicBlock.tensors[output_arg.name] ) diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 8b62761ec47..75cdd28c6da 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -24,11 +24,14 @@ process_placeholder, ) from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.graph_module import get_control_flow_submodules from torch.export.exported_program import ExportedProgram from torch.fx import Graph, GraphModule, Node + # TOSA backend debug functionality logger = logging.getLogger(__name__) @@ -169,12 +172,13 @@ def _preprocess( # noqa: C901 return PreprocessResult(processed_bytes=binary) @staticmethod - def _preprocess_module( + def _preprocess_module( # noqa: C901 graph_module: GraphModule, edge_program: ExportedProgram, compile_spec: TosaCompileSpec, tosa_graph: ts.TosaSerializer, debug_hook: DebugHook | None, + submodule_name: str | None = None, ): """Convert 'graph_module' to a tosa_graph""" tosa_spec = compile_spec.tosa_spec @@ -194,7 +198,13 @@ def _preprocess_module( node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook) graph_module = _sort_outputs(graph_module, node_to_id_map) - input_count = 0 + if submodule_name is not None: + tosa_graph.startRegion(submodule_name) + tosa_graph.currRegion.addBasicBlock(submodule_name) + suffix = f"_{submodule_name}" + for loop_node in graph_module.graph.nodes: + loop_node.meta[TOSA_TENSOR_NAME_META] = suffix + for node in graph_module.graph.nodes: node = cast(Node, node) try: @@ -204,18 +214,27 @@ def _preprocess_module( if len(node.users) == 0: continue process_placeholder(node, tosa_graph, edge_program, tosa_spec) - if node.name in edge_program.graph_signature.user_inputs: - input_count += 1 elif node.op == "output": - process_output(node, tosa_graph) + process_output(node, tosa_graph, tosa_spec) else: # This will only happen if an unpartitioned graph is passed without # any checking of compatibility. raise RuntimeError(f"{node.name} is unsupported op {node.op}") except Exception: - debug_fail(node, graph_module, tosa_graph.serialize(), artifact_path) + debug_fail(node, graph_module, tosa_graph, artifact_path) raise + # Recursively preprocess controlflow submodules. + for name, submodule, _ in get_control_flow_submodules(graph_module): + TOSABackend._preprocess_module( + submodule, + edge_program, + compile_spec, + tosa_graph, + debug_hook, + submodule_name=name, + ) + @staticmethod def filter_tosa_compile_specs( compile_spec: ArmCompileSpec, diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 2287a727009..5162d2c6a53 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -17,6 +17,8 @@ import tosa_serializer as ts from executorch.backends.arm.tosa.specification import TosaSpecification +TOSA_TENSOR_NAME_META = "tosa_tensor_name" + UNSUPPORTED_DTYPES = ( torch.float64, torch.double, @@ -144,7 +146,7 @@ def __process_node(self, argument: torch.fx.Node): argument (torch.fx.Node): FX node to inspect. """ - self.name: str = argument.name + self.name = argument.name + argument.meta.get(TOSA_TENSOR_NAME_META, "") output_dtype, self.shape, self.dim_order = extract_tensor_meta( argument.meta, self.tosa_spec )