diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 8703e260883..a3933ffb993 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -657,13 +657,21 @@ def _combine_aot_overlapped_intermediate_outputs( # Combine all AOT debug_handles into a list aot_combined_debug_handle = [t[0] for t in aot_map.keys()] - if set(aot_combined_debug_handle) != set(runtime_debug_handle): - # AOT combined debug_handle and runtime debug_handle do not match. + # Reason we dont check for exact match: + # in some experiments where we want to rewrite the aten graph that was + # lowered, so as to use custom ops like int4_matmul, we lose some nodes + # on the graph and thus lose some debug handles. And we dont find + # exact match within connected components. + if not set(aot_combined_debug_handle).issubset(set(runtime_debug_handle)): + # AOT combined debug_handle is not a subset of runtime debug_handle. return (-1,), None # Pick the last intermediate output last_int = runtime_debug_handle[negative_index] key = (last_int,) + if key not in aot_map: + # If the last intermediate output is not in the AOT map, return None + return (-1,), None return runtime_debug_handle, aot_map[key] @@ -1059,11 +1067,16 @@ def _find_n_match_node(node: Node) -> None: if node.op in ("output", "placeholder"): return node_id = f"{node.name}.{exported_program_graph_id}" - parent_node_id = get_parent_node_identifier(node) + parent_node_ids = get_ancestor_node_identifiers(node) if node_id in ancestors_node_id_to_debug_handle: matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id]) - elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: - matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id]) + elif parent_node_ids: + for parent_node_id in parent_node_ids: + if parent_node_id in ancestors_node_id_to_debug_handle: + matched_debug_handles.add( + ancestors_node_id_to_debug_handle[parent_node_id] + ) + break bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) return matched_debug_handles @@ -1097,15 +1110,17 @@ def _equip_debug_handle(node: Node) -> None: if node.op in ("output", "placeholder"): return node_id = f"{node.name}.{exported_program_graph_id}" - parent_node_id = get_parent_node_identifier(node) + parent_node_ids = get_ancestor_node_identifiers(node) + node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE if node_id in ancestors_node_id_to_debug_handle: node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id] - elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: - node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[ - parent_node_id - ] - else: - node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE + elif parent_node_ids: + for parent_node_id in parent_node_ids: + if parent_node_id in ancestors_node_id_to_debug_handle: + node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[ + parent_node_id + ] + break bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle) diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 26fe38acfac..8c4bb4b38b9 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -334,13 +334,15 @@ def test_map_runtime_aot_intermediate_outputs_no_overlaps(self): self.assertEqual(actual, expected) def test_map_runtime_aot_intermediate_outputs_partial_match(self): - # Partial match between aot and runtime debug_handles will return empty + # Partial match between aot and runtime debug_handles will return + # matching debug handles from runtime aot_intermediate_outputs = {(2,): 100, (9,): 300} runtime_intermediate_outputs = {(2, 3): (200, 1), (8, 9): (300, 1)} actual = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) - expected = {} + # Since the runtime output debug handle of 9 is there in aot debug handle + expected = {((8, 9), 300): ((8, 9), 300)} self.assertEqual(actual, expected) def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):