Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
QuantizeStage,
SourceTransformStage,
Stage,
ToBackendStage,
ToEdgeStage,
TorchExportStage,
)
from .types import StageType
Expand Down Expand Up @@ -147,7 +149,9 @@ def __init__(
)

# Stage registry: map of StageType to Stage instances
self._stage_registry: Dict[StageType, Stage] = self._build_default_stages()
self._stage_registry: Dict[StageType, Stage] = self._build_stages(
self._pipeline_stages
)

# Intialize run context
self._run_context: Dict[str, Any] = {
Expand All @@ -170,10 +174,12 @@ def _get_default_pipeline(self) -> List[StageType]:
StageType.TO_EXECUTORCH,
]

def _build_default_stages(self) -> Dict[StageType, Stage]:
def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
"""Build the stage registry from the given stages."""
stage_registry: Dict[StageType, Stage] = {}

for stage_type in self._get_default_pipeline():
stage = None
for stage_type in stages or self._get_default_pipeline():
if stage_type == StageType.SOURCE_TRANSFORM:
stage = SourceTransformStage(self._quant_recipe)
elif stage_type == StageType.QUANTIZE:
Expand All @@ -191,12 +197,24 @@ def _build_default_stages(self) -> Dict[StageType, Stage]:
transform_passes=self._export_recipe.edge_transform_passes,
compile_config=self._export_recipe.edge_compile_config,
)
elif stage_type == StageType.TO_EDGE:
stage = ToEdgeStage(
edge_compile_config=self._export_recipe.edge_compile_config
)
elif stage_type == StageType.TO_BACKEND:
stage = ToBackendStage(
partitioners=self._export_recipe.partitioners,
transform_passes=self._export_recipe.edge_transform_passes,
)
elif stage_type == StageType.TO_EXECUTORCH:
stage = ExecutorchStage(self._export_recipe.executorch_backend_config)
else:
raise ValueError(f"Unknown stage type: {stage_type}")
logging.info(
f"{stage_type} is unknown, you have to register it before executing export()"
)

stage_registry[stage_type] = stage
if stage:
stage_registry[stage_type] = stage
return stage_registry

def register_stage(self, stage_type: StageType, stage: Stage) -> None:
Expand Down Expand Up @@ -241,7 +259,9 @@ def _validate_pipeline_sequence(
first_stage = stages[0]
first_stage_instance = self._stage_registry.get(first_stage)
if first_stage_instance is None:
raise ValueError(f"Stage {first_stage} not found in registry")
raise ValueError(
f"Stage {first_stage} not found in registry, register it using session.register_stage()"
)

if not first_stage_instance.can_start_pipeline:
raise ValueError(f"Stage {first_stage} cannot start a pipeline. ")
Expand All @@ -254,7 +274,9 @@ def _validate_pipeline_sequence(
# Get the stage instance to check its valid predecessors
stage_instance = self._stage_registry.get(current_stage)
if stage_instance is None:
raise ValueError(f"Stage {current_stage} not found in registry")
raise ValueError(
f"Stage {current_stage} not found in registry, , register it using session.register_stage()"
)

valid_predecessors = stage_instance.valid_predecessor_stages

Expand Down
116 changes: 114 additions & 2 deletions export/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

import torch
from executorch.devtools.backend_debug import get_delegation_info
from executorch.exir import EdgeCompileConfig
from executorch.exir.backend.backend_api import validation_disabled
from executorch.exir.program import to_edge_transform_and_lower
from executorch.exir.program import to_edge, to_edge_transform_and_lower
from executorch.exir.program._program import _transform
from executorch.export.recipe import QuantizationRecipe
from executorch.export.types import StageType
Expand Down Expand Up @@ -223,7 +224,7 @@ def stage_type(self) -> str:

@property
def valid_predecessor_stages(self) -> List["StageType"]:
return [StageType.TO_EDGE_TRANSFORM_AND_LOWER]
return [StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_BACKEND]

@property
def can_start_pipeline(self) -> bool:
Expand Down Expand Up @@ -354,3 +355,114 @@ def run(self, artifact: PipelineArtifact) -> None:
quantized_models[method_name] = quantized_model

self._artifact = artifact.copy_with_new_data(quantized_models)


class ToEdgeStage(Stage):
"""
Stage: Convert ExportedProgram to EdgeProgramManager.
"""

def __init__(
self,
edge_compile_config: Optional[EdgeCompileConfig] = None, # pyre-ignore
) -> None:
super().__init__()
self._edge_compile_config = edge_compile_config

@property
def stage_type(self) -> str:
return StageType.TO_EDGE

@property
def valid_predecessor_stages(self) -> List["StageType"]:
return [StageType.TORCH_EXPORT]

@property
def can_start_pipeline(self) -> bool:
return False

def run(self, artifact: PipelineArtifact) -> None:
"""
Convert ExportedProgram to EdgeProgramManager.

Args:
artifact: Contains exported programs and context
"""
exported_programs = artifact.data
constant_methods = artifact.get_context("constant_methods")

# Convert to edge program manager
edge_program_manager = to_edge(
exported_programs,
constant_methods=constant_methods,
compile_config=self._edge_compile_config,
)

self._artifact = artifact.copy_with_new_data(edge_program_manager)


class ToBackendStage(Stage):
"""
Stage: Apply transformations and partitioning to EdgeProgramManager.
"""

def __init__(
self,
partitioners: Optional[List[Any]] = None,
transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None,
) -> None:
super().__init__()
self._partitioners = partitioners
self._transform_passes = transform_passes

@property
def stage_type(self) -> str:
return StageType.TO_BACKEND

@property
def valid_predecessor_stages(self) -> List["StageType"]:
return [StageType.TO_EDGE]

@property
def can_start_pipeline(self) -> bool:
return False

def run(self, artifact: PipelineArtifact) -> None:
"""
Apply transformations and partitioning to EdgeProgramManager.

Args:
artifact: Contains edge program manager and context
"""
edge_program_manager = artifact.data

if edge_program_manager is None:
raise RuntimeError("Edge program manager is not set.")

# Apply transform passes if available
if self._transform_passes:
edge_program_manager = edge_program_manager.transform(
self._transform_passes
)

# Apply partitioners if available
if self._partitioners is not None and len(self._partitioners) > 0:
with validation_disabled():
# pyre-ignore
for partitioner in self._partitioners:
edge_program_manager = edge_program_manager.to_backend(partitioner)

# Get delegation info
delegation_info = get_delegation_info(
edge_program_manager.exported_program().graph_module
)

self._artifact = artifact.copy_with_new_data(edge_program_manager)
self._artifact.add_context("delegation_info", delegation_info)

@property
def delegation_info(self) -> Any:
"""
Returns the delegation info.
"""
return self._artifact.get_context("delegation_info")
11 changes: 10 additions & 1 deletion export/tests/test_export_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,23 @@ def _get_export_session(self, stages: List[StageType]):
def test_valid_pipeline_sequences(self) -> None:
"""Test various valid pipeline sequences."""
valid_sequences = [
# Full pipeline
# Full pipeline with to_edge_transform_lower
[
StageType.SOURCE_TRANSFORM,
StageType.QUANTIZE,
StageType.TORCH_EXPORT,
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.TO_EXECUTORCH,
],
# Full pipeline with to_edge, to_backend
[
StageType.SOURCE_TRANSFORM,
StageType.QUANTIZE,
StageType.TORCH_EXPORT,
StageType.TO_EDGE,
StageType.TO_BACKEND,
StageType.TO_EXECUTORCH,
],
# Skip quantize
[
StageType.SOURCE_TRANSFORM,
Expand Down
104 changes: 104 additions & 0 deletions export/tests/test_export_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
QuantizeStage,
SourceTransformStage,
StageType,
ToBackendStage,
ToEdgeStage,
TorchExportStage,
)
from torch.export import ExportedProgram
Expand Down Expand Up @@ -282,3 +284,105 @@ def test_run_empty_example_inputs(self) -> None:
self.assertIn(
"Example inputs for method forward not found or empty", str(cm.exception)
)


class TestToEdgeStage(unittest.TestCase):
def setUp(self) -> None:
self.mock_exported_program = Mock(spec=ExportedProgram)
self.exported_programs = {"forward": self.mock_exported_program}
self.context = {"constant_methods": None}

@patch("executorch.export.stages.to_edge")
def test_run_success(self, mock_to_edge: Mock) -> None:
mock_edge_manager = Mock(spec=EdgeProgramManager)
mock_to_edge.return_value = mock_edge_manager
mock_config = Mock()

stage = ToEdgeStage(edge_compile_config=mock_config)
artifact = PipelineArtifact(data=self.exported_programs, context=self.context)
stage.run(artifact)

# Verify to_edge was called with correct parameters
mock_to_edge.assert_called_once_with(
self.exported_programs,
constant_methods=None,
compile_config=mock_config,
)

# Verify artifacts are set correctly
result_artifact = stage.get_artifacts()
self.assertEqual(result_artifact.data, mock_edge_manager)


class TestToBackendStage(unittest.TestCase):
def setUp(self) -> None:
self.mock_edge_manager = Mock(spec=EdgeProgramManager)
self.context = {}

@patch("executorch.export.stages.get_delegation_info")
def test_run_success_no_transforms_or_partitioners(
self, mock_get_delegation_info: Mock
) -> None:
# Test successful execution without transforms or partitioners
mock_delegation_info = {"delegation": "info"}
mock_get_delegation_info.return_value = mock_delegation_info
mock_exported_program = Mock()
mock_graph_module = Mock()
mock_exported_program.graph_module = mock_graph_module
self.mock_edge_manager.exported_program.return_value = mock_exported_program

stage = ToBackendStage()
artifact = PipelineArtifact(data=self.mock_edge_manager, context=self.context)
stage.run(artifact)

# Verify get_delegation_info was called
mock_get_delegation_info.assert_called_once_with(mock_graph_module)

# Verify artifacts are set correctly
result_artifact = stage.get_artifacts()
self.assertEqual(result_artifact.data, self.mock_edge_manager)
self.assertEqual(
result_artifact.get_context("delegation_info"), mock_delegation_info
)

@patch("executorch.export.stages.get_delegation_info")
def test_run_with_partitioners_and_passes(
self, mock_get_delegation_info: Mock
) -> None:
mock_delegation_info = {"delegation": "info"}
mock_get_delegation_info.return_value = mock_delegation_info
mock_exported_program = Mock()
mock_graph_module = Mock()
mock_exported_program.graph_module = mock_graph_module

mock_edge_program_manager = Mock(spec=EdgeProgramManager)
mock_edge_program_manager.transform.return_value = mock_edge_program_manager
mock_edge_program_manager.to_backend.return_value = mock_edge_program_manager

mock_partitioner = Mock()
mock_transform_passes = [Mock(), Mock()]
stage = ToBackendStage(
partitioners=[mock_partitioner], transform_passes=mock_transform_passes
)
artifact = PipelineArtifact(
data=mock_edge_program_manager, context=self.context
)
stage.run(artifact)

# Verify transform and to_backend called correctly
mock_edge_program_manager.transform.assert_called_once_with(
mock_transform_passes
)
mock_edge_program_manager.to_backend.assert_called_once_with(mock_partitioner)

# Verify artifacts contain the backend manager
result_artifact = stage.get_artifacts()
self.assertEqual(result_artifact.data, mock_edge_program_manager)

def test_run_edge_manager_none(self) -> None:
stage = ToBackendStage()
artifact = PipelineArtifact(data=None, context=self.context)

with self.assertRaises(RuntimeError) as cm:
stage.run(artifact)
self.assertIn("Edge program manager is not set", str(cm.exception))
2 changes: 2 additions & 0 deletions export/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ class StageType(str, Enum):
QUANTIZE = "quantize"
TORCH_EXPORT = "torch_export"
TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower"
TO_EDGE = "to_edge"
TO_BACKEND = "to_backend"
TO_EXECUTORCH = "to_executorch"
Loading