Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,13 +657,21 @@ def _combine_aot_overlapped_intermediate_outputs(
# Combine all AOT debug_handles into a list
aot_combined_debug_handle = [t[0] for t in aot_map.keys()]

if set(aot_combined_debug_handle) != set(runtime_debug_handle):
# AOT combined debug_handle and runtime debug_handle do not match.
# Reason we dont check for exact match:
# in some experiments where we want to rewrite the aten graph that was
# lowered, so as to use custom ops like int4_matmul, we lose some nodes
# on the graph and thus lose some debug handles. And we dont find
# exact match within connected components.
if not set(aot_combined_debug_handle).issubset(set(runtime_debug_handle)):
# AOT combined debug_handle is not a subset of runtime debug_handle.
return (-1,), None

# Pick the last intermediate output
last_int = runtime_debug_handle[negative_index]
key = (last_int,)
if key not in aot_map:
# If the last intermediate output is not in the AOT map, return None
return (-1,), None
return runtime_debug_handle, aot_map[key]


Expand Down Expand Up @@ -1059,11 +1067,16 @@ def _find_n_match_node(node: Node) -> None:
if node.op in ("output", "placeholder"):
return
node_id = f"{node.name}.{exported_program_graph_id}"
parent_node_id = get_parent_node_identifier(node)
parent_node_ids = get_ancestor_node_identifiers(node)
if node_id in ancestors_node_id_to_debug_handle:
matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id])
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id])
elif parent_node_ids:
for parent_node_id in parent_node_ids:
if parent_node_id in ancestors_node_id_to_debug_handle:
matched_debug_handles.add(
ancestors_node_id_to_debug_handle[parent_node_id]
)
break

bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
return matched_debug_handles
Expand Down Expand Up @@ -1097,15 +1110,17 @@ def _equip_debug_handle(node: Node) -> None:
if node.op in ("output", "placeholder"):
return
node_id = f"{node.name}.{exported_program_graph_id}"
parent_node_id = get_parent_node_identifier(node)
parent_node_ids = get_ancestor_node_identifiers(node)
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
if node_id in ancestors_node_id_to_debug_handle:
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id]
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
parent_node_id
]
else:
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
elif parent_node_ids:
for parent_node_id in parent_node_ids:
if parent_node_id in ancestors_node_id_to_debug_handle:
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
parent_node_id
]
break

bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)

Expand Down
6 changes: 4 additions & 2 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,15 @@ def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_partial_match(self):
# Partial match between aot and runtime debug_handles will return empty
# Partial match between aot and runtime debug_handles will return
# matching debug handles from runtime
aot_intermediate_outputs = {(2,): 100, (9,): 300}
runtime_intermediate_outputs = {(2, 3): (200, 1), (8, 9): (300, 1)}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {}
# Since the runtime output debug handle of 9 is there in aot debug handle
expected = {((8, 9), 300): ((8, 9), 300)}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
Expand Down
Loading