Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch
from executorch.exir import EdgeProgramManager
from executorch.exir._warnings import experimental
from executorch.exir.program import ExecutorchProgramManager
from executorch.exir.schema import Program
Expand Down Expand Up @@ -461,17 +462,65 @@ def export(self) -> None:
def get_stage_artifacts(self) -> Dict[StageType, PipelineArtifact]:
return self._stage_to_artifacts

def save_pte_file(self, path: str) -> None:
def get_exported_program(self, method_name: str = "forward") -> ExportedProgram:
"""
Save the exported program to a PTE file.
Get the ExportedProgram for a specific method after torch export.

Args:
path: Path where the PTE file will be saved
method_name: Name of the method to get exported program for, defaults to "forward"

Returns:
The ExportedProgram for the specified method

Raises:
RuntimeError: If the executorch program manager is not initialized
RuntimeError: If torch export stage has not been run
KeyError: If the method name is not found in exported programs
"""
self.get_executorch_program_manager().save(path)
artifact = self._stage_to_artifacts.get(StageType.TORCH_EXPORT)
if artifact is None or artifact.data is None:
raise RuntimeError(
"Exported program is not available. Run Torch Export Stage first."
)

exported_programs = artifact.data
if method_name not in exported_programs:
raise KeyError(
f"Method name '{method_name}' not found in exported programs. "
f"Available methods: {list(exported_programs.keys())}"
)

return exported_programs[method_name]

def get_edge_program_manager(self) -> "EdgeProgramManager":
"""
Get the EdgeProgramManager after edge lowering stages.

This method checks multiple stages in order of preference:
1. TO_EDGE_TRANSFORM_AND_LOWER (combined stage)
2. TO_BACKEND (separate stage with backend delegation)
3. TO_EDGE (separate stage without backend delegation)

Returns:
The EdgeProgramManager

Raises:
RuntimeError: If no edge stage has been run
"""
# Check stages in order of preference
for stage_type in [
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.TO_BACKEND,
StageType.TO_EDGE,
]:
artifact = self._stage_to_artifacts.get(stage_type)
if artifact is not None and artifact.data is not None:
logging.info(f"Returning edge program manager from stage {stage_type}")
return artifact.data

raise RuntimeError(
"Edge program manager is not available. "
"Run one of the edge stages first: TO_EDGE_TRANSFORM_AND_LOWER, TO_EDGE, or TO_BACKEND."
)

def get_executorch_program(self) -> Program:
"""
Expand Down
154 changes: 154 additions & 0 deletions export/tests/test_export_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,3 +817,157 @@ def test_exported_program_valid_pipeline(self) -> None:

# Should not raise during validation
session._validate_pipeline_sequence(recipe.pipeline_stages)


class TestIntermediateStateGetters(unittest.TestCase):
"""Test convenience getters for intermediate pipeline states."""

def setUp(self) -> None:
self.model = SimpleTestModel()
self.example_inputs = [(torch.randn(2, 10),)]

def test_get_exported_program_after_torch_export(self) -> None:
"""Test that get_exported_program works after torch export stage."""
recipe = ExportRecipe(
name="test",
pipeline_stages=[
StageType.TORCH_EXPORT,
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.TO_EXECUTORCH,
],
)

session = ExportSession(
model=self.model,
example_inputs=self.example_inputs,
export_recipe=recipe,
)

session.export()

exported_program = session.get_exported_program()
self.assertIsNotNone(exported_program)
self.assertIsInstance(exported_program, torch.export.ExportedProgram)

def test_get_exported_program_before_export_fails(self) -> None:
"""Test that get_exported_program fails before torch export stage."""
recipe = ExportRecipe(name="test")

session = ExportSession(
model=self.model,
example_inputs=self.example_inputs,
export_recipe=recipe,
)

with self.assertRaises(RuntimeError) as cm:
session.get_exported_program()
self.assertIn("Exported program is not available", str(cm.exception))

def test_get_exported_program_invalid_method_name(self) -> None:
"""Test that get_exported_program fails with invalid method name."""
recipe = ExportRecipe(name="test")

session = ExportSession(
model=self.model,
example_inputs=self.example_inputs,
export_recipe=recipe,
)

session.export()

with self.assertRaises(KeyError) as cm:
session.get_exported_program("nonexistent_method")
self.assertIn("Method name 'nonexistent_method' not found", str(cm.exception))

def test_get_exported_program_multi_method(self) -> None:
"""Test get_exported_program with multi-method model."""
model_dict = {
"forward": self.model,
"inference": SimpleTestModel(),
}
inputs_dict = {
"forward": self.example_inputs,
"inference": [(torch.randn(1, 10),)],
}

recipe = ExportRecipe(name="multi_method_test")

session = ExportSession(
model=model_dict,
example_inputs=inputs_dict,
export_recipe=recipe,
)

session.export()

forward_ep = session.get_exported_program("forward")
inference_ep = session.get_exported_program("inference")

self.assertIsNotNone(forward_ep)
self.assertIsNotNone(inference_ep)
self.assertIsInstance(forward_ep, torch.export.ExportedProgram)
self.assertIsInstance(inference_ep, torch.export.ExportedProgram)

def test_get_edge_program_manager_with_transform_and_lower(self) -> None:
"""Test get_edge_program_manager with TO_EDGE_TRANSFORM_AND_LOWER stage."""
recipe = ExportRecipe(
name="test",
pipeline_stages=[
StageType.TORCH_EXPORT,
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.TO_EXECUTORCH,
],
)

session = ExportSession(
model=self.model,
example_inputs=self.example_inputs,
export_recipe=recipe,
)

session.export()

edge_manager = session.get_edge_program_manager()
self.assertIsNotNone(edge_manager)

def test_get_edge_program_manager_with_separate_stages(self) -> None:
"""Test get_edge_program_manager with separate TO_EDGE and TO_BACKEND stages."""
recipe = ExportRecipe(
name="test",
pipeline_stages=[
StageType.TORCH_EXPORT,
StageType.TO_EDGE,
StageType.TO_BACKEND,
StageType.TO_EXECUTORCH,
],
)

session = ExportSession(
model=self.model,
example_inputs=self.example_inputs,
export_recipe=recipe,
)

session.export()

edge_manager = session.get_edge_program_manager()
self.assertIsNotNone(edge_manager)

def test_get_edge_program_manager_before_edge_stage_fails(self) -> None:
"""Test that get_edge_program_manager fails before edge stages."""
recipe = ExportRecipe(
name="test",
pipeline_stages=[StageType.TORCH_EXPORT],
)

session = ExportSession(
model=self.model,
example_inputs=self.example_inputs,
export_recipe=recipe,
)

session.export()

with self.assertRaises(RuntimeError) as cm:
session.get_edge_program_manager()
self.assertIn("Edge program manager is not available", str(cm.exception))
Loading