From 6b37502842f81cd8c269664b99dbcd149a3f1ff6 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 31 Jul 2025 02:10:20 -0700 Subject: [PATCH] make to_edge_transform_and_lower support etrecord generation Differential Revision: [D79336982](https://our.internmc.facebook.com/intern/diff/D79336982/) [ghstack-poisoned] --- devtools/etrecord/tests/etrecord_test.py | 162 +++++++++++++++++++++++ exir/program/_program.py | 48 ++++++- 2 files changed, 208 insertions(+), 2 deletions(-) diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 2937ce64ccf..11b687e25d0 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -24,6 +24,7 @@ ETRecordReservedFileNames, ) from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge +from executorch.exir.program._program import to_edge_transform_and_lower from torch.export import export @@ -275,6 +276,167 @@ def test_etrecord_generation_with_exported_program(self): # Validate that export_graph_id matches the expected value self.assertEqual(etrecord.export_graph_id, expected_graph_id) + def test_to_edge_transform_and_lower_with_etrecord_generation(self): + """Test that to_edge_transform_and_lower generates ETRecord correctly.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Test with generate_etrecord=True + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Verify that ETRecord was generated and attached + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + + # Verify that ETRecord has the expected data + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.assertIsNotNone(etrecord.edge_dialect_program) + + # Verify the exported program matches the input + self.check_graph_closeness( + etrecord.exported_program, + aten_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(aten_program.graph), + ) + + # Verify the edge dialect program matches the edge manager + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_manager.exported_program().graph_module, + ) + + def test_to_edge_transform_and_lower_without_etrecord_generation(self): + """Test that to_edge_transform_and_lower works correctly without ETRecord generation.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Test with generate_etrecord=False (default) + edge_manager = to_edge_transform_and_lower(aten_program) + + # Verify that no ETRecord was generated + self.assertIsNone(edge_manager._etrecord) + + # Verify that the edge manager still works correctly + self.assertIsNotNone(edge_manager.exported_program()) + + def test_get_etrecord_from_executorch_program_manager(self): + """Test getting ETRecord from ExecutorchProgramManager using get_etrecord() method.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager with ETRecord + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Convert to executorch + et_manager = edge_manager.to_executorch() + + # Test get_etrecord method + etrecord = et_manager.get_etrecord() + self.assertIsNotNone(etrecord) + + # Verify that the returned ETRecord has all expected data + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.assertIsNotNone(etrecord.edge_dialect_program) + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + + # Verify the data matches the original input + self.check_graph_closeness( + etrecord.exported_program, + aten_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(aten_program.graph), + ) + + # Verify the executorch program data matches + # ETRecord stores data directly (not JSON serialized), so compare with original data + self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map) + self.assertEqual(etrecord._delegate_map, et_manager.delegate_map) + + def test_get_etrecord_from_executorch_program_manager_without_generation(self): + """Test getting ETRecord from ExecutorchProgramManager when ETRecord was not generated.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager without ETRecord + edge_manager = to_edge_transform_and_lower(aten_program) + + # Verify no ETRecord on edge manager + self.assertIsNone(edge_manager._etrecord) + + # Convert to executorch + et_manager = edge_manager.to_executorch() + + # Verify no ETRecord on executorch manager + self.assertIsNone(et_manager._etrecord) + + # Test get_etrecord method should raise RuntimeError + with self.assertRaises(RuntimeError) as context: + et_manager.get_etrecord() + + self.assertIn("ETRecord was not generated", str(context.exception)) + + def test_to_edge_transform_and_lower_etrecord_save_and_parse(self): + """Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed.""" + f = models.BasicSinMax() + aten_program = export(f, f.get_random_inputs(), strict=True) + + # Generate edge manager with ETRecord + edge_manager = to_edge_transform_and_lower( + aten_program, + generate_etrecord=True, + ) + + # Convert to executorch to get complete ETRecord + et_manager = edge_manager.to_executorch() + etrecord = et_manager.get_etrecord() + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_flow2.bin" + + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + # Note: Skip graph structure comparison due to transformation differences + self.check_graph_closeness( + etrecord.exported_program, parsed_etrecord.exported_program + ) + self.check_graph_closeness( + etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program + ) + + # Validate executorch program data + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_manager.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_manager.delegate_map)), + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(aten_program.graph), + ) + def test_add_extra_export_modules(self): """Test add_extra_export_modules when ETRecord already has a graph_map.""" captured_output, edge_output, et_output = self.get_test_model() diff --git a/exir/program/_program.py b/exir/program/_program.py index 8bbe0833b85..63b49d9860d 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -291,6 +291,15 @@ def _copy_module(new_prog, new_gm): setattr(new_prog, node.target, t) +def _create_empty_etrecord(): + # Import etrecord at runtime to resolve cyclic dependencies (program -> etrecord -> program). + # This also ensures that etrecord-related packages do not affect the export flow. + # @manual + from executorch.devtools.etrecord import ETRecord + + return ETRecord() + + def lift_constant_tensor_pass(ep): """ Takes an ExportedProgram and returns the ExportedProgram modified in-place, @@ -1103,6 +1112,7 @@ def _gen_edge_manager_for_partitioners( aten_programs: Dict[str, ExportedProgram], config: EdgeCompileConfig, constant_methods: Optional[Dict[str, Any]], + generate_etrecord: Optional[bool] = False, ) -> "EdgeProgramManager": """ Generates EdgeProgramManager for subsequent lowering to the @@ -1179,6 +1189,13 @@ def _gen_edge_manager_for_partitioners( config, list(set().union(*ops_set_to_not_decompose_by_program.values())), ) + + if generate_etrecord: + etrecord = _create_empty_etrecord() + etrecord.add_exported_program(aten_programs) + etrecord.add_edge_dialect_program(copy.deepcopy(edge_manager)) + edge_manager._etrecord = etrecord + return edge_manager @@ -1220,6 +1237,7 @@ def to_edge_transform_and_lower( # noqa: C901 ] = None, constant_methods: Optional[Dict[str, Any]] = None, compile_config: Optional[EdgeCompileConfig] = None, + generate_etrecord: bool = False, ) -> "EdgeProgramManager": """ :func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of @@ -1260,6 +1278,8 @@ def to_edge_transform_and_lower( # noqa: C901 compile_config: An optional argument used to provide greater control over the transformation to edge dialect process. + generate_etrecord: An optional argument used to generate an etrecord for debugging purposes. + Returns: EdgeProgramManager """ @@ -1279,7 +1299,7 @@ def to_edge_transform_and_lower( # noqa: C901 partitioner, aten_programs ) edge_manager = _gen_edge_manager_for_partitioners( - partitioner, aten_programs, config, constant_methods + partitioner, aten_programs, config, constant_methods, generate_etrecord ) if transform_passes is not None: @@ -1447,6 +1467,8 @@ def __init__( program, self._named_data_store ) + self._etrecord = None + @property def methods(self) -> Set[str]: """ @@ -1643,13 +1665,19 @@ def to_executorch( _copy_module(program.graph_module, new_gm) execution_programs[name] = program - return ExecutorchProgramManager( + et_pm = ExecutorchProgramManager( execution_programs, self._config_methods, config, self._named_data_store.get_named_data_store_output(), ) + if self._etrecord is not None: + self._etrecord.add_executorch_program(et_pm) + et_pm._etrecord = self._etrecord + + return et_pm + class ExecutorchProgramManager: """ @@ -1713,6 +1741,7 @@ def __init__( self._named_data, ) self._buffer: Optional[bytes] = None + self._etrecord = None @property def methods(self) -> Set[str]: @@ -1785,6 +1814,21 @@ def buffer(self) -> bytes: self._buffer = bytes(self._pte_data) return self._buffer + def get_etrecord(self): + """ + Get the generated ETRecord if etrecord generation was enabled. + + Returns: + ETRecord object if generation was enabled, None otherwise + + Raises: + RuntimeError: if ETRecord object was not generated. + """ + + if self._etrecord is None: + raise RuntimeError("ETRecord was not generated") + return self._etrecord + def write_to_file(self, open_file: io.BufferedIOBase) -> None: """ Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over