diff --git a/backends/apple/coreml/TARGETS b/backends/apple/coreml/TARGETS index 22cb20d9065..444e886b4e6 100644 --- a/backends/apple/coreml/TARGETS +++ b/backends/apple/coreml/TARGETS @@ -61,16 +61,21 @@ runtime.python_library( ) runtime.python_library( - name = "recipes", - srcs = glob([ - "recipes/*.py", - ]), + name = "coreml_recipes", + srcs = [ + "recipes/__init__.py", + "recipes/coreml_recipe_provider.py" + ], visibility = [ "@EXECUTORCH_CLIENTS", + "//executorch/export/...", ], deps = [ "fbsource//third-party/pypi/coremltools:coremltools", + ":coreml_recipe_types", ":backend", + ":partitioner", + ":quantizer", "//caffe2:torch", "//executorch/exir:lib", "//executorch/exir/backend:compile_spec_schema", @@ -80,6 +85,20 @@ runtime.python_library( ], ) +runtime.python_library( + name = "coreml_recipe_types", + srcs = [ + "recipes/coreml_recipe_types.py", + ], + visibility = [ + "@EXECUTORCH_CLIENTS", + "//executorch/export/...", + ], + deps = [ + "//executorch/export:recipe", + ], +) + runtime.cxx_python_extension( name = "executorchcoreml", srcs = [ @@ -124,7 +143,7 @@ runtime.python_test( "fbsource//third-party/pypi/pytest:pytest", ":partitioner", ":quantizer", - ":recipes", + ":coreml_recipes", "//caffe2:torch", "//pytorch/vision:torchvision", "fbsource//third-party/pypi/scikit-learn:scikit-learn", diff --git a/backends/apple/coreml/recipes/coreml_recipe_provider.py b/backends/apple/coreml/recipes/coreml_recipe_provider.py index 90b798f9e0c..77e15aeced3 100644 --- a/backends/apple/coreml/recipes/coreml_recipe_provider.py +++ b/backends/apple/coreml/recipes/coreml_recipe_provider.py @@ -3,6 +3,7 @@ # Please refer to the license found in the LICENSE file in the root directory of the source tree. +import logging from typing import Any, Optional, Sequence import coremltools as ct @@ -111,8 +112,9 @@ def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> Non unexpected = set(kwargs.keys()) - expected_keys if unexpected: - raise ValueError( - f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}" + logging.warning( + f"CoreML recipe '{recipe_type.value}' ignoring unexpected parameters: {list(unexpected)}. " + f"Expected parameters: {list(expected_keys)}" ) self._validate_base_parameters(kwargs) @@ -121,7 +123,13 @@ def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> Non def _get_expected_keys(self, recipe_type: RecipeType) -> set: """Get expected parameter keys for a recipe type""" - common_keys = {"minimum_deployment_target", "compute_unit"} + common_keys = { + "minimum_deployment_target", + "compute_unit", + "skip_ops_for_coreml_delegation", + "lower_full_graph", + "take_over_constant_data", + } if recipe_type in [ CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, @@ -377,9 +385,19 @@ def _get_coreml_lowering_recipe( if minimum_deployment_target and minimum_deployment_target < ct.target.iOS18: take_over_mutable_buffer = False + # Extract additional partitioner parameters + skip_ops_for_coreml_delegation = kwargs.get( + "skip_ops_for_coreml_delegation", None + ) + lower_full_graph = kwargs.get("lower_full_graph", False) + take_over_constant_data = kwargs.get("take_over_constant_data", True) + partitioner = CoreMLPartitioner( compile_specs=compile_specs, take_over_mutable_buffer=take_over_mutable_buffer, + skip_ops_for_coreml_delegation=skip_ops_for_coreml_delegation, + lower_full_graph=lower_full_graph, + take_over_constant_data=take_over_constant_data, ) edge_compile_config = EdgeCompileConfig( diff --git a/backends/apple/coreml/test/test_coreml_recipes.py b/backends/apple/coreml/test/test_coreml_recipes.py index 7a78836b2bc..78d5a30063c 100644 --- a/backends/apple/coreml/test/test_coreml_recipes.py +++ b/backends/apple/coreml/test/test_coreml_recipes.py @@ -185,14 +185,6 @@ def test_int4_weight_only_per_group_validation(self): ) self.assertIn("must be positive", str(cm.exception)) - # Test unexpected parameter - with self.assertRaises(ValueError) as cm: - self.provider.create_recipe( - CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL, - group_size=32, # group_size not valid for per-channel - ) - self.assertIn("unexpected parameters", str(cm.exception)) - def test_int8_weight_only_per_channel(self): """Test INT8 weight-only per-channel quantization""" model = TestHelperModules.TwoLinearModule().eval() @@ -385,23 +377,6 @@ def forward(self, x): self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-2) self._compare_eager_unquantized_model_outputs(session, model, example_inputs) - def test_pt2e_recipes_parameter_rejection(self): - """Test that PT2E recipes reject TorchAO-specific parameters""" - # PT2E recipes should reject TorchAO-specific parameters - pt2e_recipes = [ - CoreMLRecipeType.PT2E_INT8_STATIC, - CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY, - ] - torchao_params = ["filter_fn", "group_size", "bits", "block_size"] - - for recipe_type in pt2e_recipes: - for param in torchao_params: - with self.subTest(recipe=recipe_type.value, param=param): - kwargs = {param: "dummy_value"} - with self.assertRaises(ValueError) as cm: - self.provider.create_recipe(recipe_type, **kwargs) - self.assertIn("unexpected parameters", str(cm.exception).lower()) - def test_filter_fn_comprehensive(self): """Comprehensive test for filter_fn parameter functionality""" diff --git a/backends/xnnpack/TARGETS b/backends/xnnpack/TARGETS index 62a703bddb7..d5c6d6303d2 100644 --- a/backends/xnnpack/TARGETS +++ b/backends/xnnpack/TARGETS @@ -36,10 +36,7 @@ runtime.python_library( ], deps = [ ":xnnpack_preprocess", - "//executorch/export:lib", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", "//executorch/backends/xnnpack/utils:xnnpack_utils", - "//executorch/backends/xnnpack/recipes:xnnpack_recipe_provider", - "//executorch/backends/xnnpack/recipes:xnnpack_recipe_types", ], ) diff --git a/backends/xnnpack/__init__.py b/backends/xnnpack/__init__.py index 01b73101c86..b87dfab4f02 100644 --- a/backends/xnnpack/__init__.py +++ b/backends/xnnpack/__init__.py @@ -4,18 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from executorch.export import recipe_registry - # Exposed Partitioners in XNNPACK Package from .partition.xnnpack_partitioner import ( XnnpackDynamicallyQuantizedPartitioner, XnnpackPartitioner, ) -from .recipes.xnnpack_recipe_provider import XNNPACKRecipeProvider -from .recipes.xnnpack_recipe_types import XNNPackRecipeType - -# Auto-register XNNPACK recipe provider -recipe_registry.register_backend_recipe_provider(XNNPACKRecipeProvider()) # Exposed Configs in XNNPACK Package from .utils.configs import ( @@ -34,7 +27,6 @@ "XnnpackDynamicallyQuantizedPartitioner", "XnnpackPartitioner", "XnnpackBackend", - "XNNPackRecipeType", "capture_graph_for_xnnpack", "get_xnnpack_capture_config", "get_xnnpack_edge_compile_config", diff --git a/backends/xnnpack/recipes/TARGETS b/backends/xnnpack/recipes/TARGETS index 60968a5085d..6b6c1ddfe82 100644 --- a/backends/xnnpack/recipes/TARGETS +++ b/backends/xnnpack/recipes/TARGETS @@ -2,6 +2,22 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") oncall("executorch") +runtime.python_library( + name = "xnnpack_recipes", + srcs = [ + "__init__.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/export:recipe_registry", + ":xnnpack_recipe_provider", + ":xnnpack_recipe_types", + ], +) + runtime.python_library( name = "xnnpack_recipe_provider", srcs = [ @@ -30,6 +46,6 @@ runtime.python_library( "@EXECUTORCH_CLIENTS", ], deps = [ - "//executorch/export:lib", + "//executorch/export:recipe", ], ) diff --git a/backends/xnnpack/recipes/__init__.py b/backends/xnnpack/recipes/__init__.py new file mode 100644 index 00000000000..3fa8e9496e6 --- /dev/null +++ b/backends/xnnpack/recipes/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.export import recipe_registry + +from .xnnpack_recipe_provider import XNNPACKRecipeProvider +from .xnnpack_recipe_types import XNNPackRecipeType + +# Auto-register XNNPACK recipe provider +recipe_registry.register_backend_recipe_provider(XNNPACKRecipeProvider()) + + +__all__ = [ + "XNNPACKRecipeProvider", + "XNNPackRecipeType", +] diff --git a/backends/xnnpack/recipes/xnnpack_recipe_provider.py b/backends/xnnpack/recipes/xnnpack_recipe_provider.py index 436eb2db158..2c80d528c45 100644 --- a/backends/xnnpack/recipes/xnnpack_recipe_provider.py +++ b/backends/xnnpack/recipes/xnnpack_recipe_provider.py @@ -6,6 +6,7 @@ # pyre-strict +import logging from typing import Any, Optional, Sequence import torch @@ -180,9 +181,9 @@ def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> Non expected_keys = {"group_size"} unexpected = set(kwargs.keys()) - expected_keys if unexpected: - raise ValueError( - f"Recipe '{recipe_type.value}' only accepts 'group_size' parameter. " - f"Unexpected parameters: {list(unexpected)}" + logging.warning( + f"XNNPACK recipe '{recipe_type.value}' ignoring unexpected parameters: {list(unexpected)}. " + f"Only 'group_size' is supported for this recipe." ) if "group_size" in kwargs: group_size = kwargs["group_size"] @@ -193,7 +194,7 @@ def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> Non elif kwargs: # All other recipes don't expect any kwargs unexpected = list(kwargs.keys()) - raise ValueError( - f"Recipe '{recipe_type.value}' does not accept any parameters. " - f"Unexpected parameters: {unexpected}" + logging.warning( + f"XNNPACK recipe '{recipe_type.value}' ignoring unexpected parameters: {unexpected}. " + f"This recipe does not accept any parameters." ) diff --git a/backends/xnnpack/recipes/xnnpack_recipe_types.py b/backends/xnnpack/recipes/xnnpack_recipe_types.py index 61117b94502..82296f912c2 100644 --- a/backends/xnnpack/recipes/xnnpack_recipe_types.py +++ b/backends/xnnpack/recipes/xnnpack_recipe_types.py @@ -12,23 +12,25 @@ class XNNPackRecipeType(RecipeType): """XNNPACK-specific recipe types""" - FP32 = "fp32" + FP32 = "xnnpack_fp32" ## PT2E-based quantization recipes # INT8 Dynamic Quantization - PT2E_INT8_DYNAMIC_PER_CHANNEL = "pt2e_int8_dynamic_per_channel" + PT2E_INT8_DYNAMIC_PER_CHANNEL = "xnnpack_pt2e_int8_dynamic_per_channel" # INT8 Static Quantization, needs calibration dataset - PT2E_INT8_STATIC_PER_CHANNEL = "pt2e_int8_static_per_channel" - PT2E_INT8_STATIC_PER_TENSOR = "pt2e_int8_static_per_tensor" + PT2E_INT8_STATIC_PER_CHANNEL = "xnnpack_pt2e_int8_static_per_channel" + PT2E_INT8_STATIC_PER_TENSOR = "xnnpack_pt2e_int8_static_per_tensor" ## TorchAO-based quantization recipes # INT8 Dynamic Activations INT4 Weight Quantization, Axis = 0 TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = ( - "torchao_int8da_int4w_per_channel" + "xnnpack_torchao_int8da_int4w_per_channel" ) # INT8 Dynamic Activations INT4 Weight Quantization, default group_size = 32 # can be overriden by group_size kwarg - TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "torchao_int8da_int4w_per_tensor" + TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = ( + "xnnpack_torchao_int8da_int4w_per_tensor" + ) @classmethod def get_backend_name(cls) -> str: diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS index 5679f336fef..5f3581b6aeb 100644 --- a/backends/xnnpack/test/TARGETS +++ b/backends/xnnpack/test/TARGETS @@ -105,7 +105,7 @@ runtime.python_test( "HTTPS_PROXY": "http://fwdproxy:8080", }, deps = [ - "//executorch/backends/xnnpack:xnnpack_delegate", + "//executorch/backends/xnnpack/recipes:xnnpack_recipes", "//executorch/export:lib", "//pytorch/vision:torchvision", # @manual "//executorch/backends/xnnpack/test/tester:tester", diff --git a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py index e4bd6f1f4c1..aa470bdcb50 100644 --- a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py +++ b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py @@ -286,15 +286,6 @@ def test_all_models_with_recipes(self) -> None: if os.path.exists("dog.jpg"): os.remove("dog.jpg") - def test_validate_recipe_kwargs_fp32(self) -> None: - provider = XNNPACKRecipeProvider() - - with self.assertRaises(ValueError) as cm: - provider.create_recipe(XNNPackRecipeType.FP32, invalid_param=123) - - error_msg = str(cm.exception) - self.assertIn("Recipe 'fp32' does not accept any parameters", error_msg) - def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size( self, ) -> None: diff --git a/export/TARGETS b/export/TARGETS index 816a3a1a289..ae41393d883 100644 --- a/export/TARGETS +++ b/export/TARGETS @@ -110,3 +110,16 @@ runtime.python_library( "types.py", ], ) + +runtime.python_library( + name = "target_recipes", + srcs = [ + "target_recipes.py", + ], + deps = [ + "fbsource//third-party/pypi/coremltools:coremltools", + "//executorch/export:recipe", + "//executorch/backends/xnnpack/recipes:xnnpack_recipes", + "//executorch/backends/apple/coreml:coreml_recipes", + ] +) diff --git a/export/recipe.py b/export/recipe.py index 086d57f3e38..811270cdbf8 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy from abc import ABCMeta, abstractmethod from dataclasses import dataclass from enum import Enum, EnumMeta @@ -184,3 +185,129 @@ def get_recipe(cls, recipe: "RecipeType", **kwargs) -> "ExportRecipe": f"Supported: {[r.value for r in supported]}" ) return export_recipe + + @classmethod + def combine( + cls, recipes: List["ExportRecipe"], recipe_name: Optional[str] = None + ) -> "ExportRecipe": + """ + Combine multiple ExportRecipe objects into a single recipe. + + Args: + recipes: List of ExportRecipe objects to combine + recipe_name: Optional name for the combined recipe + + Returns: + A new ExportRecipe that combines all input recipes + + Example: + recipe1 = ExportRecipe.get_recipe(CoreMLRecipeType.FP32) + recipe2 = ExportRecipe.get_recipe(XNNPackRecipeType.FP32) + combined_recipe = ExportRecipe.combine( + [recipe1, recipe2], + recipe_name="multi_backend_coreml_xnnpack_fp32" + ) + """ + if not recipes: + raise ValueError("Recipes cannot be empty") + + if len(recipes) == 1: + return recipes[0] + + return cls._combine_recipes(recipes, recipe_name) + + @classmethod + def _combine_recipes( # noqa: C901 + cls, backend_recipes: List["ExportRecipe"], recipe_name: Optional[str] = None + ) -> "ExportRecipe": + """ + Util to combine multiple backend recipes into a single multi-backend recipe. + + Args: + backend_recipes: List of ExportRecipe objects to combine + recipe_name: Optional name for the combined recipe + + Returns: + Combined ExportRecipe for multi-backend deployment + """ + # Extract components from individual recipes + all_partitioners = [] + all_quantizers = [] + all_ao_quantization_configs = [] + all_pre_edge_passes = [] + all_transform_passes = [] + combined_backend_config = None + + for recipe in backend_recipes: + # Collect pre-edge transform passes + if recipe.pre_edge_transform_passes: + all_pre_edge_passes.extend(recipe.pre_edge_transform_passes) + + # Collect partitioners from lowering recipes + if recipe.lowering_recipe and recipe.lowering_recipe.partitioners: + all_partitioners.extend(recipe.lowering_recipe.partitioners) + + # Collect transform passes from lowering recipes + if recipe.lowering_recipe and recipe.lowering_recipe.edge_transform_passes: + all_transform_passes.extend( + recipe.lowering_recipe.edge_transform_passes + ) + + # Collect for quantize stage + if quantization_recipe := recipe.quantization_recipe: + # Collect PT2E quantizers + if quantization_recipe.quantizers: + all_quantizers.extend(quantization_recipe.quantizers) + + # Collect source transform configs + if quantization_recipe.ao_quantization_configs: + all_ao_quantization_configs.extend( + quantization_recipe.ao_quantization_configs + ) + + # Use the first backend config as base + if combined_backend_config is None and recipe.executorch_backend_config: + combined_backend_config = copy.deepcopy( + recipe.executorch_backend_config + ) + + # Create combined quantization recipe + combined_quantization_recipe = None + if all_quantizers or all_ao_quantization_configs: + combined_quantization_recipe = QuantizationRecipe( + quantizers=all_quantizers if all_quantizers else None, + ao_quantization_configs=( + all_ao_quantization_configs if all_ao_quantization_configs else None + ), + ) + + # Create combined lowering recipe + combined_lowering_recipe = None + if all_partitioners or all_transform_passes: + edge_compile_config = None + for recipe in backend_recipes: + if ( + recipe.lowering_recipe + and recipe.lowering_recipe.edge_compile_config + ): + edge_compile_config = recipe.lowering_recipe.edge_compile_config + break + + combined_lowering_recipe = LoweringRecipe( + partitioners=all_partitioners if all_partitioners else None, + edge_transform_passes=( + all_transform_passes if all_transform_passes else None + ), + edge_compile_config=edge_compile_config or EdgeCompileConfig(), + ) + + recipe_name = recipe_name or "_".join( + [r.name for r in backend_recipes if r.name is not None] + ) + return cls( + name=recipe_name, + quantization_recipe=combined_quantization_recipe, + pre_edge_transform_passes=all_pre_edge_passes, + lowering_recipe=combined_lowering_recipe, + executorch_backend_config=combined_backend_config, + ) diff --git a/export/stages.py b/export/stages.py index 2b3f8a42440..7b583e27943 100644 --- a/export/stages.py +++ b/export/stages.py @@ -309,9 +309,14 @@ def run(self, artifact: PipelineArtifact) -> None: # Apply torchao quantize_ to each model for _, model in artifact.data.items(): # pyre-ignore - for ao_config in self._quantization_recipe.ao_quantization_configs: - quantize_(model, ao_config.ao_base_config, ao_config.filter_fn) - unwrap_tensor_subclass(model) + if len(self._quantization_recipe.ao_quantization_configs) > 1: + raise ValueError( + "AO quantization configs cannot be reliably composed together, multiple quantization configs are disallowed for source transform at this point" + ) + + ao_config = self._quantization_recipe.ao_quantization_configs[0] + quantize_(model, ao_config.ao_base_config, ao_config.filter_fn) + unwrap_tensor_subclass(model) self._artifact = artifact.copy_with_new_data(self._transformed_models) diff --git a/export/target_recipes.py b/export/target_recipes.py new file mode 100644 index 00000000000..76e0cacc7b4 --- /dev/null +++ b/export/target_recipes.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Target-specific recipe functions for simplified multi-backend deployment. + +This module provides platform-specific functions that abstract away backend +selection and combine multiple backends optimally for target hardware. +""" + +from typing import Dict, List + +import coremltools as ct + +# pyre-ignore +from executorch.backends.apple.coreml.recipes import CoreMLRecipeType +from executorch.backends.xnnpack.recipes import XNNPackRecipeType +from executorch.export.recipe import ExportRecipe, RecipeType + + +## IOS Target configs +# The following list of recipes are not exhaustive for CoreML; refer to CoreMLRecipeType for more detailed recipes. +IOS_CONFIGS: Dict[str, List[RecipeType]] = { + # pyre-ignore + "ios-arm64-coreml-fp32": [CoreMLRecipeType.FP32, XNNPackRecipeType.FP32], + # pyre-ignore + "ios-arm64-coreml-fp16": [CoreMLRecipeType.FP16], + # pyre-ignore + "ios-arm64-coreml-int8": [CoreMLRecipeType.PT2E_INT8_STATIC], +} + + +def _create_target_recipe( + target_config: str, recipes: List[RecipeType], **kwargs +) -> ExportRecipe: + """ + Create a combined recipe for a target. + + Args: + target: Human-readable hardware configuration name + recipes: List of backend recipe types to combine + **kwargs: Additional parameters - each backend will use what it needs + + Returns: + Combined ExportRecipe for the hardware configuration + """ + if not recipes: + raise ValueError(f"No backends configured for: {target_config}") + + # Create individual backend recipes + backend_recipes = [] + for recipe_type in recipes: + try: + backend_recipe = ExportRecipe.get_recipe(recipe_type, **kwargs) + backend_recipes.append(backend_recipe) + except Exception as e: + raise ValueError( + f"Failed to create {recipe_type.value} recipe for {target_config}: {e}" + ) from e + + # Combine into single recipe + if len(backend_recipes) == 1: + return backend_recipes[0] + + return ExportRecipe.combine(backend_recipes, recipe_name=target_config) + + +# IOS Recipe +def get_ios_recipe( + target_config: str = "ios-arm64-coreml-fp16", **kwargs +) -> ExportRecipe: + """ + Get iOS-optimized recipe for specified hardware configuration. + + Supported configurations: + - 'ios-arm64-coreml-fp32': CoreML + XNNPACK fallback (FP32) + - 'ios-arm64-coreml-fp16': CoreML fp16 recipe + - 'ios-arm64-coreml-int8': CoreML INT8 quantization recipe + + Args: + target_config: iOS configuration string + **kwargs: Additional parameters for backend recipes + + Returns: + ExportRecipe configured for iOS deployment + + Raises: + ValueError: If target configuration is not supported + + Example: + recipe = get_ios_recipe('ios-arm64-coreml-int8') + session = export(model, recipe, example_inputs) + """ + if target_config not in IOS_CONFIGS: + supported = list(IOS_CONFIGS.keys()) + raise ValueError( + f"Unsupported iOS configuration: '{target_config}'. " + f"Supported: {supported}" + ) + + kwargs = kwargs or {} + + if target_config == "ios-arm64-coreml-int8": + if "minimum_deployment_target" not in kwargs: + kwargs["minimum_deployment_target"] = ct.target.iOS17 + + backend_recipes = IOS_CONFIGS[target_config] + return _create_target_recipe(target_config, backend_recipes, **kwargs) diff --git a/export/tests/TARGETS b/export/tests/TARGETS index 068c3436b6a..71f28b64df7 100644 --- a/export/tests/TARGETS +++ b/export/tests/TARGETS @@ -31,3 +31,17 @@ runtime.python_test( "//executorch/runtime:runtime", ] ) + +runtime.python_test( + name = "test_target_recipes", + srcs = [ + "test_target_recipes.py", + ], + deps = [ + "//executorch/export:lib", + "//executorch/export:target_recipes", + "//executorch/runtime:runtime", + "//executorch/backends/xnnpack/recipes:xnnpack_recipes", + "//executorch/backends/apple/coreml:coreml_recipes", + ] +) diff --git a/export/tests/test_target_recipes.py b/export/tests/test_target_recipes.py new file mode 100644 index 00000000000..d781ffea945 --- /dev/null +++ b/export/tests/test_target_recipes.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +import unittest + +import torch +from executorch.backends.apple.coreml.recipes import CoreMLRecipeProvider # pyre-ignore +from executorch.backends.xnnpack.recipes.xnnpack_recipe_provider import ( + XNNPACKRecipeProvider, +) +from executorch.export import export, recipe_registry +from executorch.export.target_recipes import get_ios_recipe +from executorch.runtime import Runtime + + +class TestTargetRecipes(unittest.TestCase): + """Test target recipes.""" + + def setUp(self) -> None: + torch._dynamo.reset() + super().setUp() + recipe_registry.register_backend_recipe_provider(XNNPACKRecipeProvider()) + # pyre-ignore + recipe_registry.register_backend_recipe_provider(CoreMLRecipeProvider()) + + def tearDown(self) -> None: + super().tearDown() + + def test_ios_fp32_recipe_with_xnnpack_fallback(self) -> None: + # Linear ops skipped by coreml but handled by xnnpack + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 2) + + def forward(self, x, y): + a = self.linear1(x) + b = a + y + c = b - x + result = self.linear2(c) + return result + + model = Model() + model.eval() + + example_inputs = [(torch.randn(2, 4), torch.randn(2, 4))] + + # Export using multi-backend target recipe with CoreML configured to skip linear operations + recipe = get_ios_recipe( + "ios-arm64-coreml-fp32", + skip_ops_for_coreml_delegation=["aten.linear.default"], + ) + + # Export the model + session = export( + model=model, example_inputs=example_inputs, export_recipe=recipe + ) + + # Verify we can create executable + executorch_program = session.get_executorch_program() + # session.print_delegation_info() + + self.assertIsNotNone( + executorch_program, "ExecuTorch program should not be None" + ) + + # Assert there is an execution plan + self.assertTrue(len(executorch_program.execution_plan) == 1) + + # Check number of partitions created + self.assertTrue(len(executorch_program.execution_plan[0].delegates) == 3) + + # First delegate backend is Xnnpack + self.assertEqual( + executorch_program.execution_plan[0].delegates[0].id, + "XnnpackBackend", + ) + + # Second delegate backend is CoreML + self.assertEqual( + executorch_program.execution_plan[0].delegates[1].id, + "CoreMLBackend", + ) + + # Third delegate backend is Xnnpack + self.assertEqual( + executorch_program.execution_plan[0].delegates[2].id, + "XnnpackBackend", + ) + + et_runtime: Runtime = Runtime.get() + backend_registry = et_runtime.backend_registry + logging.info( + f"backends registered: {et_runtime.backend_registry.registered_backend_names}" + ) + if backend_registry.is_available( + "CoreMLBackend" + ) and backend_registry.is_available("XnnpackBackend"): + logging.info("Running with CoreML and XNNPACK backends") + et_output = session.run_method("forward", example_inputs[0]) + logging.info(f"et output {et_output}") + + def test_ios_quant_recipes(self) -> None: + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 2) + + def forward(self, x, y): + a = self.linear1(x) + b = a + y + c = b - x + result = self.linear2(c) + return result + + model = Model() + model.eval() + + example_inputs = [(torch.randn(2, 4), torch.randn(2, 4))] + + for recipe in [ + get_ios_recipe("ios-arm64-coreml-fp16"), + get_ios_recipe("ios-arm64-coreml-int8"), + ]: + # Export the model + session = export( + model=model, example_inputs=example_inputs, export_recipe=recipe + ) + + # Verify we can create executable + executorch_program = session.get_executorch_program() + session.print_delegation_info() + + self.assertIsNotNone( + executorch_program, "ExecuTorch program should not be None" + ) + + # Assert there is an execution plan + self.assertTrue(len(executorch_program.execution_plan) == 1) + + # Check number of partitions created + self.assertTrue(len(executorch_program.execution_plan[0].delegates) == 1) + + # Delegate backend is CoreML + self.assertEqual( + executorch_program.execution_plan[0].delegates[0].id, + "CoreMLBackend", + ) + + # Check number of instructions + instructions = executorch_program.execution_plan[0].chains[0].instructions + self.assertIsNotNone(instructions) + self.assertEqual(len(instructions), 1) + + et_runtime: Runtime = Runtime.get() + backend_registry = et_runtime.backend_registry + logging.info( + f"backends registered: {et_runtime.backend_registry.registered_backend_names}" + ) + if backend_registry.is_available("CoreMLBackend"): + et_output = session.run_method("forward", example_inputs[0]) + logging.info(f"et output {et_output}")