From acbf935b153f680383c31b6802634b1590d05db9 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 15 Sep 2025 11:37:20 -0700 Subject: [PATCH] Allow for matching debug handles with partial overlap between aten graph and runtime (#14306) Summary: When aten graph is modified for debug, for instance using int4 matmul, it wont have complete overlap with debug handles recorded by the delegate. For example, original model will have chose_qparams,q, dq, dq, linear nodes. Delegate will record debug hanlde for all of those. Say those are (4, 5, 6, 7, 8). When int4 matmul rewrite pass, from torchao, is applied, we just inherit from_node information from linear node. Thus only the last debug handle 8 is associated with custom op int4 node. Thus when we map delegate debug handles with custom op we find overlap for 8 only. This diff allows to look for overlapping match instead of exact match. Plus it also changes the code for AOT debug handle so that we can look for all ancestor nodes instead of just parent node. This is also needed so as to allow for numerical comparison despite passes applied on original aten graph. Reviewed By: Gasoonjia Differential Revision: D82229367 --- devtools/inspector/_inspector_utils.py | 39 +++++++++++++------ .../inspector/tests/inspector_utils_test.py | 6 ++- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 8703e260883..a3933ffb993 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -657,13 +657,21 @@ def _combine_aot_overlapped_intermediate_outputs( # Combine all AOT debug_handles into a list aot_combined_debug_handle = [t[0] for t in aot_map.keys()] - if set(aot_combined_debug_handle) != set(runtime_debug_handle): - # AOT combined debug_handle and runtime debug_handle do not match. + # Reason we dont check for exact match: + # in some experiments where we want to rewrite the aten graph that was + # lowered, so as to use custom ops like int4_matmul, we lose some nodes + # on the graph and thus lose some debug handles. And we dont find + # exact match within connected components. + if not set(aot_combined_debug_handle).issubset(set(runtime_debug_handle)): + # AOT combined debug_handle is not a subset of runtime debug_handle. return (-1,), None # Pick the last intermediate output last_int = runtime_debug_handle[negative_index] key = (last_int,) + if key not in aot_map: + # If the last intermediate output is not in the AOT map, return None + return (-1,), None return runtime_debug_handle, aot_map[key] @@ -1059,11 +1067,16 @@ def _find_n_match_node(node: Node) -> None: if node.op in ("output", "placeholder"): return node_id = f"{node.name}.{exported_program_graph_id}" - parent_node_id = get_parent_node_identifier(node) + parent_node_ids = get_ancestor_node_identifiers(node) if node_id in ancestors_node_id_to_debug_handle: matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id]) - elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: - matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id]) + elif parent_node_ids: + for parent_node_id in parent_node_ids: + if parent_node_id in ancestors_node_id_to_debug_handle: + matched_debug_handles.add( + ancestors_node_id_to_debug_handle[parent_node_id] + ) + break bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) return matched_debug_handles @@ -1097,15 +1110,17 @@ def _equip_debug_handle(node: Node) -> None: if node.op in ("output", "placeholder"): return node_id = f"{node.name}.{exported_program_graph_id}" - parent_node_id = get_parent_node_identifier(node) + parent_node_ids = get_ancestor_node_identifiers(node) + node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE if node_id in ancestors_node_id_to_debug_handle: node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id] - elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: - node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[ - parent_node_id - ] - else: - node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE + elif parent_node_ids: + for parent_node_id in parent_node_ids: + if parent_node_id in ancestors_node_id_to_debug_handle: + node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[ + parent_node_id + ] + break bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle) diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 26fe38acfac..8c4bb4b38b9 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -334,13 +334,15 @@ def test_map_runtime_aot_intermediate_outputs_no_overlaps(self): self.assertEqual(actual, expected) def test_map_runtime_aot_intermediate_outputs_partial_match(self): - # Partial match between aot and runtime debug_handles will return empty + # Partial match between aot and runtime debug_handles will return + # matching debug handles from runtime aot_intermediate_outputs = {(2,): 100, (9,): 300} runtime_intermediate_outputs = {(2, 3): (200, 1), (8, 9): (300, 1)} actual = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) - expected = {} + # Since the runtime output debug handle of 9 is there in aot debug handle + expected = {((8, 9), 300): ((8, 9), 300)} self.assertEqual(actual, expected) def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):