From 9ac97638d2bd70f90b739ce03b640da6bd625c7d Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 5 Dec 2024 00:55:45 -0800 Subject: [PATCH] restucture debug handle Differential Revision: [D66622890](https://our.internmc.facebook.com/intern/diff/D66622890/) [ghstack-poisoned] --- exir/graph_module.py | 24 ++++++++- exir/passes/debug_handle_generator_pass.py | 57 ++++++++-------------- 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/exir/graph_module.py b/exir/graph_module.py index 7a032b5290d..6edd5143e5f 100644 --- a/exir/graph_module.py +++ b/exir/graph_module.py @@ -7,7 +7,7 @@ # pyre-strict from types import FunctionType as function -from typing import Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union import torch @@ -68,3 +68,25 @@ def get_control_flow_submodules( control_flow_submodules.append(_get_submodule(graph_module, node, 0)) return control_flow_submodules + +# TODO(gasoonjia): remove this and leverage core pytorch bfs_trace_with_node_process after code freeze +def bfs_trace_with_node_process( + gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None] +) -> None: + """Traverse the graph module and apply node_op to each node.""" + + assert isinstance( + gm, torch.fx.GraphModule + ), f"Expected GraphModule, got {type(gm)}" + + queue = [gm] + while queue: + current_graph_module = queue.pop(0) + for node in current_graph_module.graph.nodes: + node_op(node) + + control_flow_submodules = [ + submodule + for _, submodule, _ in get_control_flow_submodules(current_graph_module) + ] + queue.extend(control_flow_submodules) diff --git a/exir/passes/debug_handle_generator_pass.py b/exir/passes/debug_handle_generator_pass.py index 0502c47dbb3..1374cf08b28 100644 --- a/exir/passes/debug_handle_generator_pass.py +++ b/exir/passes/debug_handle_generator_pass.py @@ -4,32 +4,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from executorch.exir.graph_module import get_control_flow_submodules +from executorch.exir.graph_module import bfs_trace_with_node_process from executorch.exir.pass_base import ExportPass from torch.export import ExportedProgram from torch.fx import GraphModule from torch.fx.passes.infra.pass_base import PassResult - class DebugHandleGeneratorPass(ExportPass): def call(self, graph_module: GraphModule) -> PassResult: """Lower a quantized reference model (with reference quantized operator patterns) to executorch backend, that has a canonical set of quantized operators """ - queue = [graph_module] index = 1 - # bfs to traverse all modules including control flow submodules to attached debug handle id - while queue: - current_graph_module = queue.pop(0) - for node in current_graph_module.graph.nodes: - node.meta["debug_handle"] = index - index += 1 - control_flow_submodules = [ - submodule - for _, submodule, _ in get_control_flow_submodules(current_graph_module) - ] - queue.extend(control_flow_submodules) + + def _extract_debug_handles_from_node(node): + nonlocal index + node.meta["debug_handle"] = index + index += 1 + + bfs_trace_with_node_process(graph_module, _extract_debug_handles_from_node) + return PassResult(graph_module, True) @@ -38,28 +33,18 @@ def generate_missing_debug_handles(ep: ExportedProgram): This pass is used to generate missing debug handles for the graph module and its submodules. """ - def get_control_flow_submodules_list(graph_module): - return [ - submodule for _, submodule, _ in get_control_flow_submodules(graph_module) - ] - max_handle = 0 - queue = [ep.graph_module] - while queue: - current_graph_module = queue.pop(0) - for node in current_graph_module.graph.nodes: - if "debug_handle" in node.meta: - max_handle = max(max_handle, node.meta["debug_handle"]) - control_flow_submodules = get_control_flow_submodules_list(current_graph_module) - queue.extend(control_flow_submodules) + def _extract_max_debug_handle(node): + nonlocal max_handle + if "debug_handle" in node.meta: + max_handle = max(max_handle, node.meta["debug_handle"]) + + def _insert_new_debug_handles(node): + nonlocal max_handle + if node.meta.get("debug_handle", 0) in (0, None): + node.meta["debug_handle"] = max_handle + 1 + max_handle += 1 - queue = [ep.graph_module] - while queue: - current_graph_module = queue.pop(0) - for node in current_graph_module.graph.nodes: - if node.meta.get("debug_handle", 0) in (0, None): - node.meta["debug_handle"] = max_handle + 1 - max_handle += 1 - control_flow_submodules = get_control_flow_submodules_list(current_graph_module) - queue.extend(control_flow_submodules) + bfs_trace_with_node_process(ep.graph_module, _extract_max_debug_handle) + bfs_trace_with_node_process(ep.graph_module, _insert_new_debug_handles)