From eaa16636c37e70d5e5ee279a2428d9067cbf34c5 Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Wed, 10 Sep 2025 09:35:06 -0700 Subject: [PATCH 1/2] [Executorch][Export/ Recipes] Modify pre_edge_transform_passes, edge_transform_passes definition to take exported program and method name. Pull Request resolved: https://github.com/pytorch/executorch/pull/14125 While adding QNN recipes, i realized there are gaps in pre edge and edge transform passes in the recipe logic. Where the transform passes need exported program to resolve them dynamically at runtime. Changes made: - `pre_edge_transform_passes` is now renamed to `aten_transform_passes` - `aten_transform_passes`: now accepts list of transformation functions which takes in (method_name, ExportedProgram) and gives back ExportedProgram - `aten_transform_passes: Optional[List[Callable[[str, ExportedProgram], ExportedProgram]]]` - `edge_transform_passes`: Now this is a list of callables that resolves to a list of passes, callables is needed because some of the passes might need ExportedProgram which is only available during execution and not when recipe is created. - `edge_transform_passes: None | List[Callable[[str, ExportedProgram], List[PassType]]]` ghstack-source-id: 308796607 Differential Revision: [D81730890](https://our.internmc.facebook.com/intern/diff/D81730890/) --- export/export.py | 10 +- export/recipe.py | 24 +++-- export/stages.py | 74 +++++++++++---- export/tests/test_export_stages.py | 143 +++++++++++++++++++++++++++-- 4 files changed, 214 insertions(+), 37 deletions(-) diff --git a/export/export.py b/export/export.py index ab15067c561..86a932d153c 100644 --- a/export/export.py +++ b/export/export.py @@ -195,12 +195,12 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]: elif stage_type == StageType.QUANTIZE: stage = QuantizeStage(self._quant_recipe) elif stage_type == StageType.TORCH_EXPORT: - pre_edge_passes = None - if self._export_recipe.pre_edge_transform_passes is not None: - pre_edge_passes = list( - self._export_recipe.pre_edge_transform_passes + aten_transform_passes = None + if self._export_recipe.aten_transform_passes is not None: + aten_transform_passes = list( + self._export_recipe.aten_transform_passes ) - stage = TorchExportStage(pre_edge_passes) + stage = TorchExportStage(aten_transform_passes) elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER: stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe) elif stage_type == StageType.TO_EDGE: diff --git a/export/recipe.py b/export/recipe.py index 811270cdbf8..18f4b8aebb9 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -7,9 +7,10 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass from enum import Enum, EnumMeta -from typing import Callable, List, Optional, Sequence +from typing import Callable, List, Optional import torch +from executorch.exir import ExportedProgram from executorch.exir._warnings import experimental @@ -117,12 +118,15 @@ class LoweringRecipe: Attributes: partitioners: Optional list of partitioners for model partitioning - edge_transform_passes: Optional sequence of transformation passes to apply + edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram) as arguments + and return a list of passes (PassType) to be executed during lowering stages. edge_compile_config: Optional edge compilation configuration """ partitioners: Optional[List[Partitioner]] = None - edge_transform_passes: Optional[Sequence[PassType]] = None + edge_transform_passes: ( + None | List[Callable[[str, ExportedProgram], List[PassType]]] + ) = None # pyre-ignore[11]: Type not defined edge_compile_config: Optional[EdgeCompileConfig] = None @@ -141,8 +145,8 @@ class ExportRecipe: Attributes: name: Optional name for the recipe quantization_recipe: Optional quantization recipe for model quantization - pre_edge_transform_passes: Optional function to apply transformation passes - before edge lowering + aten_transform_passes: Optional list of functions to apply transformation passes to the program before edge lowering. + These callables are invoked to modify and return the transformed program. lowering_recipe: Optional lowering recipe for model lowering and partitioning executorch_backend_config: Optional backend configuration for ExecuTorch pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline. @@ -151,7 +155,9 @@ class ExportRecipe: name: Optional[str] = None quantization_recipe: Optional[QuantizationRecipe] = None - pre_edge_transform_passes: Optional[Sequence[PassType]] = None + aten_transform_passes: Optional[ + List[Callable[[str, ExportedProgram], ExportedProgram]] + ] = None lowering_recipe: Optional[LoweringRecipe] = None # pyre-ignore[11]: Type not defined executorch_backend_config: Optional[ExecutorchBackendConfig] = None @@ -240,8 +246,8 @@ def _combine_recipes( # noqa: C901 for recipe in backend_recipes: # Collect pre-edge transform passes - if recipe.pre_edge_transform_passes: - all_pre_edge_passes.extend(recipe.pre_edge_transform_passes) + if recipe.aten_transform_passes: + all_pre_edge_passes.extend(recipe.aten_transform_passes) # Collect partitioners from lowering recipes if recipe.lowering_recipe and recipe.lowering_recipe.partitioners: @@ -307,7 +313,7 @@ def _combine_recipes( # noqa: C901 return cls( name=recipe_name, quantization_recipe=combined_quantization_recipe, - pre_edge_transform_passes=all_pre_edge_passes, + aten_transform_passes=all_pre_edge_passes, lowering_recipe=combined_lowering_recipe, executorch_backend_config=combined_backend_config, ) diff --git a/export/stages.py b/export/stages.py index 7b583e27943..609e7d197b9 100644 --- a/export/stages.py +++ b/export/stages.py @@ -7,14 +7,14 @@ import copy import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Sequence +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional import torch from executorch.devtools.backend_debug import get_delegation_info -from executorch.exir import EdgeCompileConfig +from executorch.exir import EdgeCompileConfig, ExportedProgram from executorch.exir.backend.backend_api import validation_disabled from executorch.exir.program import to_edge, to_edge_transform_and_lower -from executorch.exir.program._program import _transform from executorch.export.recipe import LoweringRecipe, QuantizationRecipe from executorch.export.types import StageType from torch import nn @@ -107,10 +107,12 @@ class TorchExportStage(Stage): def __init__( self, - pre_edge_transform_passes: Optional[List[PassType]] = None, + aten_transform_passes: Optional[ + List[Callable[[str, ExportedProgram], ExportedProgram]] + ] = None, ) -> None: super().__init__() - self._pre_edge_transform_passes = pre_edge_transform_passes + self._aten_transform_passes = aten_transform_passes @property def stage_type(self) -> str: @@ -149,9 +151,13 @@ def run(self, artifact: PipelineArtifact) -> None: ) # Apply pre-edge transform passes if available - for pass_ in self._pre_edge_transform_passes or []: - exported_programs[method_name] = _transform( - exported_programs[method_name], pass_ + for pass_ in self._aten_transform_passes or []: + if not callable(pass_): + raise ValueError( + "Aten transform passes must be a callable that can transform and return an exported program" + ) + exported_programs[method_name] = pass_( + method_name, exported_programs[method_name] ) self._artifact = artifact.copy_with_new_data(exported_programs) @@ -165,7 +171,9 @@ class EdgeTransformAndLowerStage(Stage): def __init__( self, partitioners: Optional[List[Any]] = None, - transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None, + transform_passes: ( + None | List[Callable[[str, ExportedProgram], List[PassType]]] + ) = None, compile_config: Optional[Any] = None, ) -> None: self._partitioners = partitioners @@ -205,11 +213,28 @@ def run(self, artifact: PipelineArtifact) -> None: constant_methods = artifact.get_context("constant_methods") generate_etrecord = artifact.get_context("generate_etrecord", False) + # per method transform passes + transform_passes = defaultdict(list) + for method_name, ep in exported_programs.items(): + # Resolve transform passes from callable + for pass_ in self._transform_passes or []: + if not callable(pass_): + raise ValueError( + "Transform passes must be a callable that resolves to a list of passes" + ) + passes = pass_(method_name, ep) + if isinstance(passes, list): + transform_passes[method_name].extend(passes) + else: + raise ValueError( + "Transform passes must be a callable that resolves to a list of passes" + ) + with validation_disabled(): edge_program_manager = to_edge_transform_and_lower( exported_programs, partitioner=self._partitioners, - transform_passes=self._transform_passes, + transform_passes=transform_passes, constant_methods=constant_methods, compile_config=self._compile_config, generate_etrecord=generate_etrecord, @@ -396,7 +421,7 @@ def run(self, artifact: PipelineArtifact) -> None: captured_graph = torch.export.export(model, inputs, strict=True).module() quantizer = self._get_quantizer_for_prepare_pt2e( - self._quantization_recipe.quantizers + self._quantization_recipe.quantizers # pyre-ignore ) prepared_model = prepare_pt2e(captured_graph, quantizer) @@ -471,7 +496,9 @@ class ToBackendStage(Stage): def __init__( self, partitioners: Optional[List[Any]] = None, - transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None, + transform_passes: ( + None | List[Callable[[str, ExportedProgram], List[PassType]]] + ) = None, ) -> None: super().__init__() self._partitioners = partitioners @@ -513,11 +540,24 @@ def run(self, artifact: PipelineArtifact) -> None: if edge_program_manager is None: raise RuntimeError("Edge program manager is not set.") - # Apply transform passes if available - if self._transform_passes: - edge_program_manager = edge_program_manager.transform( - self._transform_passes - ) + # per method transform passes + transform_passes = defaultdict(list) + for method_name in edge_program_manager.methods: + # Resolve transform passes if it's a callable + ep = edge_program_manager.exported_program(method_name) + for pass_ in self._transform_passes or []: + if not callable(pass_): + raise ValueError( + "Transform passes must be a callable that resolves to a list of passes" + ) + passes = pass_(method_name, ep) + if isinstance(passes, list): + transform_passes[method_name].extend(passes) + else: + raise ValueError("Transform passes must return list of passes") + + # Apply transform passes + edge_program_manager = edge_program_manager.transform(transform_passes) # Apply partitioners if available if self._partitioners is not None and len(self._partitioners) > 0: diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index d4629a1aea7..608aa5adb3c 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -7,7 +7,7 @@ # pyre-strict import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, PropertyMock import torch from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager @@ -99,6 +99,66 @@ def test_get_artifacts_before_run(self) -> None: stage.get_artifacts() self.assertIn("Stage: TorchExportStage not executed", str(cm.exception)) + @patch("torch.export.export") + def test_export_stage_with_aten_transform_passes( + self, mock_torch_export: Mock + ) -> None: + """Test TorchExportStage with aten_transform_passes.""" + mock_exported_program = Mock(spec=ExportedProgram) + mock_transformed_program = Mock(spec=ExportedProgram) + mock_torch_export.return_value = mock_exported_program + + # Create a mock aten transform pass that we can verify + mock_aten_transform_pass = Mock() + mock_aten_transform_pass.return_value = mock_transformed_program + aten_transform_passes = [mock_aten_transform_pass] + + stage = TorchExportStage(aten_transform_passes=aten_transform_passes) + artifact = PipelineArtifact(data=self.models_dict, context=self.context) + + stage.run(artifact) + + # Verify torch.export.export was called + mock_torch_export.assert_called_once_with( + self.model, + self.example_inputs[0], + dynamic_shapes=None, + strict=True, + ) + + # Verify the aten transform pass was called with correct parameters + mock_aten_transform_pass.assert_called_once_with( + "forward", mock_exported_program + ) + + # Verify artifacts contain the transformed program + result_artifact = stage.get_artifacts() + self.assertIn("forward", result_artifact.data) + self.assertEqual(result_artifact.data["forward"], mock_transformed_program) + + @patch("torch.export.export") + def test_export_stage_invalid_aten_transform_pass( + self, mock_torch_export: Mock + ) -> None: + """Test TorchExportStage with invalid aten_transform_pass (not callable).""" + mock_exported_program = Mock(spec=ExportedProgram) + mock_torch_export.return_value = mock_exported_program + + # Use a non-callable object as transform pass + invalid_transform_pass = "not_callable" + aten_transform_passes = [invalid_transform_pass] + + # pyre-ignore + stage = TorchExportStage(aten_transform_passes=aten_transform_passes) + artifact = PipelineArtifact(data=self.models_dict, context=self.context) + + with self.assertRaises(ValueError) as cm: + stage.run(artifact) + self.assertIn( + "Aten transform passes must be a callable that can transform and return an exported program", + str(cm.exception), + ) + class TestEdgeTransformAndLowerStage(unittest.TestCase): def setUp(self) -> None: @@ -106,12 +166,32 @@ def setUp(self) -> None: self.exported_programs = {"forward": self.mock_exported_program} self.context = {"constant_methods": None} - def test_run_with_partitioners_and_config(self) -> None: + @patch("executorch.export.stages.to_edge_transform_and_lower") + @patch("executorch.export.stages.get_delegation_info") + def test_run_with_partitioners_and_config( + self, mock_get_delegation_info: Mock, mock_to_edge_transform_and_lower: Mock + ) -> None: """Test execution with partitioners and compile config""" + mock_delegation_info = {"delegation": "info"} + mock_get_delegation_info.return_value = mock_delegation_info + mock_partitioners = [Mock()] - mock_transform_passes = [Mock()] mock_compile_config = Mock() + # Create a mock transform pass callable that we can verify + mock_transform_pass = Mock() + mock_pass1 = Mock() + mock_pass2 = Mock() + mock_transform_pass.return_value = [mock_pass1, mock_pass2] + mock_transform_passes = [mock_transform_pass] + + mock_edge_program_manager = Mock(spec=EdgeProgramManager) + mock_exported_program = Mock() + mock_graph_module = Mock() + mock_exported_program.graph_module = mock_graph_module + mock_edge_program_manager.exported_program.return_value = mock_exported_program + mock_to_edge_transform_and_lower.return_value = mock_edge_program_manager + stage = EdgeTransformAndLowerStage( partitioners=mock_partitioners, transform_passes=mock_transform_passes, @@ -124,6 +204,33 @@ def test_run_with_partitioners_and_config(self) -> None: self.assertEqual(stage._transform_passes, mock_transform_passes) self.assertEqual(stage._compile_config, mock_compile_config) + # Test the run method + artifact = PipelineArtifact(data=self.exported_programs, context=self.context) + stage.run(artifact) + + # Verify the transform pass callable was called with correct parameters + mock_transform_pass.assert_called_once_with( + "forward", self.mock_exported_program + ) + + # Verify to_edge_transform_and_lower was called with the expected structure + expected_transform_passes = {"forward": [mock_pass1, mock_pass2]} + mock_to_edge_transform_and_lower.assert_called_once_with( + self.exported_programs, + partitioner=mock_partitioners, + transform_passes=expected_transform_passes, + constant_methods=None, + compile_config=mock_compile_config, + generate_etrecord=False, + ) + + # Verify artifacts are set correctly + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, mock_edge_program_manager) + self.assertEqual( + result_artifact.get_context("delegation_info"), mock_delegation_info + ) + class TestExecutorchStage(unittest.TestCase): def setUp(self) -> None: @@ -380,7 +487,10 @@ def test_run_success_no_transforms_or_partitioners( mock_exported_program = Mock() mock_graph_module = Mock() mock_exported_program.graph_module = mock_graph_module + + self.mock_edge_manager.transform.return_value = self.mock_edge_manager self.mock_edge_manager.exported_program.return_value = mock_exported_program + self.mock_edge_manager.methods = {"forward"} stage = ToBackendStage() artifact = PipelineArtifact(data=self.mock_edge_manager, context=self.context) @@ -409,9 +519,21 @@ def test_run_with_partitioners_and_passes( mock_edge_program_manager = Mock(spec=EdgeProgramManager) mock_edge_program_manager.transform.return_value = mock_edge_program_manager mock_edge_program_manager.to_backend.return_value = mock_edge_program_manager + mock_edge_program_manager.exported_program.return_value = mock_exported_program + + # Use PropertyMock for the methods property + methods_property_mock = PropertyMock(return_value={"forward"}) + type(mock_edge_program_manager).methods = methods_property_mock mock_partitioner = Mock() - mock_transform_passes = [Mock(), Mock()] + + # Create a mock transform pass callable that we can verify + mock_transform_pass = Mock() + mock_pass1 = Mock() + mock_pass2 = Mock() + mock_transform_pass.return_value = [mock_pass1, mock_pass2] + mock_transform_passes = [mock_transform_pass] + stage = ToBackendStage( partitioners=[mock_partitioner], transform_passes=mock_transform_passes ) @@ -420,10 +542,19 @@ def test_run_with_partitioners_and_passes( ) stage.run(artifact) - # Verify transform and to_backend called correctly + # Verify that the methods property was accessed + methods_property_mock.assert_called_once() + + # Verify the transform pass callable was called with correct parameters + mock_transform_pass.assert_called_once_with("forward", mock_exported_program) + + # Verify transform was called with the expected structure + expected_transform_passes = {"forward": [mock_pass1, mock_pass2]} mock_edge_program_manager.transform.assert_called_once_with( - mock_transform_passes + expected_transform_passes ) + + # Verify to_backend called correctly mock_edge_program_manager.to_backend.assert_called_once_with(mock_partitioner) # Verify artifacts contain the backend manager From 610219f73866265d240836e69d70defa949fa7d6 Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Wed, 10 Sep 2025 09:35:07 -0700 Subject: [PATCH 2/2] [Executorch][QNN Recipes] Introduce QNN fp16 recipe Pull Request resolved: https://github.com/pytorch/executorch/pull/14126 This diff adds QNN fp16 recipe and tests using htp simulator. Fixes: #13101 ghstack-source-id: 308796609 Differential Revision: [D81945971](https://our.internmc.facebook.com/intern/diff/D81945971/) --- backends/qualcomm/_passes/TARGETS | 1 + backends/qualcomm/recipes/TARGETS | 53 ++++++ backends/qualcomm/recipes/__init__.py | 16 ++ .../qualcomm/recipes/qnn_recipe_provider.py | 179 ++++++++++++++++++ backends/qualcomm/recipes/qnn_recipe_types.py | 28 +++ 5 files changed, 277 insertions(+) create mode 100644 backends/qualcomm/recipes/TARGETS create mode 100644 backends/qualcomm/recipes/__init__.py create mode 100644 backends/qualcomm/recipes/qnn_recipe_provider.py create mode 100644 backends/qualcomm/recipes/qnn_recipe_types.py diff --git a/backends/qualcomm/_passes/TARGETS b/backends/qualcomm/_passes/TARGETS index a824ca9f6e5..62a0fc43a78 100644 --- a/backends/qualcomm/_passes/TARGETS +++ b/backends/qualcomm/_passes/TARGETS @@ -12,6 +12,7 @@ runtime.python_library( ], deps = [ "//executorch/backends/transforms:addmm_mm_to_linear", + "//executorch/backends/transforms:decompose_sdpa", "//executorch/exir/backend:backend_details", "//executorch/exir/backend:compile_spec_schema", ], diff --git a/backends/qualcomm/recipes/TARGETS b/backends/qualcomm/recipes/TARGETS new file mode 100644 index 00000000000..12d1bac6f12 --- /dev/null +++ b/backends/qualcomm/recipes/TARGETS @@ -0,0 +1,53 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "qnn_recipes", + srcs = [ + "__init__.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/export:recipe_registry", + ":qnn_recipe_provider", + ":qnn_recipe_types", + ], +) + +runtime.python_library( + name = "qnn_recipe_provider", + srcs = [ + "qnn_recipe_provider.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//caffe2:torch", + "//executorch/export:lib", + "//executorch/backends/qualcomm/partition:partition", + "//executorch/backends/qualcomm/serialization:serialization", + "//executorch/backends/qualcomm/utils:utils", + "//executorch/backends/qualcomm/_passes:passes", + ":qnn_recipe_types", + ], +) + +runtime.python_library( + name = "qnn_recipe_types", + srcs = [ + "qnn_recipe_types.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/export:lib", + ], +) diff --git a/backends/qualcomm/recipes/__init__.py b/backends/qualcomm/recipes/__init__.py new file mode 100644 index 00000000000..ee0985584d6 --- /dev/null +++ b/backends/qualcomm/recipes/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""QNN Recipe module for ExecuTorch""" +from executorch.export import recipe_registry + +from .qnn_recipe_provider import QNNRecipeProvider +from .qnn_recipe_types import QNNRecipeType + +# Auto-register XNNPACK recipe provider +recipe_registry.register_backend_recipe_provider(QNNRecipeProvider()) + +__all__ = ["QNNRecipeProvider", "QNNRecipeType"] diff --git a/backends/qualcomm/recipes/qnn_recipe_provider.py b/backends/qualcomm/recipes/qnn_recipe_provider.py new file mode 100644 index 00000000000..fcfab0c3bd1 --- /dev/null +++ b/backends/qualcomm/recipes/qnn_recipe_provider.py @@ -0,0 +1,179 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +from typing import Any, Optional, Sequence + +from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager +from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner +from executorch.backends.qualcomm.recipes.qnn_recipe_types import ( + QNN_BACKEND, + QNNRecipeType, +) +from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset +from executorch.backends.qualcomm.utils.utils import ( + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + get_soc_to_chipset_map, + qnn_edge_config, +) +from executorch.export import ( + BackendRecipeProvider, + ExportRecipe, + LoweringRecipe, + RecipeType, +) + + +class QNNRecipeProvider(BackendRecipeProvider): + @property + def backend_name(self) -> str: + return QNN_BACKEND + + def get_supported_recipes(self) -> Sequence[RecipeType]: + return list(QNNRecipeType) + + def create_recipe( + self, recipe_type: RecipeType, **kwargs: Any + ) -> Optional[ExportRecipe]: + """Create QNN recipe for different precisions and SoC targets""" + + if recipe_type not in self.get_supported_recipes(): + return None + + self._validate_recipe_kwargs(recipe_type, kwargs) + + if recipe_type == QNNRecipeType.FP16: + return self._build_fp16_recipe(recipe_type, kwargs) + + return None + + def _validate_recipe_kwargs(self, recipe_type: RecipeType, kwargs: Any) -> None: + """Validate kwargs for each recipe type""" + expected_keys = self._get_expected_keys(recipe_type) + + unexpected = set(kwargs.keys()) - expected_keys + if unexpected: + logging.warning( + f"QNN Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}, ignoring them" + ) + + self._validate_soc_parameter(kwargs) + self._validate_partitioner_parameters(kwargs) + + def _get_expected_keys(self, recipe_type: RecipeType) -> set: + """Get expected parameter keys for a recipe type""" + _ = recipe_type + common_keys = { + "soc_model", + "skip_node_id_set", + "skip_node_op_set", + "skip_mutable_buffer", + } + return common_keys + + def _validate_soc_parameter(self, kwargs: Any) -> None: + """Validate soc_model parameter""" + if "soc_model" in kwargs: + soc_model = kwargs["soc_model"] + if isinstance(soc_model, str): + try: + soc_model = get_soc_to_chipset_map()[soc_model] + kwargs["soc_model"] = soc_model + except KeyError: + raise ValueError( + f"Invalid SoC model '{soc_model}'. Supported models: {[e.name for e in get_soc_to_chipset_map()]}" + ) + elif not isinstance(soc_model, QcomChipset): + raise ValueError( + f"Parameter 'soc_model' must be a QcomChipset enum or string, got {type(soc_model)}" + ) + else: + raise ValueError("Parameter 'soc_model' is required") + + def _validate_partitioner_parameters(self, kwargs: Any) -> None: + """Validate partitioner parameters""" + if "skip_node_id_set" in kwargs: + skip_node_id_set = kwargs["skip_node_id_set"] + if skip_node_id_set is not None and not isinstance(skip_node_id_set, set): + raise ValueError( + f"Parameter 'skip_node_id_set' must be a set or None, got {type(skip_node_id_set)}" + ) + + if "skip_node_op_set" in kwargs: + skip_node_op_set = kwargs["skip_node_op_set"] + if skip_node_op_set is not None and not isinstance(skip_node_op_set, set): + raise ValueError( + f"Parameter 'skip_node_op_set' must be a set or None, got {type(skip_node_op_set)}" + ) + + if "skip_mutable_buffer" in kwargs: + skip_mutable_buffer = kwargs["skip_mutable_buffer"] + if not isinstance(skip_mutable_buffer, bool): + raise ValueError( + f"Parameter 'skip_mutable_buffer' must be a boolean, got {type(skip_mutable_buffer)}" + ) + + def _build_fp16_recipe( + self, + recipe_type: RecipeType, + kwargs: Any, + ) -> ExportRecipe: + soc_model = kwargs["soc_model"] + skip_node_id_set = kwargs.get("skip_node_id_set", None) + skip_node_op_set = kwargs.get("skip_node_op_set", None) + skip_mutable_buffer = kwargs.get("skip_mutable_buffer", False) + + lowering_recipe = self._get_qnn_lowering_recipe( + use_fp16=True, + soc_model=soc_model, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + skip_mutable_buffer=skip_mutable_buffer, + ) + + return ExportRecipe( + name=recipe_type.value, + aten_transform_passes=[ + lambda method_, ep: QnnPassManager().transform_for_export_pipeline(ep) + ], + lowering_recipe=lowering_recipe, + ) + + def _get_qnn_lowering_recipe( + self, + use_fp16: bool, + soc_model: QcomChipset, + skip_node_id_set: Optional[set] = None, + skip_node_op_set: Optional[set] = None, + skip_mutable_buffer: bool = False, + ) -> LoweringRecipe: + """Get QNN lowering recipe with optional precision and SoC target""" + backend_options = generate_htp_compiler_spec(use_fp16=use_fp16) + + compile_specs = generate_qnn_executorch_compiler_spec( + soc_model=soc_model, + backend_options=backend_options, + ) + + partitioner = QnnPartitioner( + compiler_specs=compile_specs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + skip_mutable_buffer=skip_mutable_buffer, + ) + + edge_compile_config = qnn_edge_config() + + return LoweringRecipe( + partitioners=[partitioner], + edge_transform_passes=[ + lambda method_, ep: QnnPassManager().get_to_edge_transform_passes(ep) + ], + edge_compile_config=edge_compile_config, + ) diff --git a/backends/qualcomm/recipes/qnn_recipe_types.py b/backends/qualcomm/recipes/qnn_recipe_types.py new file mode 100644 index 00000000000..f41f494fce6 --- /dev/null +++ b/backends/qualcomm/recipes/qnn_recipe_types.py @@ -0,0 +1,28 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from executorch.export import RecipeType + + +QNN_BACKEND: str = "qnn" + + +class QNNRecipeType(RecipeType): + """QNN-specific recipe types""" + + # FP16 precision recipe, accepts kwargs: + # 1. soc_model + # 2. skip_node_id_set + # 3. skip_node_op_set + # 4. skip_mutable_buffer + + FP16 = "qnn_fp16" + + @classmethod + def get_backend_name(cls) -> str: + return QNN_BACKEND