diff --git a/exir/passes/debug_handle_generator_pass.py b/exir/passes/debug_handle_generator_pass.py index bf2b99da9da..0502c47dbb3 100644 --- a/exir/passes/debug_handle_generator_pass.py +++ b/exir/passes/debug_handle_generator_pass.py @@ -51,7 +51,7 @@ def get_control_flow_submodules_list(graph_module): for node in current_graph_module.graph.nodes: if "debug_handle" in node.meta: max_handle = max(max_handle, node.meta["debug_handle"]) - control_flow_submodules = get_control_flow_submodules_list(ep.graph_module) + control_flow_submodules = get_control_flow_submodules_list(current_graph_module) queue.extend(control_flow_submodules) queue = [ep.graph_module] @@ -61,5 +61,5 @@ def get_control_flow_submodules_list(graph_module): if node.meta.get("debug_handle", 0) in (0, None): node.meta["debug_handle"] = max_handle + 1 max_handle += 1 - control_flow_submodules = get_control_flow_submodules_list(ep.graph_module) + control_flow_submodules = get_control_flow_submodules_list(current_graph_module) queue.extend(control_flow_submodules) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index be2bb4be333..a3aa1f39975 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1018,16 +1018,13 @@ def forward( torch.ones(2, 2), ) - graph_module = ( - to_edge( - export( - f, - inputs, - ) + ep = to_edge( + export( + f, + inputs, ) - .exported_program() - .graph_module - ) + ).exported_program() + graph_module = ep.graph_module def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None: queue = [graph_module] @@ -1045,6 +1042,7 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None: DebugHandleGeneratorPass()(graph_module) check_debug_handle_metadata(graph_module) + generate_missing_debug_handles(ep) # Check debug handle still preserved after ScalarToTensorPass ScalarToTensorPass()(graph_module)