diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 6d046d8f2e8..323bda44a2c 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1169,6 +1169,7 @@ def _consume_etrecord(self) -> None: def _get_aot_intermediate_outputs_and_op_names( self, + disable_debug_handle_valdiation: bool = False, ) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]: """ Capture intermediate outputs only if _representative_inputs are provided @@ -1184,6 +1185,7 @@ def _get_aot_intermediate_outputs_and_op_names( self._etrecord.exported_program, self._etrecord.export_graph_id, self._etrecord.edge_dialect_program, + disable_debug_handle_valdiation, ): export_program = self._etrecord.exported_program else: @@ -1404,7 +1406,9 @@ def get_exported_program( else self._etrecord.graph_map.get(graph) ) - def calculate_numeric_gap(self, distance: str = "MSE"): + def calculate_numeric_gap( + self, distance: str = "MSE", disable_debug_handle_valdiation: bool = False + ): """ Compares logged intermediate outputs from the exported graph (in ETRecord) with runtime outputs (in ETDump) using a user-specific numerical comparator. @@ -1416,12 +1420,19 @@ def calculate_numeric_gap(self, distance: str = "MSE"): Args: distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR". + disable_debug_handle_validation: Often when aten graph has symbolic shape nodes, and inbuilt ops like gt/lt etc., + during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection + between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding + node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This + flag allows one to override such behavior and make best effort comparison. Returns: pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps. """ aot_intermediate_outputs, aot_debug_handle_to_op_names = ( - self._get_aot_intermediate_outputs_and_op_names() + self._get_aot_intermediate_outputs_and_op_names( + disable_debug_handle_valdiation + ) ) if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0: raise ValueError( @@ -1451,6 +1462,15 @@ def calculate_numeric_gap(self, distance: str = "MSE"): ) in mapping.items(): if aot_intermediate_output is None or runtime_intermediate_output is None: continue + # If aot outputs length is > 1 then comparison fails since we dont really have + # any instances where runtime intermediate output is a tuple or list + # This does not happen when edge dialect program is reference for comparison + # but happens in aten graph where ops like unbind remain undecomposed + if ( + isinstance(aot_intermediate_output, Sequence) + and len(aot_intermediate_output) > 1 + ): + continue rows.append( { "aot_ops": find_op_names( diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index ee7ebb2f5ea..8703e260883 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -965,7 +965,7 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: # Ensure both sequences have the same length if len(a) != len(b): raise ValueError( - f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison." + f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}." ) # Compare each element in the sequences and return the list of results @@ -990,6 +990,9 @@ def get_ancestor_node_identifiers(node: Node) -> List[str]: Returns: the identifiers of all its ancestor nodes """ + if FROM_NODE_KEY not in node.meta: + return [] + node_source = node.meta[FROM_NODE_KEY] node_source = node_source[-1] ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"] @@ -1111,6 +1114,7 @@ def propagate_back_debug_handle( exported_program: ExportedProgram, exported_program_graph_id: int, edge_dialect_program: ExportedProgram, + disable_debug_handle_valdiation: bool = False, ) -> bool: """ Propagate debug handle from edge dialect program back to the exported program while maintain the correctness @@ -1124,6 +1128,10 @@ def propagate_back_debug_handle( Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1. The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping. + disable_debug_handle_validation is used to avoid _verify_graph_match() in case of debug handle mismatch. + This can happen when we are comparing against aten graph in which case not all debug handles are matched + in aten graph. Example of this is when symbolic shape nodes are re-exported. + Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False. """ # 1. Extract mapping from ancestor node identifiers to debug handles @@ -1137,7 +1145,9 @@ def propagate_back_debug_handle( ) # 3. Verify if every debug handle in edge dialect program has a corresponding node - if not _verify_graph_match(edge_dialect_program, matched_debug_handles): + if not disable_debug_handle_valdiation and not _verify_graph_match( + edge_dialect_program, matched_debug_handles + ): return False # 4. Apply debug handles to the exported program diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index a3afed07ed8..babefb4e2ac 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -681,7 +681,7 @@ def test_calculate_numeric_gap(self): aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda: ( + inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( aot_intermediate_outputs, aot_debug_handle_to_op_name, ) @@ -838,6 +838,122 @@ def _gen_random_runtime_output( ) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]: return [torch.randn(RAW_DATA_SIZE)] + def test_disable_debug_handle_validation_with_symbolic_shapes(self): + """ + Test that demonstrates the issue with symbolic shape related nodes losing from_node info + during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter + in propagate_back_debug_handle allows validation to be bypassed. + """ + from executorch.devtools.inspector._inspector_utils import ( + propagate_back_debug_handle, + ) + + class SymbolicShapeModel(torch.nn.Module): + """Model that will have symbolic shape related operations after export.""" + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # This will create symbolic shape nodes during dynamic export + batch_size = x.shape[0] + x = x + torch.rand((batch_size, 1)) + # Masking operation that creates gt/lt nodes + valid_mask = mask > 0.5 + x = torch.where(valid_mask, x, torch.zeros_like(x)) + return x + + # Create model and dynamic inputs + model = SymbolicShapeModel() + batch_size = 2 + seq_len = 4 + x = torch.randn(batch_size, seq_len) + mask = torch.rand(batch_size, seq_len) + example_inputs = (x, mask) + + # Export with dynamic shapes to create symbolic shape related nodes + dynamic_shapes = { + "x": {0: torch.export.Dim("batch_size", min=1, max=10)}, + "mask": {0: torch.export.Dim("batch_size", min=1, max=10)}, + } + + exported_program = torch.export.export( + model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + + """ + In this case origina aten graph has sym_size_int_2 node but when we look at + nodes metadata in edge_program_manager, its sym_size node's from_node says + sym_size_int_3 which is not in the original aten graph. + """ + # Create edge program - this is where from_node info can be lost for symbolic shape nodes + edge_program_manager: EdgeProgramManager = to_edge(exported_program) + edge_program_manager_copy = copy.deepcopy(edge_program_manager) + et_program_manager: ExecutorchProgramManager = ( + edge_program_manager.to_executorch() + ) + + with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: + etrecord_path = tmp_file.name + + # Generate ETRecord with the exported program (aten graph) + generate_etrecord( + etrecord_path, + edge_program_manager_copy, + et_program_manager, + exported_program=exported_program, + ) + + # Create Inspector and get etrecord + with patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object(EventBlock, "_gen_from_etdump"): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=etrecord_path, + ) + + # Extract the necessary values from the inspector's etrecord + exported_program_from_etrecord = ( + inspector_instance._etrecord.exported_program + ) + export_graph_id = inspector_instance._etrecord.export_graph_id + edge_dialect_program = inspector_instance._etrecord.edge_dialect_program + + # Ensure we have all the necessary components + self.assertIsNotNone(exported_program_from_etrecord) + self.assertIsNotNone(export_graph_id) + self.assertIsNotNone(edge_dialect_program) + + # Test propagate_back_debug_handle with validation enabled (should fail or return False) + # This demonstrates the issue with symbolic shape nodes losing from_node info + validation_enabled_result = propagate_back_debug_handle( + exported_program_from_etrecord, + export_graph_id, + edge_dialect_program, + disable_debug_handle_valdiation=False, + ) + + # With validation enabled, it should return False when from_node info is lost + self.assertFalse( + validation_enabled_result, + "propagate_back_debug_handle should return False when validation is enabled " + "and symbolic shape nodes lose from_node info", + ) + + # Test propagate_back_debug_handle with validation disabled (should succeed) + # This shows how the disable_debug_handle_valdiation flag allows the function to work + validation_disabled_result = propagate_back_debug_handle( + exported_program_from_etrecord, + export_graph_id, + edge_dialect_program, + disable_debug_handle_valdiation=True, + ) + + # With validation disabled, it should return True even when from_node info is lost + self.assertTrue( + validation_disabled_result, + "propagate_back_debug_handle should return True when validation is disabled, " + "allowing best effort comparison even when from_node info is lost", + ) + def _gen_random_events(self) -> List[Event]: events = [] for i in range(2):