diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 323bda44a2c..6d046d8f2e8 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1169,7 +1169,6 @@ def _consume_etrecord(self) -> None: def _get_aot_intermediate_outputs_and_op_names( self, - disable_debug_handle_valdiation: bool = False, ) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]: """ Capture intermediate outputs only if _representative_inputs are provided @@ -1185,7 +1184,6 @@ def _get_aot_intermediate_outputs_and_op_names( self._etrecord.exported_program, self._etrecord.export_graph_id, self._etrecord.edge_dialect_program, - disable_debug_handle_valdiation, ): export_program = self._etrecord.exported_program else: @@ -1406,9 +1404,7 @@ def get_exported_program( else self._etrecord.graph_map.get(graph) ) - def calculate_numeric_gap( - self, distance: str = "MSE", disable_debug_handle_valdiation: bool = False - ): + def calculate_numeric_gap(self, distance: str = "MSE"): """ Compares logged intermediate outputs from the exported graph (in ETRecord) with runtime outputs (in ETDump) using a user-specific numerical comparator. @@ -1420,19 +1416,12 @@ def calculate_numeric_gap( Args: distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR". - disable_debug_handle_validation: Often when aten graph has symbolic shape nodes, and inbuilt ops like gt/lt etc., - during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection - between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding - node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This - flag allows one to override such behavior and make best effort comparison. Returns: pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps. """ aot_intermediate_outputs, aot_debug_handle_to_op_names = ( - self._get_aot_intermediate_outputs_and_op_names( - disable_debug_handle_valdiation - ) + self._get_aot_intermediate_outputs_and_op_names() ) if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0: raise ValueError( @@ -1462,15 +1451,6 @@ def calculate_numeric_gap( ) in mapping.items(): if aot_intermediate_output is None or runtime_intermediate_output is None: continue - # If aot outputs length is > 1 then comparison fails since we dont really have - # any instances where runtime intermediate output is a tuple or list - # This does not happen when edge dialect program is reference for comparison - # but happens in aten graph where ops like unbind remain undecomposed - if ( - isinstance(aot_intermediate_output, Sequence) - and len(aot_intermediate_output) > 1 - ): - continue rows.append( { "aot_ops": find_op_names( diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 8703e260883..ee7ebb2f5ea 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -965,7 +965,7 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: # Ensure both sequences have the same length if len(a) != len(b): raise ValueError( - f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}." + f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison." ) # Compare each element in the sequences and return the list of results @@ -990,9 +990,6 @@ def get_ancestor_node_identifiers(node: Node) -> List[str]: Returns: the identifiers of all its ancestor nodes """ - if FROM_NODE_KEY not in node.meta: - return [] - node_source = node.meta[FROM_NODE_KEY] node_source = node_source[-1] ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"] @@ -1114,7 +1111,6 @@ def propagate_back_debug_handle( exported_program: ExportedProgram, exported_program_graph_id: int, edge_dialect_program: ExportedProgram, - disable_debug_handle_valdiation: bool = False, ) -> bool: """ Propagate debug handle from edge dialect program back to the exported program while maintain the correctness @@ -1128,10 +1124,6 @@ def propagate_back_debug_handle( Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1. The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping. - disable_debug_handle_validation is used to avoid _verify_graph_match() in case of debug handle mismatch. - This can happen when we are comparing against aten graph in which case not all debug handles are matched - in aten graph. Example of this is when symbolic shape nodes are re-exported. - Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False. """ # 1. Extract mapping from ancestor node identifiers to debug handles @@ -1145,9 +1137,7 @@ def propagate_back_debug_handle( ) # 3. Verify if every debug handle in edge dialect program has a corresponding node - if not disable_debug_handle_valdiation and not _verify_graph_match( - edge_dialect_program, matched_debug_handles - ): + if not _verify_graph_match(edge_dialect_program, matched_debug_handles): return False # 4. Apply debug handles to the exported program diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index babefb4e2ac..a3afed07ed8 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -681,7 +681,7 @@ def test_calculate_numeric_gap(self): aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( + inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda: ( aot_intermediate_outputs, aot_debug_handle_to_op_name, ) @@ -838,122 +838,6 @@ def _gen_random_runtime_output( ) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]: return [torch.randn(RAW_DATA_SIZE)] - def test_disable_debug_handle_validation_with_symbolic_shapes(self): - """ - Test that demonstrates the issue with symbolic shape related nodes losing from_node info - during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter - in propagate_back_debug_handle allows validation to be bypassed. - """ - from executorch.devtools.inspector._inspector_utils import ( - propagate_back_debug_handle, - ) - - class SymbolicShapeModel(torch.nn.Module): - """Model that will have symbolic shape related operations after export.""" - - def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - # This will create symbolic shape nodes during dynamic export - batch_size = x.shape[0] - x = x + torch.rand((batch_size, 1)) - # Masking operation that creates gt/lt nodes - valid_mask = mask > 0.5 - x = torch.where(valid_mask, x, torch.zeros_like(x)) - return x - - # Create model and dynamic inputs - model = SymbolicShapeModel() - batch_size = 2 - seq_len = 4 - x = torch.randn(batch_size, seq_len) - mask = torch.rand(batch_size, seq_len) - example_inputs = (x, mask) - - # Export with dynamic shapes to create symbolic shape related nodes - dynamic_shapes = { - "x": {0: torch.export.Dim("batch_size", min=1, max=10)}, - "mask": {0: torch.export.Dim("batch_size", min=1, max=10)}, - } - - exported_program = torch.export.export( - model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True - ) - - """ - In this case origina aten graph has sym_size_int_2 node but when we look at - nodes metadata in edge_program_manager, its sym_size node's from_node says - sym_size_int_3 which is not in the original aten graph. - """ - # Create edge program - this is where from_node info can be lost for symbolic shape nodes - edge_program_manager: EdgeProgramManager = to_edge(exported_program) - edge_program_manager_copy = copy.deepcopy(edge_program_manager) - et_program_manager: ExecutorchProgramManager = ( - edge_program_manager.to_executorch() - ) - - with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: - etrecord_path = tmp_file.name - - # Generate ETRecord with the exported program (aten graph) - generate_etrecord( - etrecord_path, - edge_program_manager_copy, - et_program_manager, - exported_program=exported_program, - ) - - # Create Inspector and get etrecord - with patch.object( - _inspector, "gen_etdump_object", return_value=None - ), patch.object(EventBlock, "_gen_from_etdump"): - inspector_instance = Inspector( - etdump_path=ETDUMP_PATH, - etrecord=etrecord_path, - ) - - # Extract the necessary values from the inspector's etrecord - exported_program_from_etrecord = ( - inspector_instance._etrecord.exported_program - ) - export_graph_id = inspector_instance._etrecord.export_graph_id - edge_dialect_program = inspector_instance._etrecord.edge_dialect_program - - # Ensure we have all the necessary components - self.assertIsNotNone(exported_program_from_etrecord) - self.assertIsNotNone(export_graph_id) - self.assertIsNotNone(edge_dialect_program) - - # Test propagate_back_debug_handle with validation enabled (should fail or return False) - # This demonstrates the issue with symbolic shape nodes losing from_node info - validation_enabled_result = propagate_back_debug_handle( - exported_program_from_etrecord, - export_graph_id, - edge_dialect_program, - disable_debug_handle_valdiation=False, - ) - - # With validation enabled, it should return False when from_node info is lost - self.assertFalse( - validation_enabled_result, - "propagate_back_debug_handle should return False when validation is enabled " - "and symbolic shape nodes lose from_node info", - ) - - # Test propagate_back_debug_handle with validation disabled (should succeed) - # This shows how the disable_debug_handle_valdiation flag allows the function to work - validation_disabled_result = propagate_back_debug_handle( - exported_program_from_etrecord, - export_graph_id, - edge_dialect_program, - disable_debug_handle_valdiation=True, - ) - - # With validation disabled, it should return True even when from_node info is lost - self.assertTrue( - validation_disabled_result, - "propagate_back_debug_handle should return True when validation is disabled, " - "allowing best effort comparison even when from_node info is lost", - ) - def _gen_random_events(self) -> List[Event]: events = [] for i in range(2):