diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index e149aeab650..3b8a71279fd 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -200,6 +200,151 @@ def _save_edge_dialect_program( f"{base_name}_example_inputs", serialized_artifact.example_inputs ) + def add_extra_export_modules( + self, + extra_recorded_export_modules: Dict[ + str, + Union[ + ExportedProgram, + ExirExportedProgram, + EdgeProgramManager, + ], + ], + ) -> None: + """ + Add extra export modules to the ETRecord after it has been created. + + This method allows users to add more export modules they want to record + to an existing ETRecord instance. The modules will be added to the graph_map + and will be included when the ETRecord is saved. + + Args: + extra_recorded_export_modules: A dictionary of graph modules with the key being + the user provided name and the value being the corresponding exported module. + The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`. + """ + if self.graph_map is None: + self.graph_map = {} + + # Now self.graph_map is guaranteed to be non-None + graph_map = self.graph_map + for module_name, export_module in extra_recorded_export_modules.items(): + _add_module_to_graph_map(graph_map, module_name, export_module) + + def add_executorch_program( + self, + executorch_program: Union[ + ExecutorchProgram, + ExecutorchProgramManager, + BundledProgram, + ], + ) -> None: + """ + Add executorch program data to the ETRecord after it has been created. + + This method allows users to add executorch program data they want to record + to an existing ETRecord instance. The executorch program data includes debug handle map, + delegate map, reference outputs, and representative inputs that will be included + when the ETRecord is saved. + + Args: + executorch_program: The ExecuTorch program for this model returned by the call to + `to_executorch()` or the `BundledProgram` of this model. + + Raises: + RuntimeError: If executorch program data already exists in the ETRecord. + """ + # Check if executorch program data already exists + if ( + self._debug_handle_map is not None + or self._delegate_map is not None + or self._reference_outputs is not None + or self._representative_inputs is not None + ): + raise RuntimeError( + "Executorch program data already exists in the ETRecord. " + "Cannot add executorch program data when it already exists." + ) + + # Process executorch program and extract data + debug_handle_map, delegate_map, reference_outputs, representative_inputs = ( + _process_executorch_program(executorch_program) + ) + + # Set the extracted data + self._debug_handle_map = debug_handle_map + self._delegate_map = delegate_map + self._reference_outputs = reference_outputs + self._representative_inputs = representative_inputs + + def add_exported_program( + self, + exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]], + ) -> None: + """ + Add exported program to the ETRecord after it has been created. + + This method allows users to add an exported program they want to record + to an existing ETRecord instance. The exported program will be included + when the ETRecord is saved. + + Args: + exported_program: The exported program for this model returned by the call to + `torch.export()` or a dictionary with method names as keys and exported programs as values. + Can be None, in which case no exported program data will be added. + + Raises: + RuntimeError: If exported program already exists in the ETRecord. + """ + # Check if exported program already exists + if self.exported_program is not None or self.export_graph_id is not None: + raise RuntimeError( + "Exported program already exists in the ETRecord. " + "Cannot add exported program when it already exists." + ) + + # Process exported program and extract data + processed_exported_program, export_graph_id = _process_exported_program( + exported_program + ) + + # Set the extracted data + self.exported_program = processed_exported_program + self.export_graph_id = export_graph_id + + def add_edge_dialect_program( + self, + edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram], + ) -> None: + """ + Add edge dialect program to the ETRecord after it has been created. + + This method allows users to add an edge dialect program they want to record + to an existing ETRecord instance. The edge dialect program will be included + when the ETRecord is saved. + + Args: + edge_dialect_program: The edge dialect program for this model returned by the call to + `to_edge()` or `EdgeProgramManager` for this model. + + Raises: + RuntimeError: If edge dialect program already exists in the ETRecord. + """ + # Check if edge dialect program already exists + if self.edge_dialect_program is not None: + raise RuntimeError( + "Edge dialect program already exists in the ETRecord. " + "Cannot add edge dialect program when it already exists." + ) + + # Process edge dialect program and extract data + processed_edge_dialect_program = _process_edge_dialect_program( + edge_dialect_program + ) + + # Set the extracted data + self.edge_dialect_program = processed_edge_dialect_program + def _get_reference_outputs( bundled_program: BundledProgram, @@ -285,37 +430,24 @@ def generate_etrecord( Returns: None """ - # Process all inputs and prepare data for ETRecord construction - processed_exported_program, export_graph_id = _process_exported_program( - exported_program - ) - graph_map = _process_extra_recorded_modules(extra_recorded_export_modules) - processed_edge_dialect_program = _process_edge_dialect_program(edge_dialect_program) - debug_handle_map, delegate_map, reference_outputs, representative_inputs = ( - _process_executorch_program(executorch_program) - ) + etrecord = ETRecord() + etrecord.add_exported_program(exported_program) + etrecord.add_edge_dialect_program(edge_dialect_program) + etrecord.add_executorch_program(executorch_program) - # Create ETRecord instance and save - etrecord = ETRecord( - exported_program=processed_exported_program, - export_graph_id=export_graph_id, - edge_dialect_program=processed_edge_dialect_program, - graph_map=graph_map if graph_map else None, - _debug_handle_map=debug_handle_map, - _delegate_map=delegate_map, - _reference_outputs=reference_outputs, - _representative_inputs=representative_inputs, - ) + # Add extra export modules if user provided + if extra_recorded_export_modules is not None: + etrecord.add_extra_export_modules(extra_recorded_export_modules) etrecord.save(et_record) def _process_exported_program( exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]] -) -> tuple[Optional[ExportedProgram], int]: +) -> tuple[Optional[ExportedProgram], Optional[int]]: """Process exported program and return the processed program and export graph id.""" processed_exported_program = None - export_graph_id = 0 + export_graph_id = None if exported_program is not None: if isinstance(exported_program, dict) and "forward" in exported_program: @@ -329,29 +461,6 @@ def _process_exported_program( return processed_exported_program, export_graph_id -def _process_extra_recorded_modules( - extra_recorded_export_modules: Optional[ - Dict[ - str, - Union[ - ExportedProgram, - ExirExportedProgram, - EdgeProgramManager, - ], - ] - ] -) -> Dict[str, ExportedProgram]: - """Process extra recorded export modules and return graph map.""" - graph_map = {} - - if extra_recorded_export_modules is not None: - for module_name, export_module in extra_recorded_export_modules.items(): - _validate_module_name(module_name) - _add_module_to_graph_map(graph_map, module_name, export_module) - - return graph_map - - def _validate_module_name(module_name: str) -> None: """Validate that module name is not a reserved name.""" contains_reserved_name = any( @@ -369,6 +478,8 @@ def _add_module_to_graph_map( export_module: Union[ExportedProgram, ExirExportedProgram, EdgeProgramManager], ) -> None: """Add export module to graph map based on its type.""" + _validate_module_name(module_name) + if isinstance(export_module, ExirExportedProgram): graph_map[f"{module_name}/forward"] = export_module.exported_program elif isinstance(export_module, ExportedProgram): diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 9b9f3290162..25ea5a25e1f 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -24,11 +24,50 @@ ETRecordReservedFileNames, ) from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge +from executorch.exir.program._program import to_edge_transform_and_lower from torch.export import export # TODO : T154728484 Add test cases to cover multiple entry points class TestETRecord(unittest.TestCase): + def assert_etrecord_has_no_exported_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no exported program data.""" + self.assertIsNone(etrecord.exported_program) + self.assertIsNone(etrecord.export_graph_id) + + def assert_etrecord_has_no_edge_dialect_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no edge dialect program data.""" + self.assertIsNone(etrecord.edge_dialect_program) + + def assert_etrecord_has_no_executorch_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no executorch program data.""" + self.assertIsNone(etrecord._debug_handle_map) + self.assertIsNone(etrecord._delegate_map) + self.assertIsNone(etrecord._reference_outputs) + self.assertIsNone(etrecord._representative_inputs) + + def assert_etrecord_is_empty(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no data at all.""" + self.assert_etrecord_has_no_exported_program(etrecord) + self.assert_etrecord_has_no_edge_dialect_program(etrecord) + self.assert_etrecord_has_no_executorch_program(etrecord) + self.assertIsNone(etrecord.graph_map) + + def assert_legal_etrecord_in_edge_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has all expected data after to_edge_transform_and_lower() or to_edge() stage""" + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.assertIsNotNone(etrecord.edge_dialect_program) + self.assert_etrecord_has_no_executorch_program(etrecord) + + def assert_etrecord_saveable(self, etrecord: ETRecord) -> None: + """Assert ETRecord contains all essential information for saving""" + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.assertIsNotNone(etrecord.edge_dialect_program) + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + def get_test_model(self): f = models.BasicSinMax() captured_output = exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) @@ -252,6 +291,224 @@ def test_etrecord_generation_with_exported_program(self): # Validate that export_graph_id matches the expected value self.assertEqual(etrecord.export_graph_id, expected_graph_id) + def test_to_edge_transform_and_lower_with_etrecord_generation(self): + """Test that to_edge_transform_and_lower generates ETRecord correctly.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Test with generate_etrecord=True + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Verify that ETRecord was generated and attached + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + self.assert_legal_etrecord_in_edge_program(etrecord) + + # Verify the exported program matches the input + self.check_graph_closeness( + etrecord.exported_program, + aten_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(aten_program.graph), + ) + + # Verify the edge dialect program matches the edge manager + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_manager.exported_program().graph_module, + ) + + def test_to_edge_transform_and_lower_without_etrecord_generation(self): + """Test that to_edge_transform_and_lower works correctly without ETRecord generation.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Test with generate_etrecord=False (default) + edge_manager = to_edge_transform_and_lower(aten_program) + + # Verify that no ETRecord was generated + self.assertIsNone(edge_manager._etrecord) + + # Verify that the edge manager still works correctly + self.assertIsNotNone(edge_manager.exported_program()) + + def test_get_etrecord_from_executorch_program_manager(self): + """Test getting ETRecord from ExecutorchProgramManager using get_etrecord() method.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager with ETRecord + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Convert to executorch + et_manager = edge_manager.to_executorch() + + # Test get_etrecord method + etrecord = et_manager.get_etrecord() + self.assertIsNotNone(etrecord) + self.assert_etrecord_saveable(etrecord) + + # Verify the data matches the original input + self.check_graph_closeness( + etrecord.exported_program, + aten_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(aten_program.graph), + ) + + # Verify the executorch program data matches + # ETRecord stores data directly (not JSON serialized), so compare with original data + self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map) + self.assertEqual(etrecord._delegate_map, et_manager.delegate_map) + + def test_get_etrecord_from_executorch_program_manager_without_generation(self): + """Test getting ETRecord from ExecutorchProgramManager when ETRecord was not generated.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager without ETRecord + edge_manager = to_edge_transform_and_lower(aten_program) + + # Verify no ETRecord on edge manager + self.assertIsNone(edge_manager._etrecord) + + # Convert to executorch + et_manager = edge_manager.to_executorch() + + # Verify no ETRecord on executorch manager + self.assertIsNone(et_manager._etrecord) + + # Test get_etrecord method should raise RuntimeError + with self.assertRaises(RuntimeError) as context: + et_manager.get_etrecord() + + self.assertIn("ETRecord was not generated", str(context.exception)) + + def test_to_edge_transform_and_lower_etrecord_save_and_parse(self): + """Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager with ETRecord + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Convert to executorch to get complete ETRecord + et_manager = edge_manager.to_executorch() + etrecord = et_manager.get_etrecord() + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_flow2.bin" + + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + # Note: Skip graph structure comparison due to transformation differences + self.check_graph_closeness( + etrecord.exported_program, parsed_etrecord.exported_program + ) + self.check_graph_closeness( + etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program + ) + + # Validate executorch program data + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_manager.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_manager.delegate_map)), + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(aten_program.graph), + ) + + def test_add_extra_export_modules(self): + """Test add_extra_export_modules when ETRecord already has a graph_map.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing graph_map + initial_graph_map = { + "existing_module/forward": captured_output.exported_program + } + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + graph_map=initial_graph_map, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state + self.assertIsNotNone(etrecord.graph_map) + self.assertIn("existing_module/forward", etrecord.graph_map) + + # Create additional module to add + f2 = models.BasicSinMax() + captured_output2 = exir.capture( + f2, f2.get_random_inputs(), exir.CaptureConfig() + ) + + extra_modules = { + "new_module": captured_output2.exported_program, + } + + # Add extra export modules + etrecord.add_extra_export_modules(extra_modules) + + # Verify both existing and new modules are present + self.assertIn("existing_module/forward", etrecord.graph_map) + self.assertIn("new_module/forward", etrecord.graph_map) + + # Verify the modules are correctly stored + self.check_graph_closeness( + etrecord.graph_map["existing_module/forward"], + captured_output.exported_program.graph_module, + ) + self.check_graph_closeness( + etrecord.graph_map["new_module/forward"], + captured_output2.exported_program.graph_module, + ) + + def test_add_extra_export_modules_reserved_name_validation(self): + """Test that add_extra_export_modules validates reserved names.""" + captured_output, edge_output, et_output = self.get_test_model() + + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Test that reserved names are rejected + for reserved_name in ETRecordReservedFileNames: + with self.assertRaises(RuntimeError): + etrecord.add_extra_export_modules( + {reserved_name: captured_output.exported_program} + ) + def test_etrecord_class_constructor_and_save(self): """Test that ETRecord class constructor and save method work correctly.""" captured_output, edge_output, et_output = self.get_test_model() @@ -406,3 +663,570 @@ def test_etrecord_generation_with_exported_program_dict(self): # Validate that export_graph_id matches the expected value self.assertEqual(etrecord.export_graph_id, expected_graph_id) + + def test_add_executorch_program(self): + """Test add_executorch_program when ETRecord has no existing executorch program data.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + ) + + # Verify initial state - no executorch program data + self.assert_etrecord_has_no_executorch_program(etrecord) + + # Add executorch program + etrecord.add_executorch_program(et_output) + + # Verify executorch program data is now present + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + self.assertEqual( + etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) + # For regular ExecutorchProgram, reference_outputs and representative_inputs should be None + self.assertIsNone(etrecord._reference_outputs) + self.assertIsNone(etrecord._representative_inputs) + + def test_add_executorch_program_with_bundled_program(self): + """Test add_executorch_program with BundledProgram.""" + ( + captured_output, + edge_output, + bundled_program, + ) = self.get_test_model_with_bundled_program() + + # Create an ETRecord instance without executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + ) + + # Verify initial state - no executorch program data + self.assertIsNone(etrecord._debug_handle_map) + self.assertIsNone(etrecord._delegate_map) + self.assertIsNone(etrecord._reference_outputs) + self.assertIsNone(etrecord._representative_inputs) + + # Add bundled program + etrecord.add_executorch_program(bundled_program) + + # Verify executorch program data is now present + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + self.assertIsNotNone(etrecord._reference_outputs) + self.assertIsNotNone(etrecord._representative_inputs) + + # Verify the data matches expected values + expected_reference_outputs = _get_reference_outputs(bundled_program) + expected_representative_inputs = _get_representative_inputs(bundled_program) + + # Compare reference outputs + self.assertTrue( + torch.equal( + etrecord._reference_outputs["forward"][0][0], + expected_reference_outputs["forward"][0][0], + ) + ) + self.assertTrue( + torch.equal( + etrecord._reference_outputs["forward"][1][0], + expected_reference_outputs["forward"][1][0], + ) + ) + + # Compare representative inputs + for expected, actual in zip( + etrecord._representative_inputs, expected_representative_inputs + ): + self.assertTrue(torch.equal(expected[0], actual[0])) + self.assertTrue(torch.equal(expected[1], actual[1])) + + def test_add_executorch_program_already_exists_exception(self): + """Test that add_executorch_program raises exception when executorch program data already exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify that adding executorch program raises RuntimeError + with self.assertRaises(RuntimeError) as context: + etrecord.add_executorch_program(et_output) + + self.assertIn( + "Executorch program data already exists in the ETRecord", + str(context.exception), + ) + + def test_add_executorch_program_partial_data_exists_exception(self): + """Test that add_executorch_program raises exception when partial executorch program data exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with only debug_handle_map (partial data) + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + ) + + # Verify that adding executorch program raises RuntimeError even with partial data + with self.assertRaises(RuntimeError) as context: + etrecord.add_executorch_program(et_output) + + self.assertIn( + "Executorch program data already exists in the ETRecord", + str(context.exception), + ) + + def test_add_executorch_program_and_save(self): + """Test that ETRecord with added executorch program can be saved and parsed correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + ) + + # Add executorch program + etrecord.add_executorch_program(et_output) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_with_added_program.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate executorch program data + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_exported_program(self): + """Test add_exported_program when ETRecord has no existing exported program.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no exported program + self.assert_etrecord_has_no_exported_program(etrecord) + + # Add exported program + etrecord.add_exported_program(captured_output.exported_program) + + # Verify exported program is now present + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.check_graph_closeness( + etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_exported_program_with_dict(self): + """Test add_exported_program with dictionary input.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no exported program + self.assertIsNone(etrecord.exported_program) + self.assertIsNone(etrecord.export_graph_id) + + # Add exported program as dictionary + exported_program_dict = {"forward": captured_output.exported_program} + etrecord.add_exported_program(exported_program_dict) + + # Verify exported program is now present + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.check_graph_closeness( + etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_exported_program_already_exists_exception(self): + """Test that add_exported_program raises exception when exported program already exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing exported program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Create another exported program to try to add + f2 = models.BasicSinMax() + captured_output2 = exir.capture( + f2, f2.get_random_inputs(), exir.CaptureConfig() + ) + + # Verify that adding exported program raises RuntimeError + with self.assertRaises(RuntimeError) as context: + etrecord.add_exported_program(captured_output2.exported_program) + + self.assertIn( + "Exported program already exists in the ETRecord", + str(context.exception), + ) + + def test_add_exported_program_partial_data_exists_exception(self): + """Test that add_exported_program raises exception when partial exported program data exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with only export_graph_id (partial data) + etrecord = ETRecord( + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify that adding exported program raises RuntimeError even with partial data + with self.assertRaises(RuntimeError) as context: + etrecord.add_exported_program(captured_output.exported_program) + + self.assertIn( + "Exported program already exists in the ETRecord", + str(context.exception), + ) + + def test_add_exported_program_with_none(self): + """Test add_exported_program with None input.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no exported program + self.assert_etrecord_has_no_exported_program(etrecord) + + # Add None exported program (should not raise error) + etrecord.add_exported_program(None) + + # Verify exported program is still None + self.assert_etrecord_has_no_exported_program(etrecord) + + def test_add_exported_program_and_save(self): + """Test that ETRecord with added exported program can be saved and parsed correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Add exported program + etrecord.add_exported_program(captured_output.exported_program) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_with_added_exported_program.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_edge_dialect_program(self): + """Test add_edge_dialect_program when ETRecord has no existing edge dialect program.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no edge dialect program + self.assert_etrecord_has_no_edge_dialect_program(etrecord) + + # Add edge dialect program + etrecord.add_edge_dialect_program(edge_output) + + # Verify edge dialect program is now present + self.assertIsNotNone(etrecord.edge_dialect_program) + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + def test_add_edge_dialect_program_with_exir_exported_program(self): + """Test add_edge_dialect_program with ExirExportedProgram.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no edge dialect program + self.assertIsNone(etrecord.edge_dialect_program) + + # Create ExirExportedProgram from captured output + exir_exported_program = captured_output.to_edge( + exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) + ) + + # Add edge dialect program using ExirExportedProgram + etrecord.add_edge_dialect_program(exir_exported_program) + + # Verify edge dialect program is now present + self.assertIsNotNone(etrecord.edge_dialect_program) + self.check_graph_closeness( + etrecord.edge_dialect_program, + exir_exported_program.exported_program.graph_module, + ) + + def test_add_edge_dialect_program_already_exists_exception(self): + """Test that add_edge_dialect_program raises exception when edge dialect program already exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Create another edge program to try to add + f2 = models.BasicSinMax() + captured_output2 = exir.capture( + f2, f2.get_random_inputs(), exir.CaptureConfig() + ) + edge_output2 = captured_output2.to_edge( + exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) + ) + + # Verify that adding edge dialect program raises RuntimeError + with self.assertRaises(RuntimeError) as context: + etrecord.add_edge_dialect_program(edge_output2) + + self.assertIn( + "Edge dialect program already exists in the ETRecord", + str(context.exception), + ) + + def test_add_edge_dialect_program_and_save(self): + """Test that ETRecord with added edge dialect program can be saved and parsed correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Add edge dialect program + etrecord.add_edge_dialect_program(edge_output) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_with_added_edge_program.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_all_programs_sequentially(self): + """Test adding all programs sequentially to an empty ETRecord.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an empty ETRecord instance + etrecord = ETRecord() + + # Verify initial state - everything is None + self.assert_etrecord_is_empty(etrecord) + + # Add exported program + etrecord.add_exported_program(captured_output.exported_program) + + # Add edge dialect program + etrecord.add_edge_dialect_program(edge_output) + + # Add executorch program + etrecord.add_executorch_program(et_output) + + # Verify all components are now present + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.assertIsNotNone(etrecord.edge_dialect_program) + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + + # Verify the data matches expected values + self.check_graph_closeness( + etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + self.assertEqual( + etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) + + # Test that the complete ETRecord can be saved and parsed + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_complete.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate all metadata + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) diff --git a/exir/program/_program.py b/exir/program/_program.py index 8bbe0833b85..63b49d9860d 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -291,6 +291,15 @@ def _copy_module(new_prog, new_gm): setattr(new_prog, node.target, t) +def _create_empty_etrecord(): + # Import etrecord at runtime to resolve cyclic dependencies (program -> etrecord -> program). + # This also ensures that etrecord-related packages do not affect the export flow. + # @manual + from executorch.devtools.etrecord import ETRecord + + return ETRecord() + + def lift_constant_tensor_pass(ep): """ Takes an ExportedProgram and returns the ExportedProgram modified in-place, @@ -1103,6 +1112,7 @@ def _gen_edge_manager_for_partitioners( aten_programs: Dict[str, ExportedProgram], config: EdgeCompileConfig, constant_methods: Optional[Dict[str, Any]], + generate_etrecord: Optional[bool] = False, ) -> "EdgeProgramManager": """ Generates EdgeProgramManager for subsequent lowering to the @@ -1179,6 +1189,13 @@ def _gen_edge_manager_for_partitioners( config, list(set().union(*ops_set_to_not_decompose_by_program.values())), ) + + if generate_etrecord: + etrecord = _create_empty_etrecord() + etrecord.add_exported_program(aten_programs) + etrecord.add_edge_dialect_program(copy.deepcopy(edge_manager)) + edge_manager._etrecord = etrecord + return edge_manager @@ -1220,6 +1237,7 @@ def to_edge_transform_and_lower( # noqa: C901 ] = None, constant_methods: Optional[Dict[str, Any]] = None, compile_config: Optional[EdgeCompileConfig] = None, + generate_etrecord: bool = False, ) -> "EdgeProgramManager": """ :func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of @@ -1260,6 +1278,8 @@ def to_edge_transform_and_lower( # noqa: C901 compile_config: An optional argument used to provide greater control over the transformation to edge dialect process. + generate_etrecord: An optional argument used to generate an etrecord for debugging purposes. + Returns: EdgeProgramManager """ @@ -1279,7 +1299,7 @@ def to_edge_transform_and_lower( # noqa: C901 partitioner, aten_programs ) edge_manager = _gen_edge_manager_for_partitioners( - partitioner, aten_programs, config, constant_methods + partitioner, aten_programs, config, constant_methods, generate_etrecord ) if transform_passes is not None: @@ -1447,6 +1467,8 @@ def __init__( program, self._named_data_store ) + self._etrecord = None + @property def methods(self) -> Set[str]: """ @@ -1643,13 +1665,19 @@ def to_executorch( _copy_module(program.graph_module, new_gm) execution_programs[name] = program - return ExecutorchProgramManager( + et_pm = ExecutorchProgramManager( execution_programs, self._config_methods, config, self._named_data_store.get_named_data_store_output(), ) + if self._etrecord is not None: + self._etrecord.add_executorch_program(et_pm) + et_pm._etrecord = self._etrecord + + return et_pm + class ExecutorchProgramManager: """ @@ -1713,6 +1741,7 @@ def __init__( self._named_data, ) self._buffer: Optional[bytes] = None + self._etrecord = None @property def methods(self) -> Set[str]: @@ -1785,6 +1814,21 @@ def buffer(self) -> bytes: self._buffer = bytes(self._pte_data) return self._buffer + def get_etrecord(self): + """ + Get the generated ETRecord if etrecord generation was enabled. + + Returns: + ETRecord object if generation was enabled, None otherwise + + Raises: + RuntimeError: if ETRecord object was not generated. + """ + + if self._etrecord is None: + raise RuntimeError("ETRecord was not generated") + return self._etrecord + def write_to_file(self, open_file: io.BufferedIOBase) -> None: """ Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over