diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b1eea847792..728eb984ae9 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -153,15 +153,15 @@ def _transform(self, graph_module: GraphModule): with TosaLoweringContext(self.tosa_spec): return self(graph_module).graph_module - def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + def _tosa_INT_pipeline( + self, exported_program: ExportedProgram, graph_module: GraphModule + ) -> GraphModule: self.add_pass(AnnotateOutputDimOrderPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) self.add_pass(ConvertSplitToSlicePass()) self.add_pass(ConvertMmToBmmPass()) - self.add_pass( - DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) - ) + self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) @@ -218,9 +218,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(InsertRescalePass()) self.validate_constraints_mandatory() - return self._transform(exported_program.graph_module) + return self._transform(graph_module) - def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + def _tosa_FP_pipeline( + self, exported_program: ExportedProgram, graph_module: GraphModule + ) -> GraphModule: self.add_pass(AnnotateOutputDimOrderPass()) self.add_pass(DecomposeExpm1Pass()) self.add_pass(DecomposeLogitPass()) @@ -255,9 +257,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(DecomposeLayerNormPass()) self.add_pass(DecomposeBatchNormNoStatsPass()) self.add_pass(DecomposeVarPass()) - self.add_pass( - DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) - ) + self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) self.add_pass(DecomposeNotEqualPass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeAddSubAlphaPass()) @@ -305,14 +305,16 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(InsertRescalePass()) self.validate_constraints_mandatory() - return self._transform(exported_program.graph_module) + return self._transform(graph_module) - def transform_to_backend_pipeline(self, exported_program: ExportedProgram): + def transform_to_backend_pipeline( + self, exported_program: ExportedProgram, graph_module: GraphModule + ): """Apply passes before transforming program to backend""" if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"): - return self._tosa_FP_pipeline(exported_program) + return self._tosa_FP_pipeline(exported_program, graph_module) elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"): - return self._tosa_INT_pipeline(exported_program) + return self._tosa_INT_pipeline(exported_program, graph_module) else: raise NotImplementedError( f"No pass pipeline implemented for {self.tosa_spec=}" diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 54797b825b7..f9694c1abdf 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -158,7 +158,7 @@ def process_inputs_to_buffers( buffer_values = np.transpose(buffer_values, tosa_arg.dim_order) tosa_graph.addConst( - buffer_values.shape, tosa_arg.dtype, buffer_values, name=node.name + buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name ) @@ -215,11 +215,9 @@ def process_placeholder( raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.") -def process_output( - node: torch.fx.Node, - tosa_graph: Any, -): +def process_output(node: torch.fx.Node, tosa_graph: Any, tosa_spec: TosaSpecification): for output in cast(tuple[torch.fx.Node, ...], node.args[0]): + output_arg = TosaArg(output, tosa_spec) tosa_graph.addOutputTensor( - tosa_graph.currRegion.currBasicBlock.tensors[output.name] + tosa_graph.currRegion.currBasicBlock.tensors[output_arg.name] ) diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 8fb50707952..75cdd28c6da 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -24,10 +24,13 @@ process_placeholder, ) from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.graph_module import get_control_flow_submodules from torch.export.exported_program import ExportedProgram -from torch.fx import Graph, Node +from torch.fx import Graph, GraphModule, Node + # TOSA backend debug functionality logger = logging.getLogger(__name__) @@ -52,13 +55,39 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]): # Walk backwards so we touch every producer q.extend(n.all_input_nodes) - out = next(n for n in ep_graph.nodes if n.op == "output") + out = ep_graph.output_node() + # First argument of output node is tuple of outputs + output_list = cast(tuple, out.args[0]) seen: Set[Node] = set() - for idx, val in enumerate(out.args[0]): + for idx, val in enumerate(output_list): bfs_mark([val], idx, seen) return node2external_id +def _sort_outputs(graph_module: GraphModule, node_to_id_map: dict[str, int]): + def _external_id(n: Node, node_2_id, fallback: int) -> int: + return node_2_id.get(n.name, fallback) + + out_node = graph_module.graph.output_node() + out_list = cast(tuple, out_node.args[0]) + _counter = count() + + # sort nodes by the key that is id + def _sort_key(t: Node) -> int: + return _external_id(t, node_to_id_map, next(_counter)) + + orig_ord = tuple(sorted(out_list, key=_sort_key)) + + current_order = tuple(out_list) + if orig_ord != current_order: + replacement = list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord + out_node.args = (replacement,) + graph_module.graph.lint() + graph_module.recompile() + + return graph_module + + def arm_get_first_delegation_tag(graph_module) -> str: """Get the first delegation tag from the graph_module or return empty string.""" for node in graph_module.graph.nodes: @@ -93,9 +122,9 @@ def _preprocess( # noqa: C901 artifact_path = compile_spec.get_intermediate_path() tosa_spec = compile_spec.tosa_spec dump_debug_info = compile_spec.tosa_debug_mode - - # Assign to every node external id - node_2_id = _annotate_external_ids(edge_program.graph) + debug_hook = None + if dump_debug_info is not None: + debug_hook = DebugHook(dump_debug_info) logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}") @@ -116,45 +145,66 @@ def _preprocess( # noqa: C901 f"doesn't match specification {tosa_spec}" ) + TOSABackend._preprocess_module( + edge_program.graph_module, + edge_program, + compile_spec, + tosa_graph, + debug_hook, + ) + # Serialize and return the TOSA flatbuffer. + binary = tosa_graph.serialize() + + if artifact_path: + tag = arm_get_first_delegation_tag(edge_program.graph_module) + debug_tosa_dump( + binary, + artifact_path, + suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), + ) + + if debug_hook is not None: + if debug_hook.mode == ArmCompileSpec.DebugMode.JSON: + json_output = debug_hook.serialize() + with open(f"{artifact_path}/debug.json", "w") as f: + f.write(json_output) + + return PreprocessResult(processed_bytes=binary) + + @staticmethod + def _preprocess_module( # noqa: C901 + graph_module: GraphModule, + edge_program: ExportedProgram, + compile_spec: TosaCompileSpec, + tosa_graph: ts.TosaSerializer, + debug_hook: DebugHook | None, + submodule_name: str | None = None, + ): + """Convert 'graph_module' to a tosa_graph""" + tosa_spec = compile_spec.tosa_spec + node_to_id_map = _annotate_external_ids(graph_module.graph) + artifact_path = compile_spec.get_intermediate_path() + # TODO: Fix the need to lazily import this. from executorch.backends.arm._passes import ArmPassManager graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore - exported_program=edge_program + exported_program=edge_program, graph_module=graph_module ) - debug_hook = None - if dump_debug_info is not None: - debug_hook = DebugHook(dump_debug_info) - # TODO: Fix the need to lazily import this. from executorch.backends.arm.operators.node_visitor import get_node_visitors node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook) + graph_module = _sort_outputs(graph_module, node_to_id_map) - # Re-shuffle output nodes to preserve author's order - def _external_id(n: Node, node_2_id, fallback: int) -> int: - return node_2_id.get(n.name, fallback) - - out_node = next(n for n in graph_module.graph.nodes if n.op == "output") - _counter = count() - - # sort nodes by the key that is id - def _sort_key(t: Node) -> int: - return _external_id(t, node_2_id, next(_counter)) + if submodule_name is not None: + tosa_graph.startRegion(submodule_name) + tosa_graph.currRegion.addBasicBlock(submodule_name) + suffix = f"_{submodule_name}" + for loop_node in graph_module.graph.nodes: + loop_node.meta[TOSA_TENSOR_NAME_META] = suffix - orig_ord = tuple(sorted(out_node.args[0], key=_sort_key)) - - current_order = tuple(out_node.args[0]) - if orig_ord != current_order: - replacement = ( - list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord - ) - out_node.args = (replacement,) - graph_module.graph.lint() - graph_module.recompile() - - input_count = 0 for node in graph_module.graph.nodes: node = cast(Node, node) try: @@ -164,37 +214,27 @@ def _sort_key(t: Node) -> int: if len(node.users) == 0: continue process_placeholder(node, tosa_graph, edge_program, tosa_spec) - if node.name in edge_program.graph_signature.user_inputs: - input_count += 1 elif node.op == "output": - process_output(node, tosa_graph) + process_output(node, tosa_graph, tosa_spec) else: # This will only happen if an unpartitioned graph is passed without # any checking of compatibility. raise RuntimeError(f"{node.name} is unsupported op {node.op}") except Exception: - debug_fail(node, graph_module, tosa_graph.serialize(), artifact_path) + debug_fail(node, graph_module, tosa_graph, artifact_path) raise - # Serialize and return the TOSA flatbuffer. - binary = tosa_graph.serialize() - - if artifact_path: - tag = arm_get_first_delegation_tag(graph_module) - debug_tosa_dump( - binary, - artifact_path, - suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), + # Recursively preprocess controlflow submodules. + for name, submodule, _ in get_control_flow_submodules(graph_module): + TOSABackend._preprocess_module( + submodule, + edge_program, + compile_spec, + tosa_graph, + debug_hook, + submodule_name=name, ) - if debug_hook is not None: - if debug_hook.mode == ArmCompileSpec.DebugMode.JSON: - json_output = debug_hook.serialize() - with open(f"{artifact_path}/debug.json", "w") as f: - f.write(json_output) - - return PreprocessResult(processed_bytes=binary) - @staticmethod def filter_tosa_compile_specs( compile_spec: ArmCompileSpec, diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 2287a727009..5162d2c6a53 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -17,6 +17,8 @@ import tosa_serializer as ts from executorch.backends.arm.tosa.specification import TosaSpecification +TOSA_TENSOR_NAME_META = "tosa_tensor_name" + UNSUPPORTED_DTYPES = ( torch.float64, torch.double, @@ -144,7 +146,7 @@ def __process_node(self, argument: torch.fx.Node): argument (torch.fx.Node): FX node to inspect. """ - self.name: str = argument.name + self.name = argument.name + argument.meta.get(TOSA_TENSOR_NAME_META, "") output_dtype, self.shape, self.dim_order = extract_tensor_meta( argument.meta, self.tosa_spec )