Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 2 additions & 22 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,6 @@ def _consume_etrecord(self) -> None:

def _get_aot_intermediate_outputs_and_op_names(
self,
disable_debug_handle_valdiation: bool = False,
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
"""
Capture intermediate outputs only if _representative_inputs are provided
Expand All @@ -1185,7 +1184,6 @@ def _get_aot_intermediate_outputs_and_op_names(
self._etrecord.exported_program,
self._etrecord.export_graph_id,
self._etrecord.edge_dialect_program,
disable_debug_handle_valdiation,
):
export_program = self._etrecord.exported_program
else:
Expand Down Expand Up @@ -1406,9 +1404,7 @@ def get_exported_program(
else self._etrecord.graph_map.get(graph)
)

def calculate_numeric_gap(
self, distance: str = "MSE", disable_debug_handle_valdiation: bool = False
):
def calculate_numeric_gap(self, distance: str = "MSE"):
"""
Compares logged intermediate outputs from the exported graph (in ETRecord)
with runtime outputs (in ETDump) using a user-specific numerical comparator.
Expand All @@ -1420,19 +1416,12 @@ def calculate_numeric_gap(

Args:
distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR".
disable_debug_handle_validation: Often when aten graph has symbolic shape nodes, and inbuilt ops like gt/lt etc.,
during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection
between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding
node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This
flag allows one to override such behavior and make best effort comparison.

Returns:
pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps.
"""
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
self._get_aot_intermediate_outputs_and_op_names(
disable_debug_handle_valdiation
)
self._get_aot_intermediate_outputs_and_op_names()
)
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0:
raise ValueError(
Expand Down Expand Up @@ -1462,15 +1451,6 @@ def calculate_numeric_gap(
) in mapping.items():
if aot_intermediate_output is None or runtime_intermediate_output is None:
continue
# If aot outputs length is > 1 then comparison fails since we dont really have
# any instances where runtime intermediate output is a tuple or list
# This does not happen when edge dialect program is reference for comparison
# but happens in aten graph where ops like unbind remain undecomposed
if (
isinstance(aot_intermediate_output, Sequence)
and len(aot_intermediate_output) > 1
):
continue
rows.append(
{
"aot_ops": find_op_names(
Expand Down
14 changes: 2 additions & 12 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
# Ensure both sequences have the same length
if len(a) != len(b):
raise ValueError(
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}."
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison."
)

# Compare each element in the sequences and return the list of results
Expand All @@ -990,9 +990,6 @@ def get_ancestor_node_identifiers(node: Node) -> List[str]:
Returns: the identifiers of all its ancestor nodes
"""

if FROM_NODE_KEY not in node.meta:
return []

node_source = node.meta[FROM_NODE_KEY]
node_source = node_source[-1]
ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"]
Expand Down Expand Up @@ -1114,7 +1111,6 @@ def propagate_back_debug_handle(
exported_program: ExportedProgram,
exported_program_graph_id: int,
edge_dialect_program: ExportedProgram,
disable_debug_handle_valdiation: bool = False,
) -> bool:
"""
Propagate debug handle from edge dialect program back to the exported program while maintain the correctness
Expand All @@ -1128,10 +1124,6 @@ def propagate_back_debug_handle(
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.

disable_debug_handle_validation is used to avoid _verify_graph_match() in case of debug handle mismatch.
This can happen when we are comparing against aten graph in which case not all debug handles are matched
in aten graph. Example of this is when symbolic shape nodes are re-exported.

Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
"""
# 1. Extract mapping from ancestor node identifiers to debug handles
Expand All @@ -1145,9 +1137,7 @@ def propagate_back_debug_handle(
)

# 3. Verify if every debug handle in edge dialect program has a corresponding node
if not disable_debug_handle_valdiation and not _verify_graph_match(
edge_dialect_program, matched_debug_handles
):
if not _verify_graph_match(edge_dialect_program, matched_debug_handles):
return False

# 4. Apply debug handles to the exported program
Expand Down
118 changes: 1 addition & 117 deletions devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def test_calculate_numeric_gap(self):
aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}

inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: (
inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda: (
aot_intermediate_outputs,
aot_debug_handle_to_op_name,
)
Expand Down Expand Up @@ -838,122 +838,6 @@ def _gen_random_runtime_output(
) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
return [torch.randn(RAW_DATA_SIZE)]

def test_disable_debug_handle_validation_with_symbolic_shapes(self):
"""
Test that demonstrates the issue with symbolic shape related nodes losing from_node info
during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter
in propagate_back_debug_handle allows validation to be bypassed.
"""
from executorch.devtools.inspector._inspector_utils import (
propagate_back_debug_handle,
)

class SymbolicShapeModel(torch.nn.Module):
"""Model that will have symbolic shape related operations after export."""

def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# This will create symbolic shape nodes during dynamic export
batch_size = x.shape[0]
x = x + torch.rand((batch_size, 1))
# Masking operation that creates gt/lt nodes
valid_mask = mask > 0.5
x = torch.where(valid_mask, x, torch.zeros_like(x))
return x

# Create model and dynamic inputs
model = SymbolicShapeModel()
batch_size = 2
seq_len = 4
x = torch.randn(batch_size, seq_len)
mask = torch.rand(batch_size, seq_len)
example_inputs = (x, mask)

# Export with dynamic shapes to create symbolic shape related nodes
dynamic_shapes = {
"x": {0: torch.export.Dim("batch_size", min=1, max=10)},
"mask": {0: torch.export.Dim("batch_size", min=1, max=10)},
}

exported_program = torch.export.export(
model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
)

"""
In this case origina aten graph has sym_size_int_2 node but when we look at
nodes metadata in edge_program_manager, its sym_size node's from_node says
sym_size_int_3 which is not in the original aten graph.
"""
# Create edge program - this is where from_node info can be lost for symbolic shape nodes
edge_program_manager: EdgeProgramManager = to_edge(exported_program)
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
et_program_manager: ExecutorchProgramManager = (
edge_program_manager.to_executorch()
)

with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
etrecord_path = tmp_file.name

# Generate ETRecord with the exported program (aten graph)
generate_etrecord(
etrecord_path,
edge_program_manager_copy,
et_program_manager,
exported_program=exported_program,
)

# Create Inspector and get etrecord
with patch.object(
_inspector, "gen_etdump_object", return_value=None
), patch.object(EventBlock, "_gen_from_etdump"):
inspector_instance = Inspector(
etdump_path=ETDUMP_PATH,
etrecord=etrecord_path,
)

# Extract the necessary values from the inspector's etrecord
exported_program_from_etrecord = (
inspector_instance._etrecord.exported_program
)
export_graph_id = inspector_instance._etrecord.export_graph_id
edge_dialect_program = inspector_instance._etrecord.edge_dialect_program

# Ensure we have all the necessary components
self.assertIsNotNone(exported_program_from_etrecord)
self.assertIsNotNone(export_graph_id)
self.assertIsNotNone(edge_dialect_program)

# Test propagate_back_debug_handle with validation enabled (should fail or return False)
# This demonstrates the issue with symbolic shape nodes losing from_node info
validation_enabled_result = propagate_back_debug_handle(
exported_program_from_etrecord,
export_graph_id,
edge_dialect_program,
disable_debug_handle_valdiation=False,
)

# With validation enabled, it should return False when from_node info is lost
self.assertFalse(
validation_enabled_result,
"propagate_back_debug_handle should return False when validation is enabled "
"and symbolic shape nodes lose from_node info",
)

# Test propagate_back_debug_handle with validation disabled (should succeed)
# This shows how the disable_debug_handle_valdiation flag allows the function to work
validation_disabled_result = propagate_back_debug_handle(
exported_program_from_etrecord,
export_graph_id,
edge_dialect_program,
disable_debug_handle_valdiation=True,
)

# With validation disabled, it should return True even when from_node info is lost
self.assertTrue(
validation_disabled_result,
"propagate_back_debug_handle should return True when validation is disabled, "
"allowing best effort comparison even when from_node info is lost",
)

def _gen_random_events(self) -> List[Event]:
events = []
for i in range(2):
Expand Down
Loading