From 69074bc01e7b5a2a2c7d08990d76eb8932fdb257 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 11 Sep 2025 13:12:48 -0700 Subject: [PATCH] Optionally disable debug handle validateion (#14182) Summary: Often when aten graph has symbolic shape nodes, and inbuilt ops like gt/lt etc., during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This flag allows one to override such behavior and make best effort comparison. Reviewed By: Gasoonjia Differential Revision: D81784685 --- devtools/inspector/_inspector.py | 24 ++++- devtools/inspector/_inspector_utils.py | 14 ++- devtools/inspector/tests/inspector_test.py | 118 ++++++++++++++++++++- 3 files changed, 151 insertions(+), 5 deletions(-) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 6d046d8f2e8..323bda44a2c 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1169,6 +1169,7 @@ def _consume_etrecord(self) -> None: def _get_aot_intermediate_outputs_and_op_names( self, + disable_debug_handle_valdiation: bool = False, ) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]: """ Capture intermediate outputs only if _representative_inputs are provided @@ -1184,6 +1185,7 @@ def _get_aot_intermediate_outputs_and_op_names( self._etrecord.exported_program, self._etrecord.export_graph_id, self._etrecord.edge_dialect_program, + disable_debug_handle_valdiation, ): export_program = self._etrecord.exported_program else: @@ -1404,7 +1406,9 @@ def get_exported_program( else self._etrecord.graph_map.get(graph) ) - def calculate_numeric_gap(self, distance: str = "MSE"): + def calculate_numeric_gap( + self, distance: str = "MSE", disable_debug_handle_valdiation: bool = False + ): """ Compares logged intermediate outputs from the exported graph (in ETRecord) with runtime outputs (in ETDump) using a user-specific numerical comparator. @@ -1416,12 +1420,19 @@ def calculate_numeric_gap(self, distance: str = "MSE"): Args: distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR". + disable_debug_handle_validation: Often when aten graph has symbolic shape nodes, and inbuilt ops like gt/lt etc., + during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection + between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding + node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This + flag allows one to override such behavior and make best effort comparison. Returns: pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps. """ aot_intermediate_outputs, aot_debug_handle_to_op_names = ( - self._get_aot_intermediate_outputs_and_op_names() + self._get_aot_intermediate_outputs_and_op_names( + disable_debug_handle_valdiation + ) ) if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0: raise ValueError( @@ -1451,6 +1462,15 @@ def calculate_numeric_gap(self, distance: str = "MSE"): ) in mapping.items(): if aot_intermediate_output is None or runtime_intermediate_output is None: continue + # If aot outputs length is > 1 then comparison fails since we dont really have + # any instances where runtime intermediate output is a tuple or list + # This does not happen when edge dialect program is reference for comparison + # but happens in aten graph where ops like unbind remain undecomposed + if ( + isinstance(aot_intermediate_output, Sequence) + and len(aot_intermediate_output) > 1 + ): + continue rows.append( { "aot_ops": find_op_names( diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index ee7ebb2f5ea..8703e260883 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -965,7 +965,7 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: # Ensure both sequences have the same length if len(a) != len(b): raise ValueError( - f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison." + f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}." ) # Compare each element in the sequences and return the list of results @@ -990,6 +990,9 @@ def get_ancestor_node_identifiers(node: Node) -> List[str]: Returns: the identifiers of all its ancestor nodes """ + if FROM_NODE_KEY not in node.meta: + return [] + node_source = node.meta[FROM_NODE_KEY] node_source = node_source[-1] ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"] @@ -1111,6 +1114,7 @@ def propagate_back_debug_handle( exported_program: ExportedProgram, exported_program_graph_id: int, edge_dialect_program: ExportedProgram, + disable_debug_handle_valdiation: bool = False, ) -> bool: """ Propagate debug handle from edge dialect program back to the exported program while maintain the correctness @@ -1124,6 +1128,10 @@ def propagate_back_debug_handle( Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1. The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping. + disable_debug_handle_validation is used to avoid _verify_graph_match() in case of debug handle mismatch. + This can happen when we are comparing against aten graph in which case not all debug handles are matched + in aten graph. Example of this is when symbolic shape nodes are re-exported. + Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False. """ # 1. Extract mapping from ancestor node identifiers to debug handles @@ -1137,7 +1145,9 @@ def propagate_back_debug_handle( ) # 3. Verify if every debug handle in edge dialect program has a corresponding node - if not _verify_graph_match(edge_dialect_program, matched_debug_handles): + if not disable_debug_handle_valdiation and not _verify_graph_match( + edge_dialect_program, matched_debug_handles + ): return False # 4. Apply debug handles to the exported program diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index a3afed07ed8..babefb4e2ac 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -681,7 +681,7 @@ def test_calculate_numeric_gap(self): aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda: ( + inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( aot_intermediate_outputs, aot_debug_handle_to_op_name, ) @@ -838,6 +838,122 @@ def _gen_random_runtime_output( ) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]: return [torch.randn(RAW_DATA_SIZE)] + def test_disable_debug_handle_validation_with_symbolic_shapes(self): + """ + Test that demonstrates the issue with symbolic shape related nodes losing from_node info + during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter + in propagate_back_debug_handle allows validation to be bypassed. + """ + from executorch.devtools.inspector._inspector_utils import ( + propagate_back_debug_handle, + ) + + class SymbolicShapeModel(torch.nn.Module): + """Model that will have symbolic shape related operations after export.""" + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # This will create symbolic shape nodes during dynamic export + batch_size = x.shape[0] + x = x + torch.rand((batch_size, 1)) + # Masking operation that creates gt/lt nodes + valid_mask = mask > 0.5 + x = torch.where(valid_mask, x, torch.zeros_like(x)) + return x + + # Create model and dynamic inputs + model = SymbolicShapeModel() + batch_size = 2 + seq_len = 4 + x = torch.randn(batch_size, seq_len) + mask = torch.rand(batch_size, seq_len) + example_inputs = (x, mask) + + # Export with dynamic shapes to create symbolic shape related nodes + dynamic_shapes = { + "x": {0: torch.export.Dim("batch_size", min=1, max=10)}, + "mask": {0: torch.export.Dim("batch_size", min=1, max=10)}, + } + + exported_program = torch.export.export( + model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + + """ + In this case origina aten graph has sym_size_int_2 node but when we look at + nodes metadata in edge_program_manager, its sym_size node's from_node says + sym_size_int_3 which is not in the original aten graph. + """ + # Create edge program - this is where from_node info can be lost for symbolic shape nodes + edge_program_manager: EdgeProgramManager = to_edge(exported_program) + edge_program_manager_copy = copy.deepcopy(edge_program_manager) + et_program_manager: ExecutorchProgramManager = ( + edge_program_manager.to_executorch() + ) + + with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: + etrecord_path = tmp_file.name + + # Generate ETRecord with the exported program (aten graph) + generate_etrecord( + etrecord_path, + edge_program_manager_copy, + et_program_manager, + exported_program=exported_program, + ) + + # Create Inspector and get etrecord + with patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object(EventBlock, "_gen_from_etdump"): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=etrecord_path, + ) + + # Extract the necessary values from the inspector's etrecord + exported_program_from_etrecord = ( + inspector_instance._etrecord.exported_program + ) + export_graph_id = inspector_instance._etrecord.export_graph_id + edge_dialect_program = inspector_instance._etrecord.edge_dialect_program + + # Ensure we have all the necessary components + self.assertIsNotNone(exported_program_from_etrecord) + self.assertIsNotNone(export_graph_id) + self.assertIsNotNone(edge_dialect_program) + + # Test propagate_back_debug_handle with validation enabled (should fail or return False) + # This demonstrates the issue with symbolic shape nodes losing from_node info + validation_enabled_result = propagate_back_debug_handle( + exported_program_from_etrecord, + export_graph_id, + edge_dialect_program, + disable_debug_handle_valdiation=False, + ) + + # With validation enabled, it should return False when from_node info is lost + self.assertFalse( + validation_enabled_result, + "propagate_back_debug_handle should return False when validation is enabled " + "and symbolic shape nodes lose from_node info", + ) + + # Test propagate_back_debug_handle with validation disabled (should succeed) + # This shows how the disable_debug_handle_valdiation flag allows the function to work + validation_disabled_result = propagate_back_debug_handle( + exported_program_from_etrecord, + export_graph_id, + edge_dialect_program, + disable_debug_handle_valdiation=True, + ) + + # With validation disabled, it should return True even when from_node info is lost + self.assertTrue( + validation_disabled_result, + "propagate_back_debug_handle should return True when validation is disabled, " + "allowing best effort comparison even when from_node info is lost", + ) + def _gen_random_events(self) -> List[Event]: events = [] for i in range(2):