diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index a209da8adb7..e4ddcce1ce7 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -48,6 +48,7 @@ EXCLUDED_COLUMNS_WHEN_PRINTING, EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT, EXCLUDED_EVENTS_WHEN_PRINTING, + find_op_names, find_populated_event, FORWARD, gen_etdump_object, @@ -68,6 +69,7 @@ from executorch.devtools.inspector.numerical_comparator import ( L1Comparator, MSEComparator, + SNRComparator, ) from executorch.exir import ExportedProgram @@ -1084,8 +1086,6 @@ def __init__( # Key str is method name; value is list of ProgramOutputs because of list of test cases self._reference_outputs: Dict[str, List[ProgramOutput]] = {} self._enable_module_hierarchy = enable_module_hierarchy - self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None - self._aot_debug_handles_to_op_names: Optional[Dict[Tuple[int, ...], str]] = None self._consume_etrecord() def _consume_etrecord(self) -> None: @@ -1146,19 +1146,26 @@ def _consume_etrecord(self) -> None: event_block.reference_output = self._reference_outputs[FORWARD][ index ] - # Capture intermediate outputs only if _representative_inputs are provided - # when using bundled program to create the etrecord + + def _get_aot_intermediate_outputs_and_op_names( + self, + ) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]: + """ + Capture intermediate outputs only if _representative_inputs are provided + when using bundled program to create the etrecord + """ if self._etrecord._representative_inputs is None: - return + return {}, {} export_program = self._etrecord.edge_dialect_program graph_module = export_program.module() - self._aot_debug_handles_to_op_names = get_aot_debug_handle_to_op_name_mapping( + aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping( graph_module ) capturer = IntermediateOutputCapturer(graph_module) - self._aot_intermediate_outputs = capturer.run_and_capture( + aot_intermediate_outputs = capturer.run_and_capture( self._etrecord._representative_inputs ) + return aot_intermediate_outputs, aot_debug_handle_to_op_name # TODO: Make it more extensible to further merge overlapping debug handles def _get_runtime_intermediate_outputs_and_op_names( @@ -1366,22 +1373,27 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame: pd.DataFrame: A DataFrame listing corresponding operator outputs from both stages and their computed numerical gaps. """ - if self._aot_intermediate_outputs is None: + aot_intermediate_outputs, aot_debug_handle_to_op_name = ( + self._get_aot_intermediate_outputs_and_op_names() + ) + if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_name) == 0: raise ValueError( - "The aot intermediate outputs is required but not populated." + "calculate_numerical_gap error: The aot debug information is required but not populated" ) # The runtime_op_names will be used later to map runtime debug_handle to op_name - runtime_intermediate_outputs, runtime_op_names = ( + runtime_intermediate_outputs, runtime_debug_handle_to_op_name = ( self._get_runtime_intermediate_outputs_and_op_names() ) mapping = map_runtime_aot_intermediate_outputs( - self._aot_intermediate_outputs, runtime_intermediate_outputs + aot_intermediate_outputs, runtime_intermediate_outputs ) metric = distance.strip().upper() if metric == "MSE": comparator = MSEComparator() elif metric == "L1": comparator = L1Comparator() + elif metric == "SNR": + comparator = SNRComparator() else: raise ValueError(f"Unsupported distance metric {distance!r}") @@ -1394,9 +1406,13 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame: continue rows.append( { - "aot_debug_handle": aot_debug_handle, + "aot_ops": find_op_names( + aot_debug_handle, aot_debug_handle_to_op_name + ), "aot_intermediate_output": aot_intermediate_output, - "runtime_debug_handle": runtime_debug_handle, + "runtime_ops": find_op_names( + runtime_debug_handle, runtime_debug_handle_to_op_name + ), "runtime_intermediate_output": runtime_intermediate_output, "gap": comparator.compare( aot_intermediate_output, runtime_intermediate_output diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 50b3669309c..6869c793946 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -784,3 +784,23 @@ def get_aot_debug_handle_to_op_name_mapping( ) debug_handle_to_op_name[key] = node.name return debug_handle_to_op_name + + +def find_op_names( + target_debug_handle: Tuple[int, ...], + debug_handle_to_op_name: Dict[Tuple[int, ...], str], +) -> List[str]: + """ + Record the operator names only if their debug handles are part of the target debug handle. + The debug handles in `debug_handle_to_op_name` have undergone merging and remain unchanged, + and this function identifies operations corresponding to these transformed handles. + """ + dh_set = set(target_debug_handle) + result = [] + + for key_tuple, op_name in debug_handle_to_op_name.items(): + # Check if key is a subset of the target_debug_handle + if set(key_tuple).issubset(dh_set): + result.append(op_name) + + return result diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index df434fd675d..28e33cca863 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -44,6 +44,7 @@ TimeScale, ) from executorch.devtools.inspector.tests.inspector_test_utils import ( + check_if_debug_handle_to_op_name_match, check_if_final_outputs_match, model_registry, ) @@ -468,25 +469,7 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self): events=events, ) - def test_no_capture_when_representative_inputs_are_none(self): - # Create a context manager to patch functions called by Inspector.__init__ - with patch.object( - _inspector, "parse_etrecord", return_value=None - ), patch.object( - _inspector, "gen_etdump_object", return_value=None - ), patch.object( - EventBlock, "_gen_from_etdump" - ), patch.object( - _inspector, "gen_graphs_from_etrecord" - ): - # Call the constructor of Inspector - inspector_instance = Inspector( - etdump_path=ETDUMP_PATH, - etrecord=ETRECORD_PATH, - ) - self.assertIsNone(inspector_instance._aot_intermediate_outputs) - - def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self): + def test_etrecord_populates_correct_aot_intermediate_outputs(self): with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: etrecord_path = tmp_file.name mod = model_registry["ConvLinearModel"]() @@ -505,7 +488,6 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self): generate_etrecord( etrecord_path, edge_program_manager_copy, et_program_manager ) - original_consume_etrecord = Inspector._consume_etrecord with patch.object( Inspector, "_consume_etrecord", return_value=None ), patch.object( @@ -529,11 +511,17 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self): _representative_inputs=aten_model.example_inputs[0], ) inspector_instance._etrecord = etrecord - Inspector._consume_etrecord = original_consume_etrecord - inspector_instance._consume_etrecord() + aot_intermediate_outputs, aot_debug_handle_to_op_name = ( + inspector_instance._get_aot_intermediate_outputs_and_op_names() + ) self.assertTrue( check_if_final_outputs_match( - "ConvLinearModel", inspector_instance._aot_intermediate_outputs + "ConvLinearModel", aot_intermediate_outputs + ) + ) + self.assertTrue( + check_if_debug_handle_to_op_name_match( + "ConvLinearModel", aot_debug_handle_to_op_name ) ) @@ -605,6 +593,7 @@ def test_calculate_numeric_gap(self): ), patch.object( _inspector, "gen_graphs_from_etrecord" ): + # Call the constructor of Inspector inspector_instance = Inspector( etdump_path=ETDUMP_PATH, @@ -621,9 +610,15 @@ def test_calculate_numeric_gap(self): (1,): torch.tensor([3.0, 6.0, 5.0]), } - inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs + aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + + inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( - lambda: (runtime_intermediate_outputs, {}) + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) ) df = inspector_instance.calculate_numeric_gap(distance="L1") @@ -631,33 +626,28 @@ def test_calculate_numeric_gap(self): self.assertEqual(len(df), 2) cols = set(df.columns) expected_cols = { - "aot_debug_handle", + "aot_ops", "aot_intermediate_output", - "runtime_debug_handle", + "runtime_ops", "runtime_intermediate_output", "gap", } self.assertEqual(cols, expected_cols) - founded_aot_debug_handle = set(df["aot_debug_handle"]) - self.assertEqual( - founded_aot_debug_handle, set(aot_intermediate_outputs.keys()) - ) - for _, row in df.iterrows(): - aot_debuh_handle = row["aot_debug_handle"] + for i, row in df.iterrows(): + # Dummpy key to get the expected aot/runtime internmediate outputs + key = (i,) # aot_intermediate_output should equal aot_intermediate_outputs[h] self.assertTrue( torch.allclose( row["aot_intermediate_output"], - aot_intermediate_outputs[aot_debuh_handle], + aot_intermediate_outputs[key], ) ) - # runtime_debug_hanlde equals aot_debug_handle at this case - self.assertEqual(row["runtime_debug_handle"], aot_debuh_handle) # runtime_intermediate_output should equal runtime_intermediate_outputs[h] self.assertTrue( torch.allclose( row["runtime_intermediate_output"], - runtime_intermediate_outputs[aot_debuh_handle], + runtime_intermediate_outputs[key], ) ) # gap should equal 3.0 diff --git a/devtools/inspector/tests/inspector_test_utils.py b/devtools/inspector/tests/inspector_test_utils.py index b9d4b1882b8..ef36bd6a178 100644 --- a/devtools/inspector/tests/inspector_test_utils.py +++ b/devtools/inspector/tests/inspector_test_utils.py @@ -83,6 +83,26 @@ def get_expected_intermediate_outputs(): (21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], } + @staticmethod + def get_expected_debug_handle_to_op_name(): + """ + Returns the expected debug handle and op name mapping for this model for the given input. + """ + return { + (10,): "aten_convolution_default", + (11,): "aten_view_copy_default", + (12,): "aten_permute_copy_default", + (13,): "aten_addmm_default", + (14,): "aten_add_tensor", + (15,): "aten_sub_tensor", + (16,): "aten_mul_tensor", + (17,): "aten_add_tensor_1", + (18,): "aten_div_tensor", + (19,): "aten_relu_default", + (20,): "aten_sigmoid_default", + (21,): "aten_split_with_sizes_copy_default", + } + # Global model registry model_registry = { @@ -116,3 +136,21 @@ def check_if_final_outputs_match(model_name, actual_outputs_with_handles): if not torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5): return False return True + + +def check_if_debug_handle_to_op_name_match(model_name, actual_debug_handle_to_op_name): + """ + Checks if the actual op names match the expected op names for the specified model. + Returns True if all match, otherwise returns False. + """ + model_instance = model_registry[model_name] + expected_debug_handle_to_op_name = ( + model_instance.get_expected_debug_handle_to_op_name() + ) + if len(actual_debug_handle_to_op_name) != len(expected_debug_handle_to_op_name): + return False + for debug_handle, expected_op_name in expected_debug_handle_to_op_name.items(): + actual_op_name = actual_debug_handle_to_op_name.get(debug_handle) + if actual_op_name != expected_op_name: + return False + return True diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 6d12cb13c5f..b540f8dccd1 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -32,6 +32,7 @@ convert_to_float_tensor, create_debug_handle_to_op_node_mapping, EDGE_DIALECT_GRAPH_KEY, + find_op_names, find_populated_event, gen_graphs_from_etrecord, get_aot_debug_handle_to_op_name_mapping, @@ -472,6 +473,23 @@ def test_node_op_type_mismatch(self): # Test that the filter doesn't match the mock node (op_type mismatch) self.assertFalse(node_filter.matches(mock_node_op_type_mismatch)) + def test_find_op_names_empty_debug_handle(self): + debug_handle = () + debug_handle_to_op_name = {(1, 2): "op1", (3, 4): "op2"} + self.assertEqual(find_op_names(debug_handle, debug_handle_to_op_name), []) + + def test_find_op_names_no_matching_handles(self): + debug_handle = (1, 2) + debug_handle_to_op_name = {(3, 4): "op1", (5, 6): "op2"} + self.assertEqual(find_op_names(debug_handle, debug_handle_to_op_name), []) + + def test_find_op_names_matching_handles(self): + debug_handle = (1, 2, 3) + debug_handle_to_op_name = {(1, 2): "op1", (2, 3): "op2", (4, 5, 6): "op3"} + self.assertEqual( + find_op_names(debug_handle, debug_handle_to_op_name), ["op1", "op2"] + ) + def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]]