44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from executorch .exir .graph_module import get_control_flow_submodules
7+ from executorch .exir .graph_module import bfs_trace_with_node_process
88from executorch .exir .pass_base import ExportPass
99from torch .export import ExportedProgram
1010from torch .fx import GraphModule
@@ -17,19 +17,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
1717 to executorch backend, that has a canonical set of quantized operators
1818 """
1919
20- queue = [graph_module ]
2120 index = 1
22- # bfs to traverse all modules including control flow submodules to attached debug handle id
23- while queue :
24- current_graph_module = queue .pop (0 )
25- for node in current_graph_module .graph .nodes :
26- node .meta ["debug_handle" ] = index
27- index += 1
28- control_flow_submodules = [
29- submodule
30- for _ , submodule , _ in get_control_flow_submodules (current_graph_module )
31- ]
32- queue .extend (control_flow_submodules )
21+
22+ def _extract_debug_handles_from_node (node ):
23+ nonlocal index
24+ node .meta ["debug_handle" ] = index
25+ index += 1
26+
27+ bfs_trace_with_node_process (graph_module , _extract_debug_handles_from_node )
28+
3329 return PassResult (graph_module , True )
3430
3531
@@ -38,28 +34,18 @@ def generate_missing_debug_handles(ep: ExportedProgram):
3834 This pass is used to generate missing debug handles for the graph module and its submodules.
3935 """
4036
41- def get_control_flow_submodules_list (graph_module ):
42- return [
43- submodule for _ , submodule , _ in get_control_flow_submodules (graph_module )
44- ]
45-
4637 max_handle = 0
47- queue = [ep .graph_module ]
4838
49- while queue :
50- current_graph_module = queue .pop (0 )
51- for node in current_graph_module .graph .nodes :
52- if "debug_handle" in node .meta :
53- max_handle = max (max_handle , node .meta ["debug_handle" ])
54- control_flow_submodules = get_control_flow_submodules_list (current_graph_module )
55- queue .extend (control_flow_submodules )
39+ def _extract_max_debug_handle (node ):
40+ nonlocal max_handle
41+ if "debug_handle" in node .meta :
42+ max_handle = max (max_handle , node .meta ["debug_handle" ])
43+
44+ def _insert_new_debug_handles (node ):
45+ nonlocal max_handle
46+ if node .meta .get ("debug_handle" , 0 ) in (0 , None ):
47+ node .meta ["debug_handle" ] = max_handle + 1
48+ max_handle += 1
5649
57- queue = [ep .graph_module ]
58- while queue :
59- current_graph_module = queue .pop (0 )
60- for node in current_graph_module .graph .nodes :
61- if node .meta .get ("debug_handle" , 0 ) in (0 , None ):
62- node .meta ["debug_handle" ] = max_handle + 1
63- max_handle += 1
64- control_flow_submodules = get_control_flow_submodules_list (current_graph_module )
65- queue .extend (control_flow_submodules )
50+ bfs_trace_with_node_process (ep .graph_module , _extract_max_debug_handle )
51+ bfs_trace_with_node_process (ep .graph_module , _insert_new_debug_handles )
0 commit comments