Skip to content

Commit 867a66b

Browse files
[Executorch][Export][2/N] Add getter utils to help with getting the artifacts easily
Add getter utility methods to ExportSession for easier artifact retrieval. - `get_exported_program` - to get output from export stage - `get_edge_program_manager` - multiple stages produce edge program manager, for example, to_edge, transform, to_backend, go in reverse order to get the most processed edge program manager. Differential Revision: [D87576720](https://our.internmc.facebook.com/intern/diff/D87576720/) [ghstack-poisoned]
1 parent 362a96f commit 867a66b

File tree

2 files changed

+208
-5
lines changed

2 files changed

+208
-5
lines changed

export/export.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
1010

1111
import torch
12+
from executorch.exir import EdgeProgramManager
1213
from executorch.exir._warnings import experimental
1314
from executorch.exir.program import ExecutorchProgramManager
1415
from executorch.exir.schema import Program
@@ -461,17 +462,65 @@ def export(self) -> None:
461462
def get_stage_artifacts(self) -> Dict[StageType, PipelineArtifact]:
462463
return self._stage_to_artifacts
463464

464-
def save_pte_file(self, path: str) -> None:
465+
def get_exported_program(self, method_name: str = "forward") -> ExportedProgram:
465466
"""
466-
Save the exported program to a PTE file.
467+
Get the ExportedProgram for a specific method after torch export.
467468
468469
Args:
469-
path: Path where the PTE file will be saved
470+
method_name: Name of the method to get exported program for, defaults to "forward"
471+
472+
Returns:
473+
The ExportedProgram for the specified method
470474
471475
Raises:
472-
RuntimeError: If the executorch program manager is not initialized
476+
RuntimeError: If torch export stage has not been run
477+
KeyError: If the method name is not found in exported programs
473478
"""
474-
self.get_executorch_program_manager().save(path)
479+
artifact = self._stage_to_artifacts.get(StageType.TORCH_EXPORT)
480+
if artifact is None or artifact.data is None:
481+
raise RuntimeError(
482+
"Exported program is not available. Run Torch Export Stage first."
483+
)
484+
485+
exported_programs = artifact.data
486+
if method_name not in exported_programs:
487+
raise KeyError(
488+
f"Method name '{method_name}' not found in exported programs. "
489+
f"Available methods: {list(exported_programs.keys())}"
490+
)
491+
492+
return exported_programs[method_name]
493+
494+
def get_edge_program_manager(self) -> "EdgeProgramManager":
495+
"""
496+
Get the EdgeProgramManager after edge lowering stages.
497+
498+
This method checks multiple stages in order of preference:
499+
1. TO_EDGE_TRANSFORM_AND_LOWER (combined stage)
500+
2. TO_BACKEND (separate stage with backend delegation)
501+
3. TO_EDGE (separate stage without backend delegation)
502+
503+
Returns:
504+
The EdgeProgramManager
505+
506+
Raises:
507+
RuntimeError: If no edge stage has been run
508+
"""
509+
# Check stages in order of preference
510+
for stage_type in [
511+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
512+
StageType.TO_BACKEND,
513+
StageType.TO_EDGE,
514+
]:
515+
artifact = self._stage_to_artifacts.get(stage_type)
516+
if artifact is not None and artifact.data is not None:
517+
logging.info(f"Returning edge program manager from stage {stage_type}")
518+
return artifact.data
519+
520+
raise RuntimeError(
521+
"Edge program manager is not available. "
522+
"Run one of the edge stages first: TO_EDGE_TRANSFORM_AND_LOWER, TO_EDGE, or TO_BACKEND."
523+
)
475524

476525
def get_executorch_program(self) -> Program:
477526
"""

export/tests/test_export_session.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,3 +817,157 @@ def test_exported_program_valid_pipeline(self) -> None:
817817

818818
# Should not raise during validation
819819
session._validate_pipeline_sequence(recipe.pipeline_stages)
820+
821+
822+
class TestIntermediateStateGetters(unittest.TestCase):
823+
"""Test convenience getters for intermediate pipeline states."""
824+
825+
def setUp(self) -> None:
826+
self.model = SimpleTestModel()
827+
self.example_inputs = [(torch.randn(2, 10),)]
828+
829+
def test_get_exported_program_after_torch_export(self) -> None:
830+
"""Test that get_exported_program works after torch export stage."""
831+
recipe = ExportRecipe(
832+
name="test",
833+
pipeline_stages=[
834+
StageType.TORCH_EXPORT,
835+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
836+
StageType.TO_EXECUTORCH,
837+
],
838+
)
839+
840+
session = ExportSession(
841+
model=self.model,
842+
example_inputs=self.example_inputs,
843+
export_recipe=recipe,
844+
)
845+
846+
session.export()
847+
848+
exported_program = session.get_exported_program()
849+
self.assertIsNotNone(exported_program)
850+
self.assertIsInstance(exported_program, torch.export.ExportedProgram)
851+
852+
def test_get_exported_program_before_export_fails(self) -> None:
853+
"""Test that get_exported_program fails before torch export stage."""
854+
recipe = ExportRecipe(name="test")
855+
856+
session = ExportSession(
857+
model=self.model,
858+
example_inputs=self.example_inputs,
859+
export_recipe=recipe,
860+
)
861+
862+
with self.assertRaises(RuntimeError) as cm:
863+
session.get_exported_program()
864+
self.assertIn("Exported program is not available", str(cm.exception))
865+
866+
def test_get_exported_program_invalid_method_name(self) -> None:
867+
"""Test that get_exported_program fails with invalid method name."""
868+
recipe = ExportRecipe(name="test")
869+
870+
session = ExportSession(
871+
model=self.model,
872+
example_inputs=self.example_inputs,
873+
export_recipe=recipe,
874+
)
875+
876+
session.export()
877+
878+
with self.assertRaises(KeyError) as cm:
879+
session.get_exported_program("nonexistent_method")
880+
self.assertIn("Method name 'nonexistent_method' not found", str(cm.exception))
881+
882+
def test_get_exported_program_multi_method(self) -> None:
883+
"""Test get_exported_program with multi-method model."""
884+
model_dict = {
885+
"forward": self.model,
886+
"inference": SimpleTestModel(),
887+
}
888+
inputs_dict = {
889+
"forward": self.example_inputs,
890+
"inference": [(torch.randn(1, 10),)],
891+
}
892+
893+
recipe = ExportRecipe(name="multi_method_test")
894+
895+
session = ExportSession(
896+
model=model_dict,
897+
example_inputs=inputs_dict,
898+
export_recipe=recipe,
899+
)
900+
901+
session.export()
902+
903+
forward_ep = session.get_exported_program("forward")
904+
inference_ep = session.get_exported_program("inference")
905+
906+
self.assertIsNotNone(forward_ep)
907+
self.assertIsNotNone(inference_ep)
908+
self.assertIsInstance(forward_ep, torch.export.ExportedProgram)
909+
self.assertIsInstance(inference_ep, torch.export.ExportedProgram)
910+
911+
def test_get_edge_program_manager_with_transform_and_lower(self) -> None:
912+
"""Test get_edge_program_manager with TO_EDGE_TRANSFORM_AND_LOWER stage."""
913+
recipe = ExportRecipe(
914+
name="test",
915+
pipeline_stages=[
916+
StageType.TORCH_EXPORT,
917+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
918+
StageType.TO_EXECUTORCH,
919+
],
920+
)
921+
922+
session = ExportSession(
923+
model=self.model,
924+
example_inputs=self.example_inputs,
925+
export_recipe=recipe,
926+
)
927+
928+
session.export()
929+
930+
edge_manager = session.get_edge_program_manager()
931+
self.assertIsNotNone(edge_manager)
932+
933+
def test_get_edge_program_manager_with_separate_stages(self) -> None:
934+
"""Test get_edge_program_manager with separate TO_EDGE and TO_BACKEND stages."""
935+
recipe = ExportRecipe(
936+
name="test",
937+
pipeline_stages=[
938+
StageType.TORCH_EXPORT,
939+
StageType.TO_EDGE,
940+
StageType.TO_BACKEND,
941+
StageType.TO_EXECUTORCH,
942+
],
943+
)
944+
945+
session = ExportSession(
946+
model=self.model,
947+
example_inputs=self.example_inputs,
948+
export_recipe=recipe,
949+
)
950+
951+
session.export()
952+
953+
edge_manager = session.get_edge_program_manager()
954+
self.assertIsNotNone(edge_manager)
955+
956+
def test_get_edge_program_manager_before_edge_stage_fails(self) -> None:
957+
"""Test that get_edge_program_manager fails before edge stages."""
958+
recipe = ExportRecipe(
959+
name="test",
960+
pipeline_stages=[StageType.TORCH_EXPORT],
961+
)
962+
963+
session = ExportSession(
964+
model=self.model,
965+
example_inputs=self.example_inputs,
966+
export_recipe=recipe,
967+
)
968+
969+
session.export()
970+
971+
with self.assertRaises(RuntimeError) as cm:
972+
session.get_edge_program_manager()
973+
self.assertIn("Edge program manager is not available", str(cm.exception))

0 commit comments

Comments
 (0)