From 867a66b92011d9817357887e4ca60a1d35ec7ec8 Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Sun, 7 Dec 2025 22:20:28 -0800 Subject: [PATCH] [Executorch][Export][2/N] Add getter utils to help with getting the artifacts easily Add getter utility methods to ExportSession for easier artifact retrieval. - `get_exported_program` - to get output from export stage - `get_edge_program_manager` - multiple stages produce edge program manager, for example, to_edge, transform, to_backend, go in reverse order to get the most processed edge program manager. Differential Revision: [D87576720](https://our.internmc.facebook.com/intern/diff/D87576720/) [ghstack-poisoned] --- export/export.py | 59 ++++++++++- export/tests/test_export_session.py | 154 ++++++++++++++++++++++++++++ 2 files changed, 208 insertions(+), 5 deletions(-) diff --git a/export/export.py b/export/export.py index 226a3a06eda..f09b7818c6b 100644 --- a/export/export.py +++ b/export/export.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch +from executorch.exir import EdgeProgramManager from executorch.exir._warnings import experimental from executorch.exir.program import ExecutorchProgramManager from executorch.exir.schema import Program @@ -461,17 +462,65 @@ def export(self) -> None: def get_stage_artifacts(self) -> Dict[StageType, PipelineArtifact]: return self._stage_to_artifacts - def save_pte_file(self, path: str) -> None: + def get_exported_program(self, method_name: str = "forward") -> ExportedProgram: """ - Save the exported program to a PTE file. + Get the ExportedProgram for a specific method after torch export. Args: - path: Path where the PTE file will be saved + method_name: Name of the method to get exported program for, defaults to "forward" + + Returns: + The ExportedProgram for the specified method Raises: - RuntimeError: If the executorch program manager is not initialized + RuntimeError: If torch export stage has not been run + KeyError: If the method name is not found in exported programs """ - self.get_executorch_program_manager().save(path) + artifact = self._stage_to_artifacts.get(StageType.TORCH_EXPORT) + if artifact is None or artifact.data is None: + raise RuntimeError( + "Exported program is not available. Run Torch Export Stage first." + ) + + exported_programs = artifact.data + if method_name not in exported_programs: + raise KeyError( + f"Method name '{method_name}' not found in exported programs. " + f"Available methods: {list(exported_programs.keys())}" + ) + + return exported_programs[method_name] + + def get_edge_program_manager(self) -> "EdgeProgramManager": + """ + Get the EdgeProgramManager after edge lowering stages. + + This method checks multiple stages in order of preference: + 1. TO_EDGE_TRANSFORM_AND_LOWER (combined stage) + 2. TO_BACKEND (separate stage with backend delegation) + 3. TO_EDGE (separate stage without backend delegation) + + Returns: + The EdgeProgramManager + + Raises: + RuntimeError: If no edge stage has been run + """ + # Check stages in order of preference + for stage_type in [ + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_BACKEND, + StageType.TO_EDGE, + ]: + artifact = self._stage_to_artifacts.get(stage_type) + if artifact is not None and artifact.data is not None: + logging.info(f"Returning edge program manager from stage {stage_type}") + return artifact.data + + raise RuntimeError( + "Edge program manager is not available. " + "Run one of the edge stages first: TO_EDGE_TRANSFORM_AND_LOWER, TO_EDGE, or TO_BACKEND." + ) def get_executorch_program(self) -> Program: """ diff --git a/export/tests/test_export_session.py b/export/tests/test_export_session.py index d28c369eaa6..d3a3d68d42b 100644 --- a/export/tests/test_export_session.py +++ b/export/tests/test_export_session.py @@ -817,3 +817,157 @@ def test_exported_program_valid_pipeline(self) -> None: # Should not raise during validation session._validate_pipeline_sequence(recipe.pipeline_stages) + + +class TestIntermediateStateGetters(unittest.TestCase): + """Test convenience getters for intermediate pipeline states.""" + + def setUp(self) -> None: + self.model = SimpleTestModel() + self.example_inputs = [(torch.randn(2, 10),)] + + def test_get_exported_program_after_torch_export(self) -> None: + """Test that get_exported_program works after torch export stage.""" + recipe = ExportRecipe( + name="test", + pipeline_stages=[ + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ], + ) + + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=recipe, + ) + + session.export() + + exported_program = session.get_exported_program() + self.assertIsNotNone(exported_program) + self.assertIsInstance(exported_program, torch.export.ExportedProgram) + + def test_get_exported_program_before_export_fails(self) -> None: + """Test that get_exported_program fails before torch export stage.""" + recipe = ExportRecipe(name="test") + + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=recipe, + ) + + with self.assertRaises(RuntimeError) as cm: + session.get_exported_program() + self.assertIn("Exported program is not available", str(cm.exception)) + + def test_get_exported_program_invalid_method_name(self) -> None: + """Test that get_exported_program fails with invalid method name.""" + recipe = ExportRecipe(name="test") + + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=recipe, + ) + + session.export() + + with self.assertRaises(KeyError) as cm: + session.get_exported_program("nonexistent_method") + self.assertIn("Method name 'nonexistent_method' not found", str(cm.exception)) + + def test_get_exported_program_multi_method(self) -> None: + """Test get_exported_program with multi-method model.""" + model_dict = { + "forward": self.model, + "inference": SimpleTestModel(), + } + inputs_dict = { + "forward": self.example_inputs, + "inference": [(torch.randn(1, 10),)], + } + + recipe = ExportRecipe(name="multi_method_test") + + session = ExportSession( + model=model_dict, + example_inputs=inputs_dict, + export_recipe=recipe, + ) + + session.export() + + forward_ep = session.get_exported_program("forward") + inference_ep = session.get_exported_program("inference") + + self.assertIsNotNone(forward_ep) + self.assertIsNotNone(inference_ep) + self.assertIsInstance(forward_ep, torch.export.ExportedProgram) + self.assertIsInstance(inference_ep, torch.export.ExportedProgram) + + def test_get_edge_program_manager_with_transform_and_lower(self) -> None: + """Test get_edge_program_manager with TO_EDGE_TRANSFORM_AND_LOWER stage.""" + recipe = ExportRecipe( + name="test", + pipeline_stages=[ + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ], + ) + + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=recipe, + ) + + session.export() + + edge_manager = session.get_edge_program_manager() + self.assertIsNotNone(edge_manager) + + def test_get_edge_program_manager_with_separate_stages(self) -> None: + """Test get_edge_program_manager with separate TO_EDGE and TO_BACKEND stages.""" + recipe = ExportRecipe( + name="test", + pipeline_stages=[ + StageType.TORCH_EXPORT, + StageType.TO_EDGE, + StageType.TO_BACKEND, + StageType.TO_EXECUTORCH, + ], + ) + + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=recipe, + ) + + session.export() + + edge_manager = session.get_edge_program_manager() + self.assertIsNotNone(edge_manager) + + def test_get_edge_program_manager_before_edge_stage_fails(self) -> None: + """Test that get_edge_program_manager fails before edge stages.""" + recipe = ExportRecipe( + name="test", + pipeline_stages=[StageType.TORCH_EXPORT], + ) + + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=recipe, + ) + + session.export() + + with self.assertRaises(RuntimeError) as cm: + session.get_edge_program_manager() + self.assertIn("Edge program manager is not available", str(cm.exception))