Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
elif stage_type == StageType.QUANTIZE:
stage = QuantizeStage(self._quant_recipe)
elif stage_type == StageType.TORCH_EXPORT:
pre_edge_passes = None
if self._export_recipe.pre_edge_transform_passes is not None:
pre_edge_passes = list(
self._export_recipe.pre_edge_transform_passes
aten_transform_passes = None
if self._export_recipe.aten_transform_passes is not None:
aten_transform_passes = list(
self._export_recipe.aten_transform_passes
)
stage = TorchExportStage(pre_edge_passes)
stage = TorchExportStage(aten_transform_passes)
elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER:
stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe)
elif stage_type == StageType.TO_EDGE:
Expand Down
24 changes: 15 additions & 9 deletions export/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from enum import Enum, EnumMeta
from typing import Callable, List, Optional, Sequence
from typing import Callable, List, Optional

import torch
from executorch.exir import ExportedProgram

from executorch.exir._warnings import experimental

Expand Down Expand Up @@ -117,12 +118,15 @@ class LoweringRecipe:

Attributes:
partitioners: Optional list of partitioners for model partitioning
edge_transform_passes: Optional sequence of transformation passes to apply
edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram) as arguments
and return a list of passes (PassType) to be executed during lowering stages.
edge_compile_config: Optional edge compilation configuration
"""

partitioners: Optional[List[Partitioner]] = None
edge_transform_passes: Optional[Sequence[PassType]] = None
edge_transform_passes: (
None | List[Callable[[str, ExportedProgram], List[PassType]]]
) = None
# pyre-ignore[11]: Type not defined
edge_compile_config: Optional[EdgeCompileConfig] = None

Expand All @@ -141,8 +145,8 @@ class ExportRecipe:
Attributes:
name: Optional name for the recipe
quantization_recipe: Optional quantization recipe for model quantization
pre_edge_transform_passes: Optional function to apply transformation passes
before edge lowering
aten_transform_passes: Optional list of functions to apply transformation passes to the program before edge lowering.
These callables are invoked to modify and return the transformed program.
lowering_recipe: Optional lowering recipe for model lowering and partitioning
executorch_backend_config: Optional backend configuration for ExecuTorch
pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline.
Expand All @@ -151,7 +155,9 @@ class ExportRecipe:

name: Optional[str] = None
quantization_recipe: Optional[QuantizationRecipe] = None
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
aten_transform_passes: Optional[
List[Callable[[str, ExportedProgram], ExportedProgram]]
] = None
lowering_recipe: Optional[LoweringRecipe] = None
# pyre-ignore[11]: Type not defined
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
Expand Down Expand Up @@ -240,8 +246,8 @@ def _combine_recipes( # noqa: C901

for recipe in backend_recipes:
# Collect pre-edge transform passes
if recipe.pre_edge_transform_passes:
all_pre_edge_passes.extend(recipe.pre_edge_transform_passes)
if recipe.aten_transform_passes:
all_pre_edge_passes.extend(recipe.aten_transform_passes)

# Collect partitioners from lowering recipes
if recipe.lowering_recipe and recipe.lowering_recipe.partitioners:
Expand Down Expand Up @@ -307,7 +313,7 @@ def _combine_recipes( # noqa: C901
return cls(
name=recipe_name,
quantization_recipe=combined_quantization_recipe,
pre_edge_transform_passes=all_pre_edge_passes,
aten_transform_passes=all_pre_edge_passes,
lowering_recipe=combined_lowering_recipe,
executorch_backend_config=combined_backend_config,
)
74 changes: 57 additions & 17 deletions export/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import copy
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Sequence
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional

import torch
from executorch.devtools.backend_debug import get_delegation_info
from executorch.exir import EdgeCompileConfig
from executorch.exir import EdgeCompileConfig, ExportedProgram
from executorch.exir.backend.backend_api import validation_disabled
from executorch.exir.program import to_edge, to_edge_transform_and_lower
from executorch.exir.program._program import _transform
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
from executorch.export.types import StageType
from torch import nn
Expand Down Expand Up @@ -107,10 +107,12 @@ class TorchExportStage(Stage):

def __init__(
self,
pre_edge_transform_passes: Optional[List[PassType]] = None,
aten_transform_passes: Optional[
List[Callable[[str, ExportedProgram], ExportedProgram]]
] = None,
) -> None:
super().__init__()
self._pre_edge_transform_passes = pre_edge_transform_passes
self._aten_transform_passes = aten_transform_passes

@property
def stage_type(self) -> str:
Expand Down Expand Up @@ -149,9 +151,13 @@ def run(self, artifact: PipelineArtifact) -> None:
)

# Apply pre-edge transform passes if available
for pass_ in self._pre_edge_transform_passes or []:
exported_programs[method_name] = _transform(
exported_programs[method_name], pass_
for pass_ in self._aten_transform_passes or []:
if not callable(pass_):
raise ValueError(
"Aten transform passes must be a callable that can transform and return an exported program"
)
exported_programs[method_name] = pass_(
method_name, exported_programs[method_name]
)

self._artifact = artifact.copy_with_new_data(exported_programs)
Expand All @@ -165,7 +171,9 @@ class EdgeTransformAndLowerStage(Stage):
def __init__(
self,
partitioners: Optional[List[Any]] = None,
transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None,
transform_passes: (
None | List[Callable[[str, ExportedProgram], List[PassType]]]
) = None,
compile_config: Optional[Any] = None,
) -> None:
self._partitioners = partitioners
Expand Down Expand Up @@ -205,11 +213,28 @@ def run(self, artifact: PipelineArtifact) -> None:
constant_methods = artifact.get_context("constant_methods")
generate_etrecord = artifact.get_context("generate_etrecord", False)

# per method transform passes
transform_passes = defaultdict(list)
for method_name, ep in exported_programs.items():
# Resolve transform passes from callable
for pass_ in self._transform_passes or []:
if not callable(pass_):
raise ValueError(
"Transform passes must be a callable that resolves to a list of passes"
)
passes = pass_(method_name, ep)
if isinstance(passes, list):
transform_passes[method_name].extend(passes)
else:
raise ValueError(
"Transform passes must be a callable that resolves to a list of passes"
)

with validation_disabled():
edge_program_manager = to_edge_transform_and_lower(
exported_programs,
partitioner=self._partitioners,
transform_passes=self._transform_passes,
transform_passes=transform_passes,
constant_methods=constant_methods,
compile_config=self._compile_config,
generate_etrecord=generate_etrecord,
Expand Down Expand Up @@ -396,7 +421,7 @@ def run(self, artifact: PipelineArtifact) -> None:
captured_graph = torch.export.export(model, inputs, strict=True).module()

quantizer = self._get_quantizer_for_prepare_pt2e(
self._quantization_recipe.quantizers
self._quantization_recipe.quantizers # pyre-ignore
)
prepared_model = prepare_pt2e(captured_graph, quantizer)

Expand Down Expand Up @@ -471,7 +496,9 @@ class ToBackendStage(Stage):
def __init__(
self,
partitioners: Optional[List[Any]] = None,
transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None,
transform_passes: (
None | List[Callable[[str, ExportedProgram], List[PassType]]]
) = None,
) -> None:
super().__init__()
self._partitioners = partitioners
Expand Down Expand Up @@ -513,11 +540,24 @@ def run(self, artifact: PipelineArtifact) -> None:
if edge_program_manager is None:
raise RuntimeError("Edge program manager is not set.")

# Apply transform passes if available
if self._transform_passes:
edge_program_manager = edge_program_manager.transform(
self._transform_passes
)
# per method transform passes
transform_passes = defaultdict(list)
for method_name in edge_program_manager.methods:
# Resolve transform passes if it's a callable
ep = edge_program_manager.exported_program(method_name)
for pass_ in self._transform_passes or []:
if not callable(pass_):
raise ValueError(
"Transform passes must be a callable that resolves to a list of passes"
)
passes = pass_(method_name, ep)
if isinstance(passes, list):
transform_passes[method_name].extend(passes)
else:
raise ValueError("Transform passes must return list of passes")

# Apply transform passes
edge_program_manager = edge_program_manager.transform(transform_passes)

# Apply partitioners if available
if self._partitioners is not None and len(self._partitioners) > 0:
Expand Down
Loading
Loading