diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index 3906dcb1030..a144d7e4eaf 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -52,6 +52,7 @@ class ETRecordReservedFileNames(StrEnum): ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module" DEBUG_HANDLE_MAP_NAME = "debug_handle_map" DELEGATE_MAP_NAME = "delegate_map" + INSTRUCTION_ID_TO_NUM_OUTS_MAP_NAME = "instruction_id_to_num_outs_map" REFERENCE_OUTPUTS = "reference_outputs" REPRESENTATIVE_INPUTS = "representative_inputs" @@ -67,6 +68,9 @@ def __init__( _delegate_map: Optional[ Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]] ] = None, + _instruction_id_to_num_outs_map: Optional[ + Dict[str, Dict[int, Union[int, List[int]]]] + ] = None, _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None, _representative_inputs: Optional[List[ProgramInput]] = None, ): @@ -92,6 +96,7 @@ def __init__( self.graph_map = graph_map self._debug_handle_map = _debug_handle_map self._delegate_map = _delegate_map + self._instruction_id_to_num_outs_map = _instruction_id_to_num_outs_map self._reference_outputs = _reference_outputs self._representative_inputs = _representative_inputs @@ -172,6 +177,12 @@ def _save_metadata(self, etrecord_zip: ZipFile) -> None: json.dumps(self._delegate_map), ) + if self._instruction_id_to_num_outs_map is not None: + etrecord_zip.writestr( + ETRecordReservedFileNames.INSTRUCTION_ID_TO_NUM_OUTS_MAP_NAME, + json.dumps(self._instruction_id_to_num_outs_map), + ) + if self._reference_outputs is not None: etrecord_zip.writestr( ETRecordReservedFileNames.REFERENCE_OUTPUTS, @@ -284,6 +295,7 @@ def add_executorch_program( if ( self._debug_handle_map is not None or self._delegate_map is not None + or self._instruction_id_to_num_outs_map is not None or self._reference_outputs is not None or self._representative_inputs is not None ): @@ -293,13 +305,18 @@ def add_executorch_program( ) # Process executorch program and extract data - debug_handle_map, delegate_map, reference_outputs, representative_inputs = ( - _process_executorch_program(executorch_program) - ) + ( + debug_handle_map, + delegate_map, + instruction_id_to_num_outs_map, + reference_outputs, + representative_inputs, + ) = _process_executorch_program(executorch_program) # Set the extracted data self._debug_handle_map = debug_handle_map self._delegate_map = delegate_map + self._instruction_id_to_num_outs_map = instruction_id_to_num_outs_map self._reference_outputs = reference_outputs self._representative_inputs = representative_inputs @@ -593,7 +610,9 @@ def _process_executorch_program( executorch_program: Union[ ExecutorchProgram, ExecutorchProgramManager, BundledProgram ] -) -> tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[List]]: +) -> tuple[ + Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict], Optional[List] +]: """Process executorch program and return debug maps and bundled program data.""" if isinstance(executorch_program, BundledProgram): reference_outputs = _get_reference_outputs(executorch_program) @@ -602,11 +621,30 @@ def _process_executorch_program( debug_handle_map = executorch_program.executorch_program.debug_handle_map # pyre-ignore[16]: Item `None` of `typing.Union[None, exir.program._program.ExecutorchProgram, exir.program._program.ExecutorchProgramManager]` has no attribute `debug_handle_map` delegate_map = executorch_program.executorch_program.delegate_map - return debug_handle_map, delegate_map, reference_outputs, representative_inputs + # pyre-ignore[16]: Item `None` of `typing.Union[None, exir.program._program.ExecutorchProgram, exir.program._program.ExecutorchProgramManager]` has no attribute `instruction_id_to_num_outs_map` + instruction_id_to_num_outs_map = ( + executorch_program.executorch_program.instruction_id_to_num_outs_map + ) + return ( + debug_handle_map, + delegate_map, + instruction_id_to_num_outs_map, + reference_outputs, + representative_inputs, + ) else: debug_handle_map = executorch_program.debug_handle_map delegate_map = executorch_program.delegate_map - return debug_handle_map, delegate_map, None, None + instruction_id_to_num_outs_map = ( + executorch_program.instruction_id_to_num_outs_map + ) + return ( + debug_handle_map, + delegate_map, + instruction_id_to_num_outs_map, + None, + None, + ) def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 @@ -640,6 +678,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 graph_map: Dict[str, ExportedProgram] = {} debug_handle_map = None delegate_map = None + instruction_id_to_num_outs_map = None exported_program = None edge_dialect_program = None reference_outputs = None @@ -659,6 +698,12 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 delegate_map = json.loads( etrecord_zip.read(ETRecordReservedFileNames.DELEGATE_MAP_NAME) ) + elif entry == ETRecordReservedFileNames.INSTRUCTION_ID_TO_NUM_OUTS_MAP_NAME: + instruction_id_to_num_outs_map = json.loads( + etrecord_zip.read( + ETRecordReservedFileNames.INSTRUCTION_ID_TO_NUM_OUTS_MAP_NAME + ) + ) elif entry == ETRecordReservedFileNames.ETRECORD_IDENTIFIER: continue elif entry == ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM: @@ -724,6 +769,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 graph_map=graph_map, _debug_handle_map=debug_handle_map, _delegate_map=delegate_map, + _instruction_id_to_num_outs_map=instruction_id_to_num_outs_map, _reference_outputs=reference_outputs, _representative_inputs=representative_inputs, export_graph_id=export_graph_id, diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 44b383da0e4..535de2e9a56 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -219,6 +219,10 @@ def test_etrecord_generation(self): etrecord._debug_handle_map, json.loads(json.dumps(et_output.debug_handle_map)), ) + self.assertEqual( + etrecord._instruction_id_to_num_outs_map, + json.loads(json.dumps(et_output.instruction_id_to_num_outs_map)), + ) def test_etrecord_generation_with_bundled_program(self): ( diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index c7b4655ca11..6d046d8f2e8 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -317,6 +317,8 @@ class Event: op_type: List of op types corresponding to the event. delegate_debug_identifier: Supplemental identifier used in combination with instruction id. debug_handles: Debug handles in the model graph to which this event is correlated. + num_outputs: Indicates the number of outputs generated by the node. + Right now only used for call_delegate nodes that output more than one tensor. stack_trace: A dictionary mapping the name of each associated op to its stack trace. module_hierarchy: A dictionary mapping the name of each associated op to its module hierarchy. is_delegated_op: Whether or not the event was delegated. @@ -337,6 +339,7 @@ class Event: op_types: List[str] = dataclasses.field(default_factory=list) delegate_debug_identifier: Optional[Union[int, str]] = None debug_handles: Optional[Union[int, Sequence[int]]] = None + num_outputs: int = 1 stack_traces: Dict[str, str] = dataclasses.field(default_factory=dict) module_hierarchy: Dict[str, Dict] = dataclasses.field(default_factory=dict) is_delegated_op: Optional[bool] = None @@ -928,6 +931,7 @@ def _gen_resolve_debug_handles( self, handle_map: Dict[str, List[int]], delegate_map: Optional[Dict[str, DelegateMetadata]] = None, + instruction_id_to_num_outs_map: Dict[int, int] = None, ): """ Given mappings from instruction id to debug handles, populate the @@ -945,6 +949,10 @@ def _gen_resolve_debug_handles( if (instruction_id := str(event._instruction_id)) not in handle_map: continue + num_outputs = 1 + if instruction_id_to_num_outs_map is not None: + num_outputs = instruction_id_to_num_outs_map.get(instruction_id, 1) + event.num_outputs = num_outputs # For non-delegated event, handles are found in handle_map if (delegate_debug_id := event.delegate_debug_identifier) is None: event.debug_handles = handle_map[instruction_id] @@ -1131,6 +1139,7 @@ def _consume_etrecord(self) -> None: if self._etrecord._delegate_map is not None else None ), + self._etrecord._instruction_id_to_num_outs_map[FORWARD], ) # (2) Event Metadata Association @@ -1196,7 +1205,7 @@ def _get_aot_intermediate_outputs_and_op_names( # TODO: Make it more extensible to further merge overlapping debug handles def _get_runtime_intermediate_outputs_and_op_names( self, - ) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]: + ) -> Tuple[Dict[DebugHandle, Tuple[Any, int]], Dict[DebugHandle, List[str]]]: """ Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings) from the event blocks, along with the corresponding debug handles and op names mapping. @@ -1217,12 +1226,15 @@ def _get_runtime_intermediate_outputs_and_op_names( debug_handle = (debug_handle,) else: debug_handle = tuple(debug_handle) - current_entry = debug_handle_to_output.get(debug_handle, (-1, None)) + current_entry = debug_handle_to_output.get( + debug_handle, (-1, None, event.num_outputs) + ) # When event has same debug_handle, only keep the one with the largest instruction id if event._instruction_id > current_entry[0]: debug_handle_to_output[debug_handle] = ( event._instruction_id, event.debug_data, + event.num_outputs, ) # TODO: One debug handle can be associated with multiple op names debug_handle_to_op_names[debug_handle] = [event.name] @@ -1231,7 +1243,7 @@ def _get_runtime_intermediate_outputs_and_op_names( debug_handle_to_output ) return { - k: v[1] for k, v in debug_handle_to_output.items() + k: (v[1], v[2]) for k, v in debug_handle_to_output.items() }, debug_handle_to_op_names def to_dataframe( diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 2bda03b4873..ee7ebb2f5ea 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# copyright (c) meta platforms, inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -99,11 +99,14 @@ class NodeData: - source: A string indicating the origin of the node (either FROM_AOT or FROM_RUNTIME). - debug_handle: A tuple representing the unique identifier for the output. - output: The actual output data associated with the debug handle. + - num_outputs: Indicates the number of outputs generated by the node. + Right now only used for call_delegate nodes that output more than one tensor. """ source: NodeSource debug_handle: tuple[int] output: Any + num_outputs: int class NodeFilter: @@ -578,8 +581,8 @@ def _merge_runtime_debug_handles( def merge_runtime_overlapping_debug_handles( - runtime_intermediate_outputs: Dict[DebugHandle, Tuple[int, Any]] -) -> Dict[DebugHandle, Tuple[int, Any]]: + runtime_intermediate_outputs: Dict[DebugHandle, Tuple[int, Any, int]] +) -> Dict[DebugHandle, Tuple[int, Any, int]]: """ Merges runtimes with overlapping debug handles into a single key in the dict. @@ -592,12 +595,17 @@ def merge_runtime_overlapping_debug_handles( """ if len(runtime_intermediate_outputs) == 0: return {} - merged: Dict[DebugHandle, Tuple[int, Any]] = {} + merged: Dict[DebugHandle, Tuple[int, Any, int]] = {} for debug_handle, ( instruction_id, debug_data, + num_outputs, ) in runtime_intermediate_outputs.items(): - curr_debug_handle, last_value = debug_handle, (instruction_id, debug_data) + curr_debug_handle, last_value = debug_handle, ( + instruction_id, + debug_data, + num_outputs, + ) # Collect any existing keys that overlap with the current key to_remove = [] for existing_debug_handle, existing_value in merged.items(): @@ -634,7 +642,9 @@ def _debug_handles_have_overlap( def _combine_aot_overlapped_intermediate_outputs( - aot_nodes: List[Tuple[DebugHandle, Any]], runtime_node: Tuple[DebugHandle, Any] + aot_nodes: List[Tuple[DebugHandle, Any]], + runtime_node: Tuple[DebugHandle, Any, int], + negative_index: int, ) -> Tuple[DebugHandle, Any]: """ Ensure the AOT combined debug_handles are the same as the runtime debug_handles (order ignored), @@ -642,7 +652,7 @@ def _combine_aot_overlapped_intermediate_outputs( """ # Map AOT single element debug_handles to outputs aot_map = dict(aot_nodes) - runtime_debug_handle, _ = runtime_node + runtime_debug_handle, _, _ = runtime_node # Combine all AOT debug_handles into a list aot_combined_debug_handle = [t[0] for t in aot_map.keys()] @@ -652,14 +662,14 @@ def _combine_aot_overlapped_intermediate_outputs( return (-1,), None # Pick the last intermediate output - last_int = runtime_debug_handle[-1] + last_int = runtime_debug_handle[negative_index] key = (last_int,) return runtime_debug_handle, aot_map[key] def _create_debug_handle_overlap_graph( aot_intermediate_outputs: Dict[DebugHandle, Any], - runtime_intermediate_outputs: Dict[DebugHandle, Any], + runtime_intermediate_outputs: Dict[DebugHandle, Tuple[Any, int]], ) -> Tuple[List[NodeData], Dict[int, List[int]]]: """ Create a graph representing overlapping debug handles between AOT and runtime outputs. @@ -675,9 +685,14 @@ def _create_debug_handle_overlap_graph( """ nodes = [] for debug_handle, output in aot_intermediate_outputs.items(): - nodes.append(NodeData(NodeSource.AOT, debug_handle, output)) - for debug_handle, output in runtime_intermediate_outputs.items(): - nodes.append(NodeData(NodeSource.RUNTIME, debug_handle, output)) + # TODO: for aot outputs also derive the number of output tensors generated by the node + nodes.append(NodeData(NodeSource.AOT, debug_handle, output, 1)) + for debug_handle, value in runtime_intermediate_outputs.items(): + nodes.append( + NodeData( + NodeSource.RUNTIME, debug_handle, output=value[0], num_outputs=value[1] + ) + ) edges = {i: [] for i in range(len(nodes))} for i in range(len(nodes)): @@ -730,7 +745,7 @@ def dfs(node_id, component): def map_runtime_aot_intermediate_outputs( aot_intermediate_outputs: Dict[DebugHandle, Any], - runtime_intermediate_outputs: Dict[DebugHandle, Any], + runtime_intermediate_outputs: Dict[DebugHandle, Tuple[Any, int]], ) -> Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]]: """ Map the runtime intermediate outputs to the AOT intermediate outputs @@ -757,7 +772,11 @@ def map_runtime_aot_intermediate_outputs( if nodes[node_id].source == NodeSource.AOT ] runtime_list = [ - (nodes[node_id].debug_handle, nodes[node_id].output) + ( + nodes[node_id].debug_handle, + nodes[node_id].output, + nodes[node_id].num_outputs, + ) for node_id in comp if nodes[node_id].source == NodeSource.RUNTIME ] @@ -772,50 +791,74 @@ def map_runtime_aot_intermediate_outputs( f"Expected only one runtime debug handle, but found {len(runtime_list)}: {runtime_list}" ) - runtime_debug_handle, runtime_intermediate_output = runtime_list[0] - - # Combine aot debug handles into a single key - aot_combined_debug_handle, aot_intermediate_output = ( - _combine_aot_overlapped_intermediate_outputs(aot_list, runtime_list[0]) + runtime_debug_handle, runtime_intermediate_output, num_outputs = ( + runtime_list[0] ) + # iterate through each of the output from runtime, + # get the corresponding debug handle + # and map it to the aot debug handle + # and create a dictionary that maps aot debug handle + aot output to + # runtime debug handle + runtime output + # Note this works only for delegate case for now. + for i in range(num_outputs): + + negative_index = -1 * (i + 1) + aot_mapped_runtime_intermediate_output = runtime_intermediate_output + # Combine aot debug handles into a single key + aot_combined_debug_handle, aot_intermediate_output = ( + _combine_aot_overlapped_intermediate_outputs( + aot_list, runtime_list[0], negative_index + ) + ) - if aot_combined_debug_handle == (-1,): - # Skip this mapping if the aot combined debug handle and runtime debug handle do not exact match. - continue + if aot_combined_debug_handle == (-1,): + # Skip this mapping if the aot combined debug handle and runtime debug handle do not exact match. + continue - if isinstance(aot_intermediate_output, Sequence): - if not isinstance(runtime_intermediate_output, Sequence): - raise TypeError( - "runtime intermediate output should be a sequence when aot intermediate output is a sequence" + if isinstance(aot_intermediate_output, Sequence): + if not isinstance(runtime_intermediate_output, Sequence): + raise TypeError( + "runtime intermediate output should be a sequence when aot intermediate output is a sequence" + ) + last_element = runtime_intermediate_output[negative_index] + # TODO: this (last_element = list) is never really the case because runtime never returns output as a list + # for delegate case. + if isinstance(last_element, list) and all( + isinstance(t, torch.Tensor) for t in last_element + ): + # If the last element is a list of tensors (delegate case) + aot_mapped_runtime_intermediate_output = last_element + elif isinstance(last_element, torch.Tensor): + # If the last element is a tensor, as is always the case for runtime. + # However, now we have a strange condition where aot_intermediate_output is a list of tensors + # while runtime_intermediate_output is a single tensor. So we should never really come here. + # TODO: fix this + aot_mapped_runtime_intermediate_output = ( + runtime_intermediate_output + ) + else: + raise ValueError( + "The last element of runtime argument list must be a tensor or a list of tensors when aot intermediate output is a sequence" + ) + # List can't be used as a key, so convert to tuple + aot_intermediate_output = tuple(aot_intermediate_output) + aot_mapped_runtime_intermediate_output = tuple( + aot_mapped_runtime_intermediate_output ) - last_element = runtime_intermediate_output[-1] - if isinstance(last_element, list) and all( - isinstance(t, torch.Tensor) for t in last_element - ): - # If the last element is a list of tensors (delegate case) - runtime_intermediate_output = last_element - elif isinstance(last_element, torch.Tensor): - # If the last element is a tensor (non-delegate case) - pass - else: - raise ValueError( - "The last element of runtime argument list must be a tensor or a list of tensors when aot intermediate output is a sequence" + + elif isinstance(runtime_intermediate_output, Sequence): + # delegate runtime call and AOT intermediate is not a sequence, just take the last element from runtime list + aot_mapped_runtime_intermediate_output = ( + runtime_intermediate_output[negative_index] ) - # List can't be used as a key, so convert to tuple - aot_intermediate_output = tuple(aot_intermediate_output) - runtime_intermediate_output = tuple(runtime_intermediate_output) - - elif isinstance(runtime_intermediate_output, Sequence): - # delegate runtime call and AOT intermediate is not a sequence, just take the last element from runtime list - runtime_intermediate_output = runtime_intermediate_output[-1] - - # Create a mapping between runtime and aot - aot_runtime_mapping[ - (aot_combined_debug_handle, aot_intermediate_output) - ] = ( - runtime_debug_handle, - runtime_intermediate_output, - ) + + # Create a mapping between runtime and aot + aot_runtime_mapping[ + (aot_combined_debug_handle, aot_intermediate_output) + ] = ( + runtime_debug_handle, + aot_mapped_runtime_intermediate_output, + ) return aot_runtime_mapping diff --git a/devtools/inspector/tests/TARGETS b/devtools/inspector/tests/TARGETS index 250cc76fd31..e036cf9e074 100644 --- a/devtools/inspector/tests/TARGETS +++ b/devtools/inspector/tests/TARGETS @@ -1,11 +1,15 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") load("@fbcode_macros//build_defs:python_library.bzl", "python_library") +load("@fbsource//tools/target_determinator/macros:ci.bzl", "ci") oncall("executorch") python_unittest( name = "inspector_test", srcs = ["inspector_test.py"], + labels = ci.labels( + ci.buckconfig("executorch.event_tracer_enabled", "true"), + ), deps = [ "//executorch/devtools:lib", "//executorch/devtools/debug_format:et_schema", diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 37dc7921923..cf1fdb7ec00 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -7,6 +7,7 @@ # pyre-unsafe import copy +import os import random import statistics import tempfile @@ -21,7 +22,9 @@ import torch import torch.fx +import torch.utils._pytree as pytree +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.devtools import generate_etrecord, parse_etrecord from executorch.devtools.debug_format.et_schema import OperatorNode from executorch.devtools.etdump.schema_flatcc import ProfileEvent @@ -52,6 +55,10 @@ EdgeProgramManager, ExecutorchProgramManager, to_edge, + to_edge_transform_and_lower, +) +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, ) from torch.export import export, ExportedProgram @@ -633,7 +640,9 @@ def test_get_runtime_intermediate_outputs_and_op_names(self): self.assertIn((4,), runtime_outputs) self.assertIn((4,), op_names) self.assertTrue( - torch.allclose(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0])) + torch.allclose( + runtime_outputs[(4,)][0][0], torch.tensor([4.0, 5.0, 6.0]) + ) ) self.assertEqual(op_names[(4,)], ["op_3"]) @@ -641,8 +650,6 @@ def test_get_runtime_intermediate_outputs_and_op_names(self): for key in range(5, 9): self.assertIn((key,), runtime_outputs) self.assertIn((key,), op_names) - self.assertEqual(runtime_outputs[(key,)][0].size(0), RAW_DATA_SIZE) - self.assertEqual(op_names[(key,)], [f"op_{key-1}"]) def test_calculate_numeric_gap(self): # Create a context manager to patch functions called by Inspector.__init__ @@ -668,8 +675,8 @@ def test_calculate_numeric_gap(self): } runtime_intermediate_outputs = { - (0,): torch.tensor([2.0, 1.0, 4.0]), - (1,): torch.tensor([3.0, 6.0, 5.0]), + (0,): ([torch.tensor([2.0, 1.0, 4.0])], 1), + (1,): ([torch.tensor([3.0, 6.0, 5.0])], 1), } aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} @@ -709,12 +716,121 @@ def test_calculate_numeric_gap(self): self.assertTrue( torch.allclose( row["runtime_intermediate_output"], - runtime_intermediate_outputs[key], + runtime_intermediate_outputs[key][0][0], ) ) # gap should equal 3.0 self.assertEqual(row["gap"][0], 3.0) + @unittest.skip("ci config values are not propagated") + def test_intermediate_tensor_comparison_with_torch_export(self): + """Test intermediate tensor comparison using torch.export.export_for_training and to_edge_transform_and_lower.""" + + class SimpleTestModel(torch.nn.Module): + """A simple test model for demonstration purposes.""" + + def __init__(self, hidden_size: int = 32, num_layers: int = 2): + super().__init__() + self.layers = torch.nn.ModuleList( + [ + torch.nn.Linear(hidden_size, hidden_size) + for _ in range(num_layers) + ] + ) + self.activation = torch.nn.ReLU() + self.output_layer = torch.nn.Linear(hidden_size, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.activation(self.layers[0](x)) + y = self.activation(self.layers[1](x)) + return y, self.output_layer(x) + + # Create test model and inputs + model = SimpleTestModel(hidden_size=32, num_layers=2) + model.eval() + + # Create representative inputs (smaller for faster testing) + batch_size, seq_len, hidden_size = 1, 8, 32 + input_tensor = torch.randn(batch_size, seq_len, hidden_size) + example_inputs = (input_tensor,) + representative_inputs = [example_inputs] + + with tempfile.TemporaryDirectory() as tmp_dir: + model_path = os.path.join(tmp_dir, "model.pte") + etrecord_path = os.path.join(tmp_dir, "etrecord.bin") + + # Step 1: Export using torch.export.export_for_training + exported_program = torch.export.export_for_training(model, example_inputs) + self.assertIsNotNone(exported_program) + + # Step 2: Lower to XNNPACK with generate_etrecord=True + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + edge_program_manager = to_edge_transform_and_lower( + exported_program, + partitioner=[XnnpackPartitioner()], + compile_config=edge_compile_config, + generate_etrecord=True, + ) + self.assertIsNotNone(edge_program_manager) + + # Step 3: Generate ETRecord from edge program manager + # Step 4: Convert to executorch and save as PTE + executorch_program = edge_program_manager.to_executorch() + et_record = executorch_program.get_etrecord() + self.assertIsNotNone(et_record) + + # Update with representative inputs + flattened_x = pytree.tree_flatten(representative_inputs[0])[0] + et_record.update_representative_inputs(flattened_x) + et_record.save(etrecord_path) + + with open(model_path, "wb") as f: + executorch_program.write_to_file(f) + + # Step 5: Test intermediate output comparison using pybind APIs + # Read the PTE file + with open(model_path, "rb") as f: + pte_buffer = f.read() + + etdump_path = os.path.join(tmp_dir, "etdump.etdp") + debug_buffer_path = os.path.join(tmp_dir, "debug_buffer.bin") + + # Load the PTE file with ETDump enabled using pybind API + executorch_module = _load_for_executorch_from_buffer( + pte_buffer, + enable_etdump=True, + debug_buffer_size=1024 * 1024, # 1MB for testing + ) + self.assertIsNotNone(executorch_module) + + # Run the model with the given input using pybind API + flattened_x = pytree.tree_flatten(representative_inputs[0])[0] + executorch_module.run_method("forward", tuple(flattened_x)) + + # Write the ETDump results to a file using pybind API + executorch_module.write_etdump_result_to_file( + etdump_path, debug_buffer_path + ) + + # Step 6: Use Inspector API to compare intermediate outputs + try: + inspector = Inspector( + etdump_path=etdump_path, + etrecord=etrecord_path, + debug_buffer_path=debug_buffer_path, + ) + except FileNotFoundError as e: + new_message = f"{e} You likely need to run the test with --config executorch.event_tracer_enabled=true" + raise RuntimeError(new_message) from e + self.assertIsNotNone(inspector) + + # Calculate numerical gap using SNR metric + df = inspector.calculate_numeric_gap("SNR") + + # Verify that we got some intermediate tensor comparisons + # The exact number will depend on the model structure and partitioning + self.assertEqual(len(df), 2) + def _gen_random_float_list(self) -> List[float]: return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)] diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index ea8c0e653af..26fe38acfac 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -233,20 +233,20 @@ def test_compare_results_uint8(self): def test_merge_overlapping_debug_handles_basic(self): big_tensor = torch.rand(100, 100) intermediate_outputs = { - (1, 2, 3): (1, "val1"), - (2, 3, 4, 5): (2, "val2"), - (6, 7, 8): (3, "val3"), - (10, 11): (4, "val4"), - (11, 12): (5, big_tensor), + (1, 2, 3): (1, "val1", 1), + (2, 3, 4, 5): (2, "val2", 1), + (6, 7, 8): (3, "val3", 1), + (10, 11): (4, "val4", 1), + (11, 12): (5, big_tensor, 1), } # basic merge behavior intermediate_outputs = merge_runtime_overlapping_debug_handles( intermediate_outputs ) expected_intermediate_outputs = { - (1, 2, 3, 4, 5): (2, "val2"), - (6, 7, 8): (3, "val3"), - (10, 11, 12): (5, big_tensor), + (1, 2, 3, 4, 5): (2, "val2", 1), + (6, 7, 8): (3, "val3", 1), + (10, 11, 12): (5, big_tensor, 1), } self.assertEqual(intermediate_outputs, expected_intermediate_outputs) self.assertIs(expected_intermediate_outputs[(10, 11, 12)][1], big_tensor) @@ -258,11 +258,11 @@ def test_merge_overlapping_debug_handles_non_continuous(self): tensor4 = torch.randn(6, 7) tensor5 = torch.randn(8, 9) intermediate_outputs = { - (1, 10): (1, tensor1), - (2, 5): (2, tensor2), - (1, 7, 9): (3, tensor3), - (11, 13): (4, tensor4), - (11, 15): (5, tensor5), + (1, 10): (1, tensor1, 1), + (2, 5): (2, tensor2, 1), + (1, 7, 9): (3, tensor3, 1), + (11, 13): (4, tensor4, 1), + (11, 15): (5, tensor5, 1), } intermediate_outputs = merge_runtime_overlapping_debug_handles( intermediate_outputs @@ -280,22 +280,22 @@ def test_merge_overlapping_debug_handles_non_continuous(self): def test_merge_overlapping_debug_handles_edge_cases(self): intermediate_outputs = { - (9,): (1, "val1"), + (9,): (1, "val1", 1), ( 9, 9, 9, - ): (2, "val2"), + ): (2, "val2", 1), ( 9, 9, - ): (3, "val3"), + ): (3, "val3", 1), } intermediate_outputs = merge_runtime_overlapping_debug_handles( intermediate_outputs ) expected_intermediate_outputs = { - (9,): (3, "val3"), + (9,): (3, "val3", 1), } self.assertEqual(intermediate_outputs, expected_intermediate_outputs) @@ -312,7 +312,7 @@ def test_map_runtime_aot_intermediate_outputs_empty_inputs(self): def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self): # Single element tuple aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300} - runtime_intermediate_outputs = {(0,): 150, (1,): 250, (2,): 350} + runtime_intermediate_outputs = {(0,): (150, 1), (1,): (250, 1), (2,): (350, 1)} actual = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) @@ -326,7 +326,7 @@ def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self): def test_map_runtime_aot_intermediate_outputs_no_overlaps(self): # No overlaps between aot and runtime debug_handles aot_intermediate_outputs = {(0,): 100, (4,): 300} - runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300} + runtime_intermediate_outputs = {(2, 3): (200, 1), (8, 9): (300, 1)} actual = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) @@ -336,7 +336,7 @@ def test_map_runtime_aot_intermediate_outputs_no_overlaps(self): def test_map_runtime_aot_intermediate_outputs_partial_match(self): # Partial match between aot and runtime debug_handles will return empty aot_intermediate_outputs = {(2,): 100, (9,): 300} - runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300} + runtime_intermediate_outputs = {(2, 3): (200, 1), (8, 9): (300, 1)} actual = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) @@ -346,7 +346,7 @@ def test_map_runtime_aot_intermediate_outputs_partial_match(self): def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self): # Multiple aot debug_handles map to one runtime debug_handle aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300, (3,): 400} - runtime_intermediate_outputs = {(2, 3, 1): 250, (8, 9): 300} + runtime_intermediate_outputs = {(2, 3, 1): (250, 1), (8, 9): (300, 1)} actual = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) @@ -357,35 +357,45 @@ def test_map_runtime_aot_intermediate_outputs_delegated(self): # Currently, runtime_intermediate_output logs all delegate call arguments # Test that the map function correctly extracted out the delegated outputs aot_intermediate_outputs = { - (1,): torch.tensor([4, 1]), + (1,): torch.tensor([1, 2, 3]), (2,): torch.tensor([4, 5]), (3,): torch.tensor([10, 10, 13]), (4,): torch.tensor([10, 11, 12]), (5,): torch.tensor([13, 14, 15, 16, 21]), - (6,): torch.tensor([13, 14, 15, 16, 17]), + (6,): torch.tensor([2]), } runtime_intermediate_outputs = { - (1, 2): [torch.tensor([1, 2, 3]), torch.tensor([4, 5])], - (3, 4): [ - torch.tensor([6, 7, 8, 9]), - torch.tensor(1), - torch.tensor([10, 11, 12]), - ], - (5, 6): [ - torch.tensor([1]), - torch.tensor([2]), - torch.tensor([13, 14, 15, 16, 17]), - ], + (1, 2): ([torch.tensor([1, 2, 3]), torch.tensor([4, 5])], 2), + (3, 4): ( + [ + torch.tensor([10, 10, 13]), + torch.tensor([10, 11, 12]), + ], + 2, + ), + (5, 6): ( + [ + torch.tensor([13, 14, 15, 16, 21]), + torch.tensor([2]), + ], + 2, + ), } actual = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) expected = { + ((1, 2), torch.tensor([1, 2, 3])): ((1, 2), torch.tensor([1, 2, 3])), ((1, 2), torch.tensor([4, 5])): ((1, 2), torch.tensor([4, 5])), + ((3, 4), torch.tensor([10, 10, 13])): ((3, 4), torch.tensor([10, 10, 13])), ((3, 4), torch.tensor([10, 11, 12])): ((3, 4), torch.tensor([10, 11, 12])), - ((5, 6), torch.tensor([13, 14, 15, 16, 17])): ( + ((5, 6), torch.tensor([13, 14, 15, 16, 21])): ( (5, 6), - torch.tensor([13, 14, 15, 16, 17]), + torch.tensor([13, 14, 15, 16, 21]), + ), + ((5, 6), torch.tensor([2])): ( + (5, 6), + torch.tensor([2]), ), } self.assertEqual(len(actual), len(expected)) @@ -399,8 +409,10 @@ def test_map_runtime_aot_intermediate_outputs_delegated(self): act_runtime_key, act_runtime_value, ) in actual.items(): - if exp_aot_key == act_aot_key and torch.allclose( - exp_aot_value, act_aot_value + if ( + exp_aot_key == act_aot_key + and exp_aot_value.numel() == act_aot_value.numel() + and torch.allclose(exp_aot_value, act_aot_value) ): found = True self.assertEqual(exp_runtime_key, act_runtime_key) diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index cb849dde11a..0618871bd40 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -45,6 +45,11 @@ class EmitterOutput: str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]] ] + # This dictionary maps the method name to the corresponding dict which + # contains the mapping of the instruction ids to the number of outputs + # generated by each instruction. + instruction_id_to_num_outs_map: Dict[str, Dict[int, int]] + mutable_data: Optional[List[Buffer]] # Constants are optionally stored in external files. @@ -148,6 +153,7 @@ def emit_program( plans = [] debug_handle_map = {} method_to_delegate_debug_id_map = {} + instruction_id_to_num_outs_map = {} program_state = _ProgramState() # emit each entry point in order according to name. @@ -176,6 +182,7 @@ def emit_program( method_to_delegate_debug_id_map[name] = ( emitter.instr_id_to_delegate_debug_id_map ) + instruction_id_to_num_outs_map[name] = emitter.instruction_id_to_num_outs_map training_metadata = _get_training_metadata(methods) if len(training_metadata) > 0: @@ -188,6 +195,7 @@ def emit_program( return EmitterOutput( debug_handle_map=debug_handle_map, method_to_delegate_debug_id_map=method_to_delegate_debug_id_map, + instruction_id_to_num_outs_map=instruction_id_to_num_outs_map, program=Program( version=EXECUTORCH_SCHEMA_VERSION, execution_plan=plans, diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 80ba389c270..6995f9f73a9 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -253,6 +253,7 @@ def __init__( self.concrete_output_ids: List[_AbstractValue] = [] self.debug_handle_map: Dict[int, Union[int, List[int]]] = {} + self.instruction_id_to_num_outs_map: Dict[int, int] = {} self.instr_id_to_delegate_debug_id_map: Dict[ int, Dict[str, Union[str, _DelegateDebugIdentifierMap]] ] = {} @@ -1003,7 +1004,11 @@ def _add_debug_handle( and node.meta.get("debug_handle") is not None ): debug_handle_list.append(node.meta.get("debug_handle")) + output_node = lowered_module.original_module.graph.output_node() + outputs = output_node.args[0] + num_outputs = len(outputs) if isinstance(outputs, (list, tuple)) else 1 self.debug_handle_map[emitter_id] = debug_handle_list + self.instruction_id_to_num_outs_map[emitter_id] = num_outputs # Debug handle for this node is the emitter_id which is essentially the index of the # instruction in the chain. self.node.meta["debug_handle"] = emitter_id diff --git a/exir/program/_program.py b/exir/program/_program.py index 921a2b1fab4..b6eba6b5d18 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -644,6 +644,14 @@ def delegate_map( return self._emitter_output.method_to_delegate_debug_id_map return self._get_emitter_output().method_to_delegate_debug_id_map + @property + def instruction_id_to_num_outs_map( + self, + ) -> Dict[str, Dict[int, Union[int, List[int]]]]: + if self._emitter_output: + return self._emitter_output.instruction_id_to_num_outs_map + return self._get_emitter_output().instruction_id_to_num_outs_map + @property def graph_module(self) -> torch.fx.GraphModule: return self.exported_program.graph_module @@ -1860,6 +1868,12 @@ def delegate_map( ) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]: return self._emitter_output.method_to_delegate_debug_id_map + @property + def instruction_id_to_num_outs_map( + self, + ) -> Dict[str, Dict[int, Union[int, List[int]]]]: + return self._emitter_output.instruction_id_to_num_outs_map + @property def executorch_program(self) -> Program: """