From 4650a00fa4fe42dcecfc9b1a1cd0864f011096d4 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Mon, 4 Aug 2025 15:57:43 -0700 Subject: [PATCH] bring etrecord updated "reverted" by gh patch fix bot back Summary: D79279401 D79336982 and D79294945 was landed last week but got "reverted" by gh patch fix on Saturday D78689027 due to didn't merge gh PR on time. This diff brings the updates back. Differential Revision: D79599520 --- devtools/etrecord/_etrecord.py | 201 ++++-- devtools/etrecord/tests/etrecord_test.py | 824 +++++++++++++++++++++++ exir/program/_program.py | 48 +- 3 files changed, 1026 insertions(+), 47 deletions(-) diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index e149aeab650..3b8a71279fd 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -200,6 +200,151 @@ def _save_edge_dialect_program( f"{base_name}_example_inputs", serialized_artifact.example_inputs ) + def add_extra_export_modules( + self, + extra_recorded_export_modules: Dict[ + str, + Union[ + ExportedProgram, + ExirExportedProgram, + EdgeProgramManager, + ], + ], + ) -> None: + """ + Add extra export modules to the ETRecord after it has been created. + + This method allows users to add more export modules they want to record + to an existing ETRecord instance. The modules will be added to the graph_map + and will be included when the ETRecord is saved. + + Args: + extra_recorded_export_modules: A dictionary of graph modules with the key being + the user provided name and the value being the corresponding exported module. + The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`. + """ + if self.graph_map is None: + self.graph_map = {} + + # Now self.graph_map is guaranteed to be non-None + graph_map = self.graph_map + for module_name, export_module in extra_recorded_export_modules.items(): + _add_module_to_graph_map(graph_map, module_name, export_module) + + def add_executorch_program( + self, + executorch_program: Union[ + ExecutorchProgram, + ExecutorchProgramManager, + BundledProgram, + ], + ) -> None: + """ + Add executorch program data to the ETRecord after it has been created. + + This method allows users to add executorch program data they want to record + to an existing ETRecord instance. The executorch program data includes debug handle map, + delegate map, reference outputs, and representative inputs that will be included + when the ETRecord is saved. + + Args: + executorch_program: The ExecuTorch program for this model returned by the call to + `to_executorch()` or the `BundledProgram` of this model. + + Raises: + RuntimeError: If executorch program data already exists in the ETRecord. + """ + # Check if executorch program data already exists + if ( + self._debug_handle_map is not None + or self._delegate_map is not None + or self._reference_outputs is not None + or self._representative_inputs is not None + ): + raise RuntimeError( + "Executorch program data already exists in the ETRecord. " + "Cannot add executorch program data when it already exists." + ) + + # Process executorch program and extract data + debug_handle_map, delegate_map, reference_outputs, representative_inputs = ( + _process_executorch_program(executorch_program) + ) + + # Set the extracted data + self._debug_handle_map = debug_handle_map + self._delegate_map = delegate_map + self._reference_outputs = reference_outputs + self._representative_inputs = representative_inputs + + def add_exported_program( + self, + exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]], + ) -> None: + """ + Add exported program to the ETRecord after it has been created. + + This method allows users to add an exported program they want to record + to an existing ETRecord instance. The exported program will be included + when the ETRecord is saved. + + Args: + exported_program: The exported program for this model returned by the call to + `torch.export()` or a dictionary with method names as keys and exported programs as values. + Can be None, in which case no exported program data will be added. + + Raises: + RuntimeError: If exported program already exists in the ETRecord. + """ + # Check if exported program already exists + if self.exported_program is not None or self.export_graph_id is not None: + raise RuntimeError( + "Exported program already exists in the ETRecord. " + "Cannot add exported program when it already exists." + ) + + # Process exported program and extract data + processed_exported_program, export_graph_id = _process_exported_program( + exported_program + ) + + # Set the extracted data + self.exported_program = processed_exported_program + self.export_graph_id = export_graph_id + + def add_edge_dialect_program( + self, + edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram], + ) -> None: + """ + Add edge dialect program to the ETRecord after it has been created. + + This method allows users to add an edge dialect program they want to record + to an existing ETRecord instance. The edge dialect program will be included + when the ETRecord is saved. + + Args: + edge_dialect_program: The edge dialect program for this model returned by the call to + `to_edge()` or `EdgeProgramManager` for this model. + + Raises: + RuntimeError: If edge dialect program already exists in the ETRecord. + """ + # Check if edge dialect program already exists + if self.edge_dialect_program is not None: + raise RuntimeError( + "Edge dialect program already exists in the ETRecord. " + "Cannot add edge dialect program when it already exists." + ) + + # Process edge dialect program and extract data + processed_edge_dialect_program = _process_edge_dialect_program( + edge_dialect_program + ) + + # Set the extracted data + self.edge_dialect_program = processed_edge_dialect_program + def _get_reference_outputs( bundled_program: BundledProgram, @@ -285,37 +430,24 @@ def generate_etrecord( Returns: None """ - # Process all inputs and prepare data for ETRecord construction - processed_exported_program, export_graph_id = _process_exported_program( - exported_program - ) - graph_map = _process_extra_recorded_modules(extra_recorded_export_modules) - processed_edge_dialect_program = _process_edge_dialect_program(edge_dialect_program) - debug_handle_map, delegate_map, reference_outputs, representative_inputs = ( - _process_executorch_program(executorch_program) - ) + etrecord = ETRecord() + etrecord.add_exported_program(exported_program) + etrecord.add_edge_dialect_program(edge_dialect_program) + etrecord.add_executorch_program(executorch_program) - # Create ETRecord instance and save - etrecord = ETRecord( - exported_program=processed_exported_program, - export_graph_id=export_graph_id, - edge_dialect_program=processed_edge_dialect_program, - graph_map=graph_map if graph_map else None, - _debug_handle_map=debug_handle_map, - _delegate_map=delegate_map, - _reference_outputs=reference_outputs, - _representative_inputs=representative_inputs, - ) + # Add extra export modules if user provided + if extra_recorded_export_modules is not None: + etrecord.add_extra_export_modules(extra_recorded_export_modules) etrecord.save(et_record) def _process_exported_program( exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]] -) -> tuple[Optional[ExportedProgram], int]: +) -> tuple[Optional[ExportedProgram], Optional[int]]: """Process exported program and return the processed program and export graph id.""" processed_exported_program = None - export_graph_id = 0 + export_graph_id = None if exported_program is not None: if isinstance(exported_program, dict) and "forward" in exported_program: @@ -329,29 +461,6 @@ def _process_exported_program( return processed_exported_program, export_graph_id -def _process_extra_recorded_modules( - extra_recorded_export_modules: Optional[ - Dict[ - str, - Union[ - ExportedProgram, - ExirExportedProgram, - EdgeProgramManager, - ], - ] - ] -) -> Dict[str, ExportedProgram]: - """Process extra recorded export modules and return graph map.""" - graph_map = {} - - if extra_recorded_export_modules is not None: - for module_name, export_module in extra_recorded_export_modules.items(): - _validate_module_name(module_name) - _add_module_to_graph_map(graph_map, module_name, export_module) - - return graph_map - - def _validate_module_name(module_name: str) -> None: """Validate that module name is not a reserved name.""" contains_reserved_name = any( @@ -369,6 +478,8 @@ def _add_module_to_graph_map( export_module: Union[ExportedProgram, ExirExportedProgram, EdgeProgramManager], ) -> None: """Add export module to graph map based on its type.""" + _validate_module_name(module_name) + if isinstance(export_module, ExirExportedProgram): graph_map[f"{module_name}/forward"] = export_module.exported_program elif isinstance(export_module, ExportedProgram): diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 9b9f3290162..25ea5a25e1f 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -24,11 +24,50 @@ ETRecordReservedFileNames, ) from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge +from executorch.exir.program._program import to_edge_transform_and_lower from torch.export import export # TODO : T154728484 Add test cases to cover multiple entry points class TestETRecord(unittest.TestCase): + def assert_etrecord_has_no_exported_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no exported program data.""" + self.assertIsNone(etrecord.exported_program) + self.assertIsNone(etrecord.export_graph_id) + + def assert_etrecord_has_no_edge_dialect_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no edge dialect program data.""" + self.assertIsNone(etrecord.edge_dialect_program) + + def assert_etrecord_has_no_executorch_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no executorch program data.""" + self.assertIsNone(etrecord._debug_handle_map) + self.assertIsNone(etrecord._delegate_map) + self.assertIsNone(etrecord._reference_outputs) + self.assertIsNone(etrecord._representative_inputs) + + def assert_etrecord_is_empty(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no data at all.""" + self.assert_etrecord_has_no_exported_program(etrecord) + self.assert_etrecord_has_no_edge_dialect_program(etrecord) + self.assert_etrecord_has_no_executorch_program(etrecord) + self.assertIsNone(etrecord.graph_map) + + def assert_legal_etrecord_in_edge_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has all expected data after to_edge_transform_and_lower() or to_edge() stage""" + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.assertIsNotNone(etrecord.edge_dialect_program) + self.assert_etrecord_has_no_executorch_program(etrecord) + + def assert_etrecord_saveable(self, etrecord: ETRecord) -> None: + """Assert ETRecord contains all essential information for saving""" + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.assertIsNotNone(etrecord.edge_dialect_program) + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + def get_test_model(self): f = models.BasicSinMax() captured_output = exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) @@ -252,6 +291,224 @@ def test_etrecord_generation_with_exported_program(self): # Validate that export_graph_id matches the expected value self.assertEqual(etrecord.export_graph_id, expected_graph_id) + def test_to_edge_transform_and_lower_with_etrecord_generation(self): + """Test that to_edge_transform_and_lower generates ETRecord correctly.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Test with generate_etrecord=True + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Verify that ETRecord was generated and attached + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + self.assert_legal_etrecord_in_edge_program(etrecord) + + # Verify the exported program matches the input + self.check_graph_closeness( + etrecord.exported_program, + aten_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(aten_program.graph), + ) + + # Verify the edge dialect program matches the edge manager + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_manager.exported_program().graph_module, + ) + + def test_to_edge_transform_and_lower_without_etrecord_generation(self): + """Test that to_edge_transform_and_lower works correctly without ETRecord generation.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Test with generate_etrecord=False (default) + edge_manager = to_edge_transform_and_lower(aten_program) + + # Verify that no ETRecord was generated + self.assertIsNone(edge_manager._etrecord) + + # Verify that the edge manager still works correctly + self.assertIsNotNone(edge_manager.exported_program()) + + def test_get_etrecord_from_executorch_program_manager(self): + """Test getting ETRecord from ExecutorchProgramManager using get_etrecord() method.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager with ETRecord + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Convert to executorch + et_manager = edge_manager.to_executorch() + + # Test get_etrecord method + etrecord = et_manager.get_etrecord() + self.assertIsNotNone(etrecord) + self.assert_etrecord_saveable(etrecord) + + # Verify the data matches the original input + self.check_graph_closeness( + etrecord.exported_program, + aten_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(aten_program.graph), + ) + + # Verify the executorch program data matches + # ETRecord stores data directly (not JSON serialized), so compare with original data + self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map) + self.assertEqual(etrecord._delegate_map, et_manager.delegate_map) + + def test_get_etrecord_from_executorch_program_manager_without_generation(self): + """Test getting ETRecord from ExecutorchProgramManager when ETRecord was not generated.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager without ETRecord + edge_manager = to_edge_transform_and_lower(aten_program) + + # Verify no ETRecord on edge manager + self.assertIsNone(edge_manager._etrecord) + + # Convert to executorch + et_manager = edge_manager.to_executorch() + + # Verify no ETRecord on executorch manager + self.assertIsNone(et_manager._etrecord) + + # Test get_etrecord method should raise RuntimeError + with self.assertRaises(RuntimeError) as context: + et_manager.get_etrecord() + + self.assertIn("ETRecord was not generated", str(context.exception)) + + def test_to_edge_transform_and_lower_etrecord_save_and_parse(self): + """Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager with ETRecord + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Convert to executorch to get complete ETRecord + et_manager = edge_manager.to_executorch() + etrecord = et_manager.get_etrecord() + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_flow2.bin" + + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + # Note: Skip graph structure comparison due to transformation differences + self.check_graph_closeness( + etrecord.exported_program, parsed_etrecord.exported_program + ) + self.check_graph_closeness( + etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program + ) + + # Validate executorch program data + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_manager.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_manager.delegate_map)), + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(aten_program.graph), + ) + + def test_add_extra_export_modules(self): + """Test add_extra_export_modules when ETRecord already has a graph_map.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing graph_map + initial_graph_map = { + "existing_module/forward": captured_output.exported_program + } + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + graph_map=initial_graph_map, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state + self.assertIsNotNone(etrecord.graph_map) + self.assertIn("existing_module/forward", etrecord.graph_map) + + # Create additional module to add + f2 = models.BasicSinMax() + captured_output2 = exir.capture( + f2, f2.get_random_inputs(), exir.CaptureConfig() + ) + + extra_modules = { + "new_module": captured_output2.exported_program, + } + + # Add extra export modules + etrecord.add_extra_export_modules(extra_modules) + + # Verify both existing and new modules are present + self.assertIn("existing_module/forward", etrecord.graph_map) + self.assertIn("new_module/forward", etrecord.graph_map) + + # Verify the modules are correctly stored + self.check_graph_closeness( + etrecord.graph_map["existing_module/forward"], + captured_output.exported_program.graph_module, + ) + self.check_graph_closeness( + etrecord.graph_map["new_module/forward"], + captured_output2.exported_program.graph_module, + ) + + def test_add_extra_export_modules_reserved_name_validation(self): + """Test that add_extra_export_modules validates reserved names.""" + captured_output, edge_output, et_output = self.get_test_model() + + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Test that reserved names are rejected + for reserved_name in ETRecordReservedFileNames: + with self.assertRaises(RuntimeError): + etrecord.add_extra_export_modules( + {reserved_name: captured_output.exported_program} + ) + def test_etrecord_class_constructor_and_save(self): """Test that ETRecord class constructor and save method work correctly.""" captured_output, edge_output, et_output = self.get_test_model() @@ -406,3 +663,570 @@ def test_etrecord_generation_with_exported_program_dict(self): # Validate that export_graph_id matches the expected value self.assertEqual(etrecord.export_graph_id, expected_graph_id) + + def test_add_executorch_program(self): + """Test add_executorch_program when ETRecord has no existing executorch program data.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + ) + + # Verify initial state - no executorch program data + self.assert_etrecord_has_no_executorch_program(etrecord) + + # Add executorch program + etrecord.add_executorch_program(et_output) + + # Verify executorch program data is now present + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + self.assertEqual( + etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) + # For regular ExecutorchProgram, reference_outputs and representative_inputs should be None + self.assertIsNone(etrecord._reference_outputs) + self.assertIsNone(etrecord._representative_inputs) + + def test_add_executorch_program_with_bundled_program(self): + """Test add_executorch_program with BundledProgram.""" + ( + captured_output, + edge_output, + bundled_program, + ) = self.get_test_model_with_bundled_program() + + # Create an ETRecord instance without executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + ) + + # Verify initial state - no executorch program data + self.assertIsNone(etrecord._debug_handle_map) + self.assertIsNone(etrecord._delegate_map) + self.assertIsNone(etrecord._reference_outputs) + self.assertIsNone(etrecord._representative_inputs) + + # Add bundled program + etrecord.add_executorch_program(bundled_program) + + # Verify executorch program data is now present + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + self.assertIsNotNone(etrecord._reference_outputs) + self.assertIsNotNone(etrecord._representative_inputs) + + # Verify the data matches expected values + expected_reference_outputs = _get_reference_outputs(bundled_program) + expected_representative_inputs = _get_representative_inputs(bundled_program) + + # Compare reference outputs + self.assertTrue( + torch.equal( + etrecord._reference_outputs["forward"][0][0], + expected_reference_outputs["forward"][0][0], + ) + ) + self.assertTrue( + torch.equal( + etrecord._reference_outputs["forward"][1][0], + expected_reference_outputs["forward"][1][0], + ) + ) + + # Compare representative inputs + for expected, actual in zip( + etrecord._representative_inputs, expected_representative_inputs + ): + self.assertTrue(torch.equal(expected[0], actual[0])) + self.assertTrue(torch.equal(expected[1], actual[1])) + + def test_add_executorch_program_already_exists_exception(self): + """Test that add_executorch_program raises exception when executorch program data already exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify that adding executorch program raises RuntimeError + with self.assertRaises(RuntimeError) as context: + etrecord.add_executorch_program(et_output) + + self.assertIn( + "Executorch program data already exists in the ETRecord", + str(context.exception), + ) + + def test_add_executorch_program_partial_data_exists_exception(self): + """Test that add_executorch_program raises exception when partial executorch program data exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with only debug_handle_map (partial data) + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + ) + + # Verify that adding executorch program raises RuntimeError even with partial data + with self.assertRaises(RuntimeError) as context: + etrecord.add_executorch_program(et_output) + + self.assertIn( + "Executorch program data already exists in the ETRecord", + str(context.exception), + ) + + def test_add_executorch_program_and_save(self): + """Test that ETRecord with added executorch program can be saved and parsed correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + ) + + # Add executorch program + etrecord.add_executorch_program(et_output) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_with_added_program.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate executorch program data + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_exported_program(self): + """Test add_exported_program when ETRecord has no existing exported program.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no exported program + self.assert_etrecord_has_no_exported_program(etrecord) + + # Add exported program + etrecord.add_exported_program(captured_output.exported_program) + + # Verify exported program is now present + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.check_graph_closeness( + etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_exported_program_with_dict(self): + """Test add_exported_program with dictionary input.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no exported program + self.assertIsNone(etrecord.exported_program) + self.assertIsNone(etrecord.export_graph_id) + + # Add exported program as dictionary + exported_program_dict = {"forward": captured_output.exported_program} + etrecord.add_exported_program(exported_program_dict) + + # Verify exported program is now present + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.check_graph_closeness( + etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_exported_program_already_exists_exception(self): + """Test that add_exported_program raises exception when exported program already exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing exported program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Create another exported program to try to add + f2 = models.BasicSinMax() + captured_output2 = exir.capture( + f2, f2.get_random_inputs(), exir.CaptureConfig() + ) + + # Verify that adding exported program raises RuntimeError + with self.assertRaises(RuntimeError) as context: + etrecord.add_exported_program(captured_output2.exported_program) + + self.assertIn( + "Exported program already exists in the ETRecord", + str(context.exception), + ) + + def test_add_exported_program_partial_data_exists_exception(self): + """Test that add_exported_program raises exception when partial exported program data exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with only export_graph_id (partial data) + etrecord = ETRecord( + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify that adding exported program raises RuntimeError even with partial data + with self.assertRaises(RuntimeError) as context: + etrecord.add_exported_program(captured_output.exported_program) + + self.assertIn( + "Exported program already exists in the ETRecord", + str(context.exception), + ) + + def test_add_exported_program_with_none(self): + """Test add_exported_program with None input.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no exported program + self.assert_etrecord_has_no_exported_program(etrecord) + + # Add None exported program (should not raise error) + etrecord.add_exported_program(None) + + # Verify exported program is still None + self.assert_etrecord_has_no_exported_program(etrecord) + + def test_add_exported_program_and_save(self): + """Test that ETRecord with added exported program can be saved and parsed correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Add exported program + etrecord.add_exported_program(captured_output.exported_program) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_with_added_exported_program.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_edge_dialect_program(self): + """Test add_edge_dialect_program when ETRecord has no existing edge dialect program.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no edge dialect program + self.assert_etrecord_has_no_edge_dialect_program(etrecord) + + # Add edge dialect program + etrecord.add_edge_dialect_program(edge_output) + + # Verify edge dialect program is now present + self.assertIsNotNone(etrecord.edge_dialect_program) + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + def test_add_edge_dialect_program_with_exir_exported_program(self): + """Test add_edge_dialect_program with ExirExportedProgram.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no edge dialect program + self.assertIsNone(etrecord.edge_dialect_program) + + # Create ExirExportedProgram from captured output + exir_exported_program = captured_output.to_edge( + exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) + ) + + # Add edge dialect program using ExirExportedProgram + etrecord.add_edge_dialect_program(exir_exported_program) + + # Verify edge dialect program is now present + self.assertIsNotNone(etrecord.edge_dialect_program) + self.check_graph_closeness( + etrecord.edge_dialect_program, + exir_exported_program.exported_program.graph_module, + ) + + def test_add_edge_dialect_program_already_exists_exception(self): + """Test that add_edge_dialect_program raises exception when edge dialect program already exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Create another edge program to try to add + f2 = models.BasicSinMax() + captured_output2 = exir.capture( + f2, f2.get_random_inputs(), exir.CaptureConfig() + ) + edge_output2 = captured_output2.to_edge( + exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) + ) + + # Verify that adding edge dialect program raises RuntimeError + with self.assertRaises(RuntimeError) as context: + etrecord.add_edge_dialect_program(edge_output2) + + self.assertIn( + "Edge dialect program already exists in the ETRecord", + str(context.exception), + ) + + def test_add_edge_dialect_program_and_save(self): + """Test that ETRecord with added edge dialect program can be saved and parsed correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Add edge dialect program + etrecord.add_edge_dialect_program(edge_output) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_with_added_edge_program.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_all_programs_sequentially(self): + """Test adding all programs sequentially to an empty ETRecord.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an empty ETRecord instance + etrecord = ETRecord() + + # Verify initial state - everything is None + self.assert_etrecord_is_empty(etrecord) + + # Add exported program + etrecord.add_exported_program(captured_output.exported_program) + + # Add edge dialect program + etrecord.add_edge_dialect_program(edge_output) + + # Add executorch program + etrecord.add_executorch_program(et_output) + + # Verify all components are now present + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.assertIsNotNone(etrecord.edge_dialect_program) + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + + # Verify the data matches expected values + self.check_graph_closeness( + etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + self.assertEqual( + etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) + + # Test that the complete ETRecord can be saved and parsed + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_complete.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate all metadata + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) diff --git a/exir/program/_program.py b/exir/program/_program.py index 8bbe0833b85..63b49d9860d 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -291,6 +291,15 @@ def _copy_module(new_prog, new_gm): setattr(new_prog, node.target, t) +def _create_empty_etrecord(): + # Import etrecord at runtime to resolve cyclic dependencies (program -> etrecord -> program). + # This also ensures that etrecord-related packages do not affect the export flow. + # @manual + from executorch.devtools.etrecord import ETRecord + + return ETRecord() + + def lift_constant_tensor_pass(ep): """ Takes an ExportedProgram and returns the ExportedProgram modified in-place, @@ -1103,6 +1112,7 @@ def _gen_edge_manager_for_partitioners( aten_programs: Dict[str, ExportedProgram], config: EdgeCompileConfig, constant_methods: Optional[Dict[str, Any]], + generate_etrecord: Optional[bool] = False, ) -> "EdgeProgramManager": """ Generates EdgeProgramManager for subsequent lowering to the @@ -1179,6 +1189,13 @@ def _gen_edge_manager_for_partitioners( config, list(set().union(*ops_set_to_not_decompose_by_program.values())), ) + + if generate_etrecord: + etrecord = _create_empty_etrecord() + etrecord.add_exported_program(aten_programs) + etrecord.add_edge_dialect_program(copy.deepcopy(edge_manager)) + edge_manager._etrecord = etrecord + return edge_manager @@ -1220,6 +1237,7 @@ def to_edge_transform_and_lower( # noqa: C901 ] = None, constant_methods: Optional[Dict[str, Any]] = None, compile_config: Optional[EdgeCompileConfig] = None, + generate_etrecord: bool = False, ) -> "EdgeProgramManager": """ :func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of @@ -1260,6 +1278,8 @@ def to_edge_transform_and_lower( # noqa: C901 compile_config: An optional argument used to provide greater control over the transformation to edge dialect process. + generate_etrecord: An optional argument used to generate an etrecord for debugging purposes. + Returns: EdgeProgramManager """ @@ -1279,7 +1299,7 @@ def to_edge_transform_and_lower( # noqa: C901 partitioner, aten_programs ) edge_manager = _gen_edge_manager_for_partitioners( - partitioner, aten_programs, config, constant_methods + partitioner, aten_programs, config, constant_methods, generate_etrecord ) if transform_passes is not None: @@ -1447,6 +1467,8 @@ def __init__( program, self._named_data_store ) + self._etrecord = None + @property def methods(self) -> Set[str]: """ @@ -1643,13 +1665,19 @@ def to_executorch( _copy_module(program.graph_module, new_gm) execution_programs[name] = program - return ExecutorchProgramManager( + et_pm = ExecutorchProgramManager( execution_programs, self._config_methods, config, self._named_data_store.get_named_data_store_output(), ) + if self._etrecord is not None: + self._etrecord.add_executorch_program(et_pm) + et_pm._etrecord = self._etrecord + + return et_pm + class ExecutorchProgramManager: """ @@ -1713,6 +1741,7 @@ def __init__( self._named_data, ) self._buffer: Optional[bytes] = None + self._etrecord = None @property def methods(self) -> Set[str]: @@ -1785,6 +1814,21 @@ def buffer(self) -> bytes: self._buffer = bytes(self._pte_data) return self._buffer + def get_etrecord(self): + """ + Get the generated ETRecord if etrecord generation was enabled. + + Returns: + ETRecord object if generation was enabled, None otherwise + + Raises: + RuntimeError: if ETRecord object was not generated. + """ + + if self._etrecord is None: + raise RuntimeError("ETRecord was not generated") + return self._etrecord + def write_to_file(self, open_file: io.BufferedIOBase) -> None: """ Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over