diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 2937ce64ccf..25ea5a25e1f 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -24,6 +24,7 @@ ETRecordReservedFileNames, ) from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge +from executorch.exir.program._program import to_edge_transform_and_lower from torch.export import export @@ -52,6 +53,21 @@ def assert_etrecord_is_empty(self, etrecord: ETRecord) -> None: self.assert_etrecord_has_no_executorch_program(etrecord) self.assertIsNone(etrecord.graph_map) + def assert_legal_etrecord_in_edge_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has all expected data after to_edge_transform_and_lower() or to_edge() stage""" + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.assertIsNotNone(etrecord.edge_dialect_program) + self.assert_etrecord_has_no_executorch_program(etrecord) + + def assert_etrecord_saveable(self, etrecord: ETRecord) -> None: + """Assert ETRecord contains all essential information for saving""" + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.assertIsNotNone(etrecord.edge_dialect_program) + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + def get_test_model(self): f = models.BasicSinMax() captured_output = exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) @@ -275,6 +291,157 @@ def test_etrecord_generation_with_exported_program(self): # Validate that export_graph_id matches the expected value self.assertEqual(etrecord.export_graph_id, expected_graph_id) + def test_to_edge_transform_and_lower_with_etrecord_generation(self): + """Test that to_edge_transform_and_lower generates ETRecord correctly.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Test with generate_etrecord=True + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Verify that ETRecord was generated and attached + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + self.assert_legal_etrecord_in_edge_program(etrecord) + + # Verify the exported program matches the input + self.check_graph_closeness( + etrecord.exported_program, + aten_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(aten_program.graph), + ) + + # Verify the edge dialect program matches the edge manager + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_manager.exported_program().graph_module, + ) + + def test_to_edge_transform_and_lower_without_etrecord_generation(self): + """Test that to_edge_transform_and_lower works correctly without ETRecord generation.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Test with generate_etrecord=False (default) + edge_manager = to_edge_transform_and_lower(aten_program) + + # Verify that no ETRecord was generated + self.assertIsNone(edge_manager._etrecord) + + # Verify that the edge manager still works correctly + self.assertIsNotNone(edge_manager.exported_program()) + + def test_get_etrecord_from_executorch_program_manager(self): + """Test getting ETRecord from ExecutorchProgramManager using get_etrecord() method.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager with ETRecord + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Convert to executorch + et_manager = edge_manager.to_executorch() + + # Test get_etrecord method + etrecord = et_manager.get_etrecord() + self.assertIsNotNone(etrecord) + self.assert_etrecord_saveable(etrecord) + + # Verify the data matches the original input + self.check_graph_closeness( + etrecord.exported_program, + aten_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(aten_program.graph), + ) + + # Verify the executorch program data matches + # ETRecord stores data directly (not JSON serialized), so compare with original data + self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map) + self.assertEqual(etrecord._delegate_map, et_manager.delegate_map) + + def test_get_etrecord_from_executorch_program_manager_without_generation(self): + """Test getting ETRecord from ExecutorchProgramManager when ETRecord was not generated.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager without ETRecord + edge_manager = to_edge_transform_and_lower(aten_program) + + # Verify no ETRecord on edge manager + self.assertIsNone(edge_manager._etrecord) + + # Convert to executorch + et_manager = edge_manager.to_executorch() + + # Verify no ETRecord on executorch manager + self.assertIsNone(et_manager._etrecord) + + # Test get_etrecord method should raise RuntimeError + with self.assertRaises(RuntimeError) as context: + et_manager.get_etrecord() + + self.assertIn("ETRecord was not generated", str(context.exception)) + + def test_to_edge_transform_and_lower_etrecord_save_and_parse(self): + """Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager with ETRecord + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Convert to executorch to get complete ETRecord + et_manager = edge_manager.to_executorch() + etrecord = et_manager.get_etrecord() + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_flow2.bin" + + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + # Note: Skip graph structure comparison due to transformation differences + self.check_graph_closeness( + etrecord.exported_program, parsed_etrecord.exported_program + ) + self.check_graph_closeness( + etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program + ) + + # Validate executorch program data + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_manager.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_manager.delegate_map)), + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(aten_program.graph), + ) + def test_add_extra_export_modules(self): """Test add_extra_export_modules when ETRecord already has a graph_map.""" captured_output, edge_output, et_output = self.get_test_model() diff --git a/exir/program/_program.py b/exir/program/_program.py index 8bbe0833b85..63b49d9860d 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -291,6 +291,15 @@ def _copy_module(new_prog, new_gm): setattr(new_prog, node.target, t) +def _create_empty_etrecord(): + # Import etrecord at runtime to resolve cyclic dependencies (program -> etrecord -> program). + # This also ensures that etrecord-related packages do not affect the export flow. + # @manual + from executorch.devtools.etrecord import ETRecord + + return ETRecord() + + def lift_constant_tensor_pass(ep): """ Takes an ExportedProgram and returns the ExportedProgram modified in-place, @@ -1103,6 +1112,7 @@ def _gen_edge_manager_for_partitioners( aten_programs: Dict[str, ExportedProgram], config: EdgeCompileConfig, constant_methods: Optional[Dict[str, Any]], + generate_etrecord: Optional[bool] = False, ) -> "EdgeProgramManager": """ Generates EdgeProgramManager for subsequent lowering to the @@ -1179,6 +1189,13 @@ def _gen_edge_manager_for_partitioners( config, list(set().union(*ops_set_to_not_decompose_by_program.values())), ) + + if generate_etrecord: + etrecord = _create_empty_etrecord() + etrecord.add_exported_program(aten_programs) + etrecord.add_edge_dialect_program(copy.deepcopy(edge_manager)) + edge_manager._etrecord = etrecord + return edge_manager @@ -1220,6 +1237,7 @@ def to_edge_transform_and_lower( # noqa: C901 ] = None, constant_methods: Optional[Dict[str, Any]] = None, compile_config: Optional[EdgeCompileConfig] = None, + generate_etrecord: bool = False, ) -> "EdgeProgramManager": """ :func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of @@ -1260,6 +1278,8 @@ def to_edge_transform_and_lower( # noqa: C901 compile_config: An optional argument used to provide greater control over the transformation to edge dialect process. + generate_etrecord: An optional argument used to generate an etrecord for debugging purposes. + Returns: EdgeProgramManager """ @@ -1279,7 +1299,7 @@ def to_edge_transform_and_lower( # noqa: C901 partitioner, aten_programs ) edge_manager = _gen_edge_manager_for_partitioners( - partitioner, aten_programs, config, constant_methods + partitioner, aten_programs, config, constant_methods, generate_etrecord ) if transform_passes is not None: @@ -1447,6 +1467,8 @@ def __init__( program, self._named_data_store ) + self._etrecord = None + @property def methods(self) -> Set[str]: """ @@ -1643,13 +1665,19 @@ def to_executorch( _copy_module(program.graph_module, new_gm) execution_programs[name] = program - return ExecutorchProgramManager( + et_pm = ExecutorchProgramManager( execution_programs, self._config_methods, config, self._named_data_store.get_named_data_store_output(), ) + if self._etrecord is not None: + self._etrecord.add_executorch_program(et_pm) + et_pm._etrecord = self._etrecord + + return et_pm + class ExecutorchProgramManager: """ @@ -1713,6 +1741,7 @@ def __init__( self._named_data, ) self._buffer: Optional[bytes] = None + self._etrecord = None @property def methods(self) -> Set[str]: @@ -1785,6 +1814,21 @@ def buffer(self) -> bytes: self._buffer = bytes(self._pte_data) return self._buffer + def get_etrecord(self): + """ + Get the generated ETRecord if etrecord generation was enabled. + + Returns: + ETRecord object if generation was enabled, None otherwise + + Raises: + RuntimeError: if ETRecord object was not generated. + """ + + if self._etrecord is None: + raise RuntimeError("ETRecord was not generated") + return self._etrecord + def write_to_file(self, open_file: io.BufferedIOBase) -> None: """ Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over