From af246961b84a60d91408d507c97980638f85002b Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Thu, 31 Jul 2025 20:47:48 -0700 Subject: [PATCH] [Executorch][Export][2/N] Add to_edge and to_backend stages Pull Request resolved: https://github.com/pytorch/executorch/pull/12937 Address (6) in the rfc: https://github.com/pytorch/executorch/issues/12660 1. Adds stage implementations for `to_edge` and `to_backend` 2. Adds unit tests for the two stages 3. Adds these two stages in the validation pipeline. Fixes #12932 ghstack-source-id: 300019403 @exported-using-ghexport Differential Revision: [D79120576](https://our.internmc.facebook.com/intern/diff/D79120576/) --- export/export.py | 36 +++++++-- export/stages.py | 116 +++++++++++++++++++++++++++- export/tests/test_export_session.py | 11 ++- export/tests/test_export_stages.py | 104 +++++++++++++++++++++++++ export/types.py | 2 + 5 files changed, 259 insertions(+), 10 deletions(-) diff --git a/export/export.py b/export/export.py index f5b0c6149d0..ac9d894fea1 100644 --- a/export/export.py +++ b/export/export.py @@ -24,6 +24,8 @@ QuantizeStage, SourceTransformStage, Stage, + ToBackendStage, + ToEdgeStage, TorchExportStage, ) from .types import StageType @@ -147,7 +149,9 @@ def __init__( ) # Stage registry: map of StageType to Stage instances - self._stage_registry: Dict[StageType, Stage] = self._build_default_stages() + self._stage_registry: Dict[StageType, Stage] = self._build_stages( + self._pipeline_stages + ) # Intialize run context self._run_context: Dict[str, Any] = { @@ -170,10 +174,12 @@ def _get_default_pipeline(self) -> List[StageType]: StageType.TO_EXECUTORCH, ] - def _build_default_stages(self) -> Dict[StageType, Stage]: + def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]: + """Build the stage registry from the given stages.""" stage_registry: Dict[StageType, Stage] = {} - for stage_type in self._get_default_pipeline(): + stage = None + for stage_type in stages or self._get_default_pipeline(): if stage_type == StageType.SOURCE_TRANSFORM: stage = SourceTransformStage(self._quant_recipe) elif stage_type == StageType.QUANTIZE: @@ -191,12 +197,24 @@ def _build_default_stages(self) -> Dict[StageType, Stage]: transform_passes=self._export_recipe.edge_transform_passes, compile_config=self._export_recipe.edge_compile_config, ) + elif stage_type == StageType.TO_EDGE: + stage = ToEdgeStage( + edge_compile_config=self._export_recipe.edge_compile_config + ) + elif stage_type == StageType.TO_BACKEND: + stage = ToBackendStage( + partitioners=self._export_recipe.partitioners, + transform_passes=self._export_recipe.edge_transform_passes, + ) elif stage_type == StageType.TO_EXECUTORCH: stage = ExecutorchStage(self._export_recipe.executorch_backend_config) else: - raise ValueError(f"Unknown stage type: {stage_type}") + logging.info( + f"{stage_type} is unknown, you have to register it before executing export()" + ) - stage_registry[stage_type] = stage + if stage: + stage_registry[stage_type] = stage return stage_registry def register_stage(self, stage_type: StageType, stage: Stage) -> None: @@ -241,7 +259,9 @@ def _validate_pipeline_sequence( first_stage = stages[0] first_stage_instance = self._stage_registry.get(first_stage) if first_stage_instance is None: - raise ValueError(f"Stage {first_stage} not found in registry") + raise ValueError( + f"Stage {first_stage} not found in registry, register it using session.register_stage()" + ) if not first_stage_instance.can_start_pipeline: raise ValueError(f"Stage {first_stage} cannot start a pipeline. ") @@ -254,7 +274,9 @@ def _validate_pipeline_sequence( # Get the stage instance to check its valid predecessors stage_instance = self._stage_registry.get(current_stage) if stage_instance is None: - raise ValueError(f"Stage {current_stage} not found in registry") + raise ValueError( + f"Stage {current_stage} not found in registry, , register it using session.register_stage()" + ) valid_predecessors = stage_instance.valid_predecessor_stages diff --git a/export/stages.py b/export/stages.py index 61672e55bb7..fd27c298028 100644 --- a/export/stages.py +++ b/export/stages.py @@ -10,8 +10,9 @@ import torch from executorch.devtools.backend_debug import get_delegation_info +from executorch.exir import EdgeCompileConfig from executorch.exir.backend.backend_api import validation_disabled -from executorch.exir.program import to_edge_transform_and_lower +from executorch.exir.program import to_edge, to_edge_transform_and_lower from executorch.exir.program._program import _transform from executorch.export.recipe import QuantizationRecipe from executorch.export.types import StageType @@ -223,7 +224,7 @@ def stage_type(self) -> str: @property def valid_predecessor_stages(self) -> List["StageType"]: - return [StageType.TO_EDGE_TRANSFORM_AND_LOWER] + return [StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_BACKEND] @property def can_start_pipeline(self) -> bool: @@ -354,3 +355,114 @@ def run(self, artifact: PipelineArtifact) -> None: quantized_models[method_name] = quantized_model self._artifact = artifact.copy_with_new_data(quantized_models) + + +class ToEdgeStage(Stage): + """ + Stage: Convert ExportedProgram to EdgeProgramManager. + """ + + def __init__( + self, + edge_compile_config: Optional[EdgeCompileConfig] = None, # pyre-ignore + ) -> None: + super().__init__() + self._edge_compile_config = edge_compile_config + + @property + def stage_type(self) -> str: + return StageType.TO_EDGE + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.TORCH_EXPORT] + + @property + def can_start_pipeline(self) -> bool: + return False + + def run(self, artifact: PipelineArtifact) -> None: + """ + Convert ExportedProgram to EdgeProgramManager. + + Args: + artifact: Contains exported programs and context + """ + exported_programs = artifact.data + constant_methods = artifact.get_context("constant_methods") + + # Convert to edge program manager + edge_program_manager = to_edge( + exported_programs, + constant_methods=constant_methods, + compile_config=self._edge_compile_config, + ) + + self._artifact = artifact.copy_with_new_data(edge_program_manager) + + +class ToBackendStage(Stage): + """ + Stage: Apply transformations and partitioning to EdgeProgramManager. + """ + + def __init__( + self, + partitioners: Optional[List[Any]] = None, + transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None, + ) -> None: + super().__init__() + self._partitioners = partitioners + self._transform_passes = transform_passes + + @property + def stage_type(self) -> str: + return StageType.TO_BACKEND + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.TO_EDGE] + + @property + def can_start_pipeline(self) -> bool: + return False + + def run(self, artifact: PipelineArtifact) -> None: + """ + Apply transformations and partitioning to EdgeProgramManager. + + Args: + artifact: Contains edge program manager and context + """ + edge_program_manager = artifact.data + + if edge_program_manager is None: + raise RuntimeError("Edge program manager is not set.") + + # Apply transform passes if available + if self._transform_passes: + edge_program_manager = edge_program_manager.transform( + self._transform_passes + ) + + # Apply partitioners if available + if self._partitioners is not None and len(self._partitioners) > 0: + with validation_disabled(): + # pyre-ignore + for partitioner in self._partitioners: + edge_program_manager = edge_program_manager.to_backend(partitioner) + + # Get delegation info + delegation_info = get_delegation_info( + edge_program_manager.exported_program().graph_module + ) + + self._artifact = artifact.copy_with_new_data(edge_program_manager) + self._artifact.add_context("delegation_info", delegation_info) + + @property + def delegation_info(self) -> Any: + """ + Returns the delegation info. + """ + return self._artifact.get_context("delegation_info") diff --git a/export/tests/test_export_session.py b/export/tests/test_export_session.py index cc9f2a74062..7bef0d01876 100644 --- a/export/tests/test_export_session.py +++ b/export/tests/test_export_session.py @@ -249,7 +249,7 @@ def _get_export_session(self, stages: List[StageType]): def test_valid_pipeline_sequences(self) -> None: """Test various valid pipeline sequences.""" valid_sequences = [ - # Full pipeline + # Full pipeline with to_edge_transform_lower [ StageType.SOURCE_TRANSFORM, StageType.QUANTIZE, @@ -257,6 +257,15 @@ def test_valid_pipeline_sequences(self) -> None: StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_EXECUTORCH, ], + # Full pipeline with to_edge, to_backend + [ + StageType.SOURCE_TRANSFORM, + StageType.QUANTIZE, + StageType.TORCH_EXPORT, + StageType.TO_EDGE, + StageType.TO_BACKEND, + StageType.TO_EXECUTORCH, + ], # Skip quantize [ StageType.SOURCE_TRANSFORM, diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index 5d83b4f9046..2b3e533723a 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -19,6 +19,8 @@ QuantizeStage, SourceTransformStage, StageType, + ToBackendStage, + ToEdgeStage, TorchExportStage, ) from torch.export import ExportedProgram @@ -282,3 +284,105 @@ def test_run_empty_example_inputs(self) -> None: self.assertIn( "Example inputs for method forward not found or empty", str(cm.exception) ) + + +class TestToEdgeStage(unittest.TestCase): + def setUp(self) -> None: + self.mock_exported_program = Mock(spec=ExportedProgram) + self.exported_programs = {"forward": self.mock_exported_program} + self.context = {"constant_methods": None} + + @patch("executorch.export.stages.to_edge") + def test_run_success(self, mock_to_edge: Mock) -> None: + mock_edge_manager = Mock(spec=EdgeProgramManager) + mock_to_edge.return_value = mock_edge_manager + mock_config = Mock() + + stage = ToEdgeStage(edge_compile_config=mock_config) + artifact = PipelineArtifact(data=self.exported_programs, context=self.context) + stage.run(artifact) + + # Verify to_edge was called with correct parameters + mock_to_edge.assert_called_once_with( + self.exported_programs, + constant_methods=None, + compile_config=mock_config, + ) + + # Verify artifacts are set correctly + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, mock_edge_manager) + + +class TestToBackendStage(unittest.TestCase): + def setUp(self) -> None: + self.mock_edge_manager = Mock(spec=EdgeProgramManager) + self.context = {} + + @patch("executorch.export.stages.get_delegation_info") + def test_run_success_no_transforms_or_partitioners( + self, mock_get_delegation_info: Mock + ) -> None: + # Test successful execution without transforms or partitioners + mock_delegation_info = {"delegation": "info"} + mock_get_delegation_info.return_value = mock_delegation_info + mock_exported_program = Mock() + mock_graph_module = Mock() + mock_exported_program.graph_module = mock_graph_module + self.mock_edge_manager.exported_program.return_value = mock_exported_program + + stage = ToBackendStage() + artifact = PipelineArtifact(data=self.mock_edge_manager, context=self.context) + stage.run(artifact) + + # Verify get_delegation_info was called + mock_get_delegation_info.assert_called_once_with(mock_graph_module) + + # Verify artifacts are set correctly + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, self.mock_edge_manager) + self.assertEqual( + result_artifact.get_context("delegation_info"), mock_delegation_info + ) + + @patch("executorch.export.stages.get_delegation_info") + def test_run_with_partitioners_and_passes( + self, mock_get_delegation_info: Mock + ) -> None: + mock_delegation_info = {"delegation": "info"} + mock_get_delegation_info.return_value = mock_delegation_info + mock_exported_program = Mock() + mock_graph_module = Mock() + mock_exported_program.graph_module = mock_graph_module + + mock_edge_program_manager = Mock(spec=EdgeProgramManager) + mock_edge_program_manager.transform.return_value = mock_edge_program_manager + mock_edge_program_manager.to_backend.return_value = mock_edge_program_manager + + mock_partitioner = Mock() + mock_transform_passes = [Mock(), Mock()] + stage = ToBackendStage( + partitioners=[mock_partitioner], transform_passes=mock_transform_passes + ) + artifact = PipelineArtifact( + data=mock_edge_program_manager, context=self.context + ) + stage.run(artifact) + + # Verify transform and to_backend called correctly + mock_edge_program_manager.transform.assert_called_once_with( + mock_transform_passes + ) + mock_edge_program_manager.to_backend.assert_called_once_with(mock_partitioner) + + # Verify artifacts contain the backend manager + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, mock_edge_program_manager) + + def test_run_edge_manager_none(self) -> None: + stage = ToBackendStage() + artifact = PipelineArtifact(data=None, context=self.context) + + with self.assertRaises(RuntimeError) as cm: + stage.run(artifact) + self.assertIn("Edge program manager is not set", str(cm.exception)) diff --git a/export/types.py b/export/types.py index 8ffa287f91a..760f8461d41 100644 --- a/export/types.py +++ b/export/types.py @@ -16,4 +16,6 @@ class StageType(str, Enum): QUANTIZE = "quantize" TORCH_EXPORT = "torch_export" TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower" + TO_EDGE = "to_edge" + TO_BACKEND = "to_backend" TO_EXECUTORCH = "to_executorch"