@@ -817,3 +817,157 @@ def test_exported_program_valid_pipeline(self) -> None:
817817
818818 # Should not raise during validation
819819 session ._validate_pipeline_sequence (recipe .pipeline_stages )
820+
821+
822+ class TestIntermediateStateGetters (unittest .TestCase ):
823+ """Test convenience getters for intermediate pipeline states."""
824+
825+ def setUp (self ) -> None :
826+ self .model = SimpleTestModel ()
827+ self .example_inputs = [(torch .randn (2 , 10 ),)]
828+
829+ def test_get_exported_program_after_torch_export (self ) -> None :
830+ """Test that get_exported_program works after torch export stage."""
831+ recipe = ExportRecipe (
832+ name = "test" ,
833+ pipeline_stages = [
834+ StageType .TORCH_EXPORT ,
835+ StageType .TO_EDGE_TRANSFORM_AND_LOWER ,
836+ StageType .TO_EXECUTORCH ,
837+ ],
838+ )
839+
840+ session = ExportSession (
841+ model = self .model ,
842+ example_inputs = self .example_inputs ,
843+ export_recipe = recipe ,
844+ )
845+
846+ session .export ()
847+
848+ exported_program = session .get_exported_program ()
849+ self .assertIsNotNone (exported_program )
850+ self .assertIsInstance (exported_program , torch .export .ExportedProgram )
851+
852+ def test_get_exported_program_before_export_fails (self ) -> None :
853+ """Test that get_exported_program fails before torch export stage."""
854+ recipe = ExportRecipe (name = "test" )
855+
856+ session = ExportSession (
857+ model = self .model ,
858+ example_inputs = self .example_inputs ,
859+ export_recipe = recipe ,
860+ )
861+
862+ with self .assertRaises (RuntimeError ) as cm :
863+ session .get_exported_program ()
864+ self .assertIn ("Exported program is not available" , str (cm .exception ))
865+
866+ def test_get_exported_program_invalid_method_name (self ) -> None :
867+ """Test that get_exported_program fails with invalid method name."""
868+ recipe = ExportRecipe (name = "test" )
869+
870+ session = ExportSession (
871+ model = self .model ,
872+ example_inputs = self .example_inputs ,
873+ export_recipe = recipe ,
874+ )
875+
876+ session .export ()
877+
878+ with self .assertRaises (KeyError ) as cm :
879+ session .get_exported_program ("nonexistent_method" )
880+ self .assertIn ("Method name 'nonexistent_method' not found" , str (cm .exception ))
881+
882+ def test_get_exported_program_multi_method (self ) -> None :
883+ """Test get_exported_program with multi-method model."""
884+ model_dict = {
885+ "forward" : self .model ,
886+ "inference" : SimpleTestModel (),
887+ }
888+ inputs_dict = {
889+ "forward" : self .example_inputs ,
890+ "inference" : [(torch .randn (1 , 10 ),)],
891+ }
892+
893+ recipe = ExportRecipe (name = "multi_method_test" )
894+
895+ session = ExportSession (
896+ model = model_dict ,
897+ example_inputs = inputs_dict ,
898+ export_recipe = recipe ,
899+ )
900+
901+ session .export ()
902+
903+ forward_ep = session .get_exported_program ("forward" )
904+ inference_ep = session .get_exported_program ("inference" )
905+
906+ self .assertIsNotNone (forward_ep )
907+ self .assertIsNotNone (inference_ep )
908+ self .assertIsInstance (forward_ep , torch .export .ExportedProgram )
909+ self .assertIsInstance (inference_ep , torch .export .ExportedProgram )
910+
911+ def test_get_edge_program_manager_with_transform_and_lower (self ) -> None :
912+ """Test get_edge_program_manager with TO_EDGE_TRANSFORM_AND_LOWER stage."""
913+ recipe = ExportRecipe (
914+ name = "test" ,
915+ pipeline_stages = [
916+ StageType .TORCH_EXPORT ,
917+ StageType .TO_EDGE_TRANSFORM_AND_LOWER ,
918+ StageType .TO_EXECUTORCH ,
919+ ],
920+ )
921+
922+ session = ExportSession (
923+ model = self .model ,
924+ example_inputs = self .example_inputs ,
925+ export_recipe = recipe ,
926+ )
927+
928+ session .export ()
929+
930+ edge_manager = session .get_edge_program_manager ()
931+ self .assertIsNotNone (edge_manager )
932+
933+ def test_get_edge_program_manager_with_separate_stages (self ) -> None :
934+ """Test get_edge_program_manager with separate TO_EDGE and TO_BACKEND stages."""
935+ recipe = ExportRecipe (
936+ name = "test" ,
937+ pipeline_stages = [
938+ StageType .TORCH_EXPORT ,
939+ StageType .TO_EDGE ,
940+ StageType .TO_BACKEND ,
941+ StageType .TO_EXECUTORCH ,
942+ ],
943+ )
944+
945+ session = ExportSession (
946+ model = self .model ,
947+ example_inputs = self .example_inputs ,
948+ export_recipe = recipe ,
949+ )
950+
951+ session .export ()
952+
953+ edge_manager = session .get_edge_program_manager ()
954+ self .assertIsNotNone (edge_manager )
955+
956+ def test_get_edge_program_manager_before_edge_stage_fails (self ) -> None :
957+ """Test that get_edge_program_manager fails before edge stages."""
958+ recipe = ExportRecipe (
959+ name = "test" ,
960+ pipeline_stages = [StageType .TORCH_EXPORT ],
961+ )
962+
963+ session = ExportSession (
964+ model = self .model ,
965+ example_inputs = self .example_inputs ,
966+ export_recipe = recipe ,
967+ )
968+
969+ session .export ()
970+
971+ with self .assertRaises (RuntimeError ) as cm :
972+ session .get_edge_program_manager ()
973+ self .assertIn ("Edge program manager is not available" , str (cm .exception ))
0 commit comments