diff --git a/export/export.py b/export/export.py index ab15067c561..86a932d153c 100644 --- a/export/export.py +++ b/export/export.py @@ -195,12 +195,12 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]: elif stage_type == StageType.QUANTIZE: stage = QuantizeStage(self._quant_recipe) elif stage_type == StageType.TORCH_EXPORT: - pre_edge_passes = None - if self._export_recipe.pre_edge_transform_passes is not None: - pre_edge_passes = list( - self._export_recipe.pre_edge_transform_passes + aten_transform_passes = None + if self._export_recipe.aten_transform_passes is not None: + aten_transform_passes = list( + self._export_recipe.aten_transform_passes ) - stage = TorchExportStage(pre_edge_passes) + stage = TorchExportStage(aten_transform_passes) elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER: stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe) elif stage_type == StageType.TO_EDGE: diff --git a/export/recipe.py b/export/recipe.py index 811270cdbf8..18f4b8aebb9 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -7,9 +7,10 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass from enum import Enum, EnumMeta -from typing import Callable, List, Optional, Sequence +from typing import Callable, List, Optional import torch +from executorch.exir import ExportedProgram from executorch.exir._warnings import experimental @@ -117,12 +118,15 @@ class LoweringRecipe: Attributes: partitioners: Optional list of partitioners for model partitioning - edge_transform_passes: Optional sequence of transformation passes to apply + edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram) as arguments + and return a list of passes (PassType) to be executed during lowering stages. edge_compile_config: Optional edge compilation configuration """ partitioners: Optional[List[Partitioner]] = None - edge_transform_passes: Optional[Sequence[PassType]] = None + edge_transform_passes: ( + None | List[Callable[[str, ExportedProgram], List[PassType]]] + ) = None # pyre-ignore[11]: Type not defined edge_compile_config: Optional[EdgeCompileConfig] = None @@ -141,8 +145,8 @@ class ExportRecipe: Attributes: name: Optional name for the recipe quantization_recipe: Optional quantization recipe for model quantization - pre_edge_transform_passes: Optional function to apply transformation passes - before edge lowering + aten_transform_passes: Optional list of functions to apply transformation passes to the program before edge lowering. + These callables are invoked to modify and return the transformed program. lowering_recipe: Optional lowering recipe for model lowering and partitioning executorch_backend_config: Optional backend configuration for ExecuTorch pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline. @@ -151,7 +155,9 @@ class ExportRecipe: name: Optional[str] = None quantization_recipe: Optional[QuantizationRecipe] = None - pre_edge_transform_passes: Optional[Sequence[PassType]] = None + aten_transform_passes: Optional[ + List[Callable[[str, ExportedProgram], ExportedProgram]] + ] = None lowering_recipe: Optional[LoweringRecipe] = None # pyre-ignore[11]: Type not defined executorch_backend_config: Optional[ExecutorchBackendConfig] = None @@ -240,8 +246,8 @@ def _combine_recipes( # noqa: C901 for recipe in backend_recipes: # Collect pre-edge transform passes - if recipe.pre_edge_transform_passes: - all_pre_edge_passes.extend(recipe.pre_edge_transform_passes) + if recipe.aten_transform_passes: + all_pre_edge_passes.extend(recipe.aten_transform_passes) # Collect partitioners from lowering recipes if recipe.lowering_recipe and recipe.lowering_recipe.partitioners: @@ -307,7 +313,7 @@ def _combine_recipes( # noqa: C901 return cls( name=recipe_name, quantization_recipe=combined_quantization_recipe, - pre_edge_transform_passes=all_pre_edge_passes, + aten_transform_passes=all_pre_edge_passes, lowering_recipe=combined_lowering_recipe, executorch_backend_config=combined_backend_config, ) diff --git a/export/stages.py b/export/stages.py index 7b583e27943..609e7d197b9 100644 --- a/export/stages.py +++ b/export/stages.py @@ -7,14 +7,14 @@ import copy import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Sequence +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional import torch from executorch.devtools.backend_debug import get_delegation_info -from executorch.exir import EdgeCompileConfig +from executorch.exir import EdgeCompileConfig, ExportedProgram from executorch.exir.backend.backend_api import validation_disabled from executorch.exir.program import to_edge, to_edge_transform_and_lower -from executorch.exir.program._program import _transform from executorch.export.recipe import LoweringRecipe, QuantizationRecipe from executorch.export.types import StageType from torch import nn @@ -107,10 +107,12 @@ class TorchExportStage(Stage): def __init__( self, - pre_edge_transform_passes: Optional[List[PassType]] = None, + aten_transform_passes: Optional[ + List[Callable[[str, ExportedProgram], ExportedProgram]] + ] = None, ) -> None: super().__init__() - self._pre_edge_transform_passes = pre_edge_transform_passes + self._aten_transform_passes = aten_transform_passes @property def stage_type(self) -> str: @@ -149,9 +151,13 @@ def run(self, artifact: PipelineArtifact) -> None: ) # Apply pre-edge transform passes if available - for pass_ in self._pre_edge_transform_passes or []: - exported_programs[method_name] = _transform( - exported_programs[method_name], pass_ + for pass_ in self._aten_transform_passes or []: + if not callable(pass_): + raise ValueError( + "Aten transform passes must be a callable that can transform and return an exported program" + ) + exported_programs[method_name] = pass_( + method_name, exported_programs[method_name] ) self._artifact = artifact.copy_with_new_data(exported_programs) @@ -165,7 +171,9 @@ class EdgeTransformAndLowerStage(Stage): def __init__( self, partitioners: Optional[List[Any]] = None, - transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None, + transform_passes: ( + None | List[Callable[[str, ExportedProgram], List[PassType]]] + ) = None, compile_config: Optional[Any] = None, ) -> None: self._partitioners = partitioners @@ -205,11 +213,28 @@ def run(self, artifact: PipelineArtifact) -> None: constant_methods = artifact.get_context("constant_methods") generate_etrecord = artifact.get_context("generate_etrecord", False) + # per method transform passes + transform_passes = defaultdict(list) + for method_name, ep in exported_programs.items(): + # Resolve transform passes from callable + for pass_ in self._transform_passes or []: + if not callable(pass_): + raise ValueError( + "Transform passes must be a callable that resolves to a list of passes" + ) + passes = pass_(method_name, ep) + if isinstance(passes, list): + transform_passes[method_name].extend(passes) + else: + raise ValueError( + "Transform passes must be a callable that resolves to a list of passes" + ) + with validation_disabled(): edge_program_manager = to_edge_transform_and_lower( exported_programs, partitioner=self._partitioners, - transform_passes=self._transform_passes, + transform_passes=transform_passes, constant_methods=constant_methods, compile_config=self._compile_config, generate_etrecord=generate_etrecord, @@ -396,7 +421,7 @@ def run(self, artifact: PipelineArtifact) -> None: captured_graph = torch.export.export(model, inputs, strict=True).module() quantizer = self._get_quantizer_for_prepare_pt2e( - self._quantization_recipe.quantizers + self._quantization_recipe.quantizers # pyre-ignore ) prepared_model = prepare_pt2e(captured_graph, quantizer) @@ -471,7 +496,9 @@ class ToBackendStage(Stage): def __init__( self, partitioners: Optional[List[Any]] = None, - transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None, + transform_passes: ( + None | List[Callable[[str, ExportedProgram], List[PassType]]] + ) = None, ) -> None: super().__init__() self._partitioners = partitioners @@ -513,11 +540,24 @@ def run(self, artifact: PipelineArtifact) -> None: if edge_program_manager is None: raise RuntimeError("Edge program manager is not set.") - # Apply transform passes if available - if self._transform_passes: - edge_program_manager = edge_program_manager.transform( - self._transform_passes - ) + # per method transform passes + transform_passes = defaultdict(list) + for method_name in edge_program_manager.methods: + # Resolve transform passes if it's a callable + ep = edge_program_manager.exported_program(method_name) + for pass_ in self._transform_passes or []: + if not callable(pass_): + raise ValueError( + "Transform passes must be a callable that resolves to a list of passes" + ) + passes = pass_(method_name, ep) + if isinstance(passes, list): + transform_passes[method_name].extend(passes) + else: + raise ValueError("Transform passes must return list of passes") + + # Apply transform passes + edge_program_manager = edge_program_manager.transform(transform_passes) # Apply partitioners if available if self._partitioners is not None and len(self._partitioners) > 0: diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index d4629a1aea7..608aa5adb3c 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -7,7 +7,7 @@ # pyre-strict import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, PropertyMock import torch from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager @@ -99,6 +99,66 @@ def test_get_artifacts_before_run(self) -> None: stage.get_artifacts() self.assertIn("Stage: TorchExportStage not executed", str(cm.exception)) + @patch("torch.export.export") + def test_export_stage_with_aten_transform_passes( + self, mock_torch_export: Mock + ) -> None: + """Test TorchExportStage with aten_transform_passes.""" + mock_exported_program = Mock(spec=ExportedProgram) + mock_transformed_program = Mock(spec=ExportedProgram) + mock_torch_export.return_value = mock_exported_program + + # Create a mock aten transform pass that we can verify + mock_aten_transform_pass = Mock() + mock_aten_transform_pass.return_value = mock_transformed_program + aten_transform_passes = [mock_aten_transform_pass] + + stage = TorchExportStage(aten_transform_passes=aten_transform_passes) + artifact = PipelineArtifact(data=self.models_dict, context=self.context) + + stage.run(artifact) + + # Verify torch.export.export was called + mock_torch_export.assert_called_once_with( + self.model, + self.example_inputs[0], + dynamic_shapes=None, + strict=True, + ) + + # Verify the aten transform pass was called with correct parameters + mock_aten_transform_pass.assert_called_once_with( + "forward", mock_exported_program + ) + + # Verify artifacts contain the transformed program + result_artifact = stage.get_artifacts() + self.assertIn("forward", result_artifact.data) + self.assertEqual(result_artifact.data["forward"], mock_transformed_program) + + @patch("torch.export.export") + def test_export_stage_invalid_aten_transform_pass( + self, mock_torch_export: Mock + ) -> None: + """Test TorchExportStage with invalid aten_transform_pass (not callable).""" + mock_exported_program = Mock(spec=ExportedProgram) + mock_torch_export.return_value = mock_exported_program + + # Use a non-callable object as transform pass + invalid_transform_pass = "not_callable" + aten_transform_passes = [invalid_transform_pass] + + # pyre-ignore + stage = TorchExportStage(aten_transform_passes=aten_transform_passes) + artifact = PipelineArtifact(data=self.models_dict, context=self.context) + + with self.assertRaises(ValueError) as cm: + stage.run(artifact) + self.assertIn( + "Aten transform passes must be a callable that can transform and return an exported program", + str(cm.exception), + ) + class TestEdgeTransformAndLowerStage(unittest.TestCase): def setUp(self) -> None: @@ -106,12 +166,32 @@ def setUp(self) -> None: self.exported_programs = {"forward": self.mock_exported_program} self.context = {"constant_methods": None} - def test_run_with_partitioners_and_config(self) -> None: + @patch("executorch.export.stages.to_edge_transform_and_lower") + @patch("executorch.export.stages.get_delegation_info") + def test_run_with_partitioners_and_config( + self, mock_get_delegation_info: Mock, mock_to_edge_transform_and_lower: Mock + ) -> None: """Test execution with partitioners and compile config""" + mock_delegation_info = {"delegation": "info"} + mock_get_delegation_info.return_value = mock_delegation_info + mock_partitioners = [Mock()] - mock_transform_passes = [Mock()] mock_compile_config = Mock() + # Create a mock transform pass callable that we can verify + mock_transform_pass = Mock() + mock_pass1 = Mock() + mock_pass2 = Mock() + mock_transform_pass.return_value = [mock_pass1, mock_pass2] + mock_transform_passes = [mock_transform_pass] + + mock_edge_program_manager = Mock(spec=EdgeProgramManager) + mock_exported_program = Mock() + mock_graph_module = Mock() + mock_exported_program.graph_module = mock_graph_module + mock_edge_program_manager.exported_program.return_value = mock_exported_program + mock_to_edge_transform_and_lower.return_value = mock_edge_program_manager + stage = EdgeTransformAndLowerStage( partitioners=mock_partitioners, transform_passes=mock_transform_passes, @@ -124,6 +204,33 @@ def test_run_with_partitioners_and_config(self) -> None: self.assertEqual(stage._transform_passes, mock_transform_passes) self.assertEqual(stage._compile_config, mock_compile_config) + # Test the run method + artifact = PipelineArtifact(data=self.exported_programs, context=self.context) + stage.run(artifact) + + # Verify the transform pass callable was called with correct parameters + mock_transform_pass.assert_called_once_with( + "forward", self.mock_exported_program + ) + + # Verify to_edge_transform_and_lower was called with the expected structure + expected_transform_passes = {"forward": [mock_pass1, mock_pass2]} + mock_to_edge_transform_and_lower.assert_called_once_with( + self.exported_programs, + partitioner=mock_partitioners, + transform_passes=expected_transform_passes, + constant_methods=None, + compile_config=mock_compile_config, + generate_etrecord=False, + ) + + # Verify artifacts are set correctly + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, mock_edge_program_manager) + self.assertEqual( + result_artifact.get_context("delegation_info"), mock_delegation_info + ) + class TestExecutorchStage(unittest.TestCase): def setUp(self) -> None: @@ -380,7 +487,10 @@ def test_run_success_no_transforms_or_partitioners( mock_exported_program = Mock() mock_graph_module = Mock() mock_exported_program.graph_module = mock_graph_module + + self.mock_edge_manager.transform.return_value = self.mock_edge_manager self.mock_edge_manager.exported_program.return_value = mock_exported_program + self.mock_edge_manager.methods = {"forward"} stage = ToBackendStage() artifact = PipelineArtifact(data=self.mock_edge_manager, context=self.context) @@ -409,9 +519,21 @@ def test_run_with_partitioners_and_passes( mock_edge_program_manager = Mock(spec=EdgeProgramManager) mock_edge_program_manager.transform.return_value = mock_edge_program_manager mock_edge_program_manager.to_backend.return_value = mock_edge_program_manager + mock_edge_program_manager.exported_program.return_value = mock_exported_program + + # Use PropertyMock for the methods property + methods_property_mock = PropertyMock(return_value={"forward"}) + type(mock_edge_program_manager).methods = methods_property_mock mock_partitioner = Mock() - mock_transform_passes = [Mock(), Mock()] + + # Create a mock transform pass callable that we can verify + mock_transform_pass = Mock() + mock_pass1 = Mock() + mock_pass2 = Mock() + mock_transform_pass.return_value = [mock_pass1, mock_pass2] + mock_transform_passes = [mock_transform_pass] + stage = ToBackendStage( partitioners=[mock_partitioner], transform_passes=mock_transform_passes ) @@ -420,10 +542,19 @@ def test_run_with_partitioners_and_passes( ) stage.run(artifact) - # Verify transform and to_backend called correctly + # Verify that the methods property was accessed + methods_property_mock.assert_called_once() + + # Verify the transform pass callable was called with correct parameters + mock_transform_pass.assert_called_once_with("forward", mock_exported_program) + + # Verify transform was called with the expected structure + expected_transform_passes = {"forward": [mock_pass1, mock_pass2]} mock_edge_program_manager.transform.assert_called_once_with( - mock_transform_passes + expected_transform_passes ) + + # Verify to_backend called correctly mock_edge_program_manager.to_backend.assert_called_once_with(mock_partitioner) # Verify artifacts contain the backend manager