From 04f413af18be346b658d06dacde3c4fd694fa5e2 Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Mon, 22 Sep 2025 12:51:35 -0700 Subject: [PATCH] Fix op decomposition issue when multiple partitioners with conflicting expectations are run (#14458) Summary: Context: I'm trying to enable sequential recipes targeting multiple backends (such as `CoreML.FP32 + XNNPack.FP32`, here xnnpack will be a fallback for the ops) and i don't think we've ever tested lowering a model to multiple backends or this edge case has never been hit. I've hit this case when i tried to lower vision transformer model (VIT) I've seen a similar problem occurred with SDPA op (discussed [here](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/)) and has been fixed and i think the fix works but it didn't considered when there multiple partitioners in mind with *conflicting decomposition requirements and filter for op no decomp namespace*. ## Error with coreml + xnnpack ``` [2025-09-22T11:01:47.263-07:00] ValueError: Cannot view a tensor with shape torch.Size([197, 1, 12, 64]) and strides (64, 151296, 12608, 1) as a tensor with shape (197, 768)! [2025-09-22T11:01:47.263-07:00] [2025-09-22T11:01:47.263-07:00] While executing %view_8 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%permute_6, [197, 768]), kwargs = {}) [2025-09-22T11:01:47.263-07:00] Original traceback: [2025-09-22T11:01:47.263-07:00] File "/data/users/abhinayk/fbsource/buck-out/v2/gen/fbcode/afd2a63214a057a8/executorch/export/tests/__test_target_recipes__/test_target_recipes#link-tree/torchvision/models/vision_transformer.py", line 298, in forward [2025-09-22T11:01:47.263-07:00] x = self.encoder(x) [2025-09-22T11:01:47.263-07:00] File "/data/users/abhinayk/fbsource/buck-out/v2/gen/fbcode/afd2a63214a057a8/executorch/export/tests/__test_target_recipes__/test_target_recipes#link-tree/torchvision/models/vision_transformer.py", line 157, in forward [2025-09-22T11:01:47.263-07:00] return self.ln(self.layers(self.dropout(input))) [2025-09-22T11:01:47.263-07:00] File "/data/users/abhinayk/fbsource/buck-out/v2/gen/fbcode/afd2a63214a057a8/executorch/export/tests/__test_target_recipes__/test_target_recipes#link-tree/torchvision/models/vision_transformer.py", line 113, in forward [2025-09-22T11:01:47.263-07:00] x, _ = self.self_attention(x, x, x, need_weights=False) ``` ## Error with QNN + XNNPACK The core problem here is that if there are two partitioners, Backend A partitioner (asking to preserve ops x, y) and Backend B partitioner (asking to preserve ops z), assume Backend A doesn't understand Z and want to decompose, currently we `union` all the ops to preserve from multiple partitioner and it errors out In this specific case, xnnpack asks to preserve `aten.max_pool2d` but QNN doesn't understand it. ``` [2025-09-22T12:18:33.640-07:00] partition_list = capability_partitioner.propose_partitions() [2025-09-22T12:18:33.640-07:00] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [2025-09-22T12:18:33.640-07:00] File "/data/users/abhinayk/fbsource/buck-out/v2/gen/fbcode/afd2a63214a057a8/executorch/export/tests/__test_target_recipes__/test_target_recipes#link-tree/torch/fx/passes/infra/partitioner.py", line 226, in propose_partitions [2025-09-22T12:18:33.640-07:00] if self._is_node_supported(node) and node not in assignment: [2025-09-22T12:18:33.640-07:00] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [2025-09-22T12:18:33.640-07:00] File "/data/users/abhinayk/fbsource/buck-out/v2/gen/fbcode/afd2a63214a057a8/executorch/export/tests/__test_target_recipes__/test_target_recipes#link-tree/torch/fx/passes/infra/partitioner.py", line 87, in _is_node_supported [2025-09-22T12:18:33.640-07:00] return self.operator_support.is_node_supported( [2025-09-22T12:18:33.640-07:00] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [2025-09-22T12:18:33.640-07:00] File "/data/users/abhinayk/fbsource/buck-out/v2/gen/fbcode/afd2a63214a057a8/executorch/export/tests/__test_target_recipes__/test_target_recipes#link-tree/executorch/backends/qualcomm/partition/qnn_partitioner.py", line 100, in is_node_supported [2025-09-22T12:18:33.640-07:00] op_wrapper = self.node_visitors[node.target.__name__].define_node( [2025-09-22T12:18:33.640-07:00] ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^ [2025-09-22T12:18:33.640-07:00] KeyError: 'aten.max_pool2d.default' ``` **Note**: Lowering to single backend works with both coreml or xnnpack or QNN, it is the combinations that hits this error. Changes: - The fix is to run decomposition filtering and skipping per partitioner rather than maintaining same rule in global scope. - I've additionally refactored the code to make it more readable by removing multiple boolean checks. Differential Revision: D82936479 --- exir/program/_program.py | 117 +++++++++++++++------------- export/target_recipes.py | 4 +- export/tests/test_target_recipes.py | 2 +- 3 files changed, 64 insertions(+), 59 deletions(-) diff --git a/exir/program/_program.py b/exir/program/_program.py index a33d715ca3b..9298eb3e88d 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -11,6 +11,7 @@ import io import logging import os +from collections import defaultdict from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Type, Union import torch @@ -1136,7 +1137,7 @@ def keep(op): def _can_skip_using_EDGE_DO_NOT_DECOMP( - partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram] + partitioner: Partitioner, program: ExportedProgram ) -> bool: # THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition # has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to @@ -1144,17 +1145,8 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP( # and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/ # EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support # As a temp fix, we give a more reliable path for backends that do not specify check_op_support - can_skip_using_EDGE_DO_NOT_DECOMP = True - for name, program in aten_programs.items(): - if partitioner is not None: - for curr_partitioner in partitioner.get(name, []): - ( - curr_ops_no_decomp, - check_op_support, - ) = curr_partitioner.ops_to_not_decompose(program) - if check_op_support is not None: - can_skip_using_EDGE_DO_NOT_DECOMP = False - return can_skip_using_EDGE_DO_NOT_DECOMP + _, check_op_support = partitioner.ops_to_not_decompose(program) + return check_op_support is None def _gen_edge_manager_for_partitioners( @@ -1177,60 +1169,75 @@ def _gen_edge_manager_for_partitioners( on nodes with preserved aten targets. They are then replaces with transformed ops to keep them through the second pass of decompositions """ - can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP( - partitioner, aten_programs - ) - ops_set_to_not_decompose_by_program = {} + ops_set_to_not_decompose_by_program = defaultdict(list) edge_programs: Dict[str, ExportedProgram] = {} for name, program in aten_programs.items(): # Functionalize program before asking partitioners to preserve ops program = program.run_decompositions({}) if partitioner is not None: - # preserve all ops listed by all partitioners first - all_ops_no_decomp = set() - all_ops_no_decomp_needing_preservation = [] - for curr_partitioner in partitioner.get(name, []): + partitioners_for_program = partitioner.get(name, []) + final_ops_to_preserve = set() + + # Decompose by default if there are no partitioners for the method + if not partitioners_for_program: + program = program.run_decompositions(_default_decomposition_table()) + + # Process each partitioner individually using their specific requirements + for curr_partitioner in partitioners_for_program: curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) - all_ops_no_decomp |= set(curr_ops_no_decomp) - # If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops - # Otherwise there will be issues - if not can_skip_using_EDGE_DO_NOT_DECOMP: - all_ops_no_decomp = _remove_invalid_ops_for_not_decompose( - list(all_ops_no_decomp) - ) - all_ops_no_decomp = set(all_ops_no_decomp) - - # Run default decompositions, except for those in all_ops_no_decomp - table = _default_decomposition_table() - for op in all_ops_no_decomp: - if table.pop(op, None) is not None: - all_ops_no_decomp_needing_preservation.append(op) - program = program.run_decompositions(table) - - # Among all the preserved aten ops, use the check_op_fn to do an additional - # check on which ops need to be preserved and which ops need to be decomposed - # Those which are truly preserved will be replaced with transformed ops - if can_skip_using_EDGE_DO_NOT_DECOMP: - ops_set_to_not_decompose_by_program[name] = ( - all_ops_no_decomp_needing_preservation - ) - else: - ops_set_to_not_decompose_by_program[name] = ( - _replace_aten_ops_with_transformed_ops(name, program, partitioner) - or [] + # Check if this partitioner can skip using EDGE_DO_NOT_DECOMP + can_skip_using_edge_do_not_decomp = _can_skip_using_EDGE_DO_NOT_DECOMP( + curr_partitioner, program ) - if not can_skip_using_EDGE_DO_NOT_DECOMP: - program = program.run_decompositions(_default_decomposition_table()) - _restore_transformed_ops_to_aten_ops(program) + if can_skip_using_edge_do_not_decomp: + # Preserve all ops in curr_ops_no_decomp from decomposition + table = _default_decomposition_table() + ops_needing_preservation = [] + + for op in curr_ops_no_decomp: + if table.pop(op, None) is not None: + ops_needing_preservation.append(op) + + program = program.run_decompositions(table) + final_ops_to_preserve.update(ops_needing_preservation) + else: + # EDGE_DO_NOT_DECOMP path for the partitioner + curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose( + curr_ops_no_decomp + ) + + # Apply decompositions with this partitioner's preserved ops + table = _default_decomposition_table() + for op in curr_ops_no_decomp: + table.pop(op, None) + + # First pass of decompositions with this partitioner's preserved ops + program = program.run_decompositions(table) + + # Filter ops using EDGE_DO_NOT_DECOMP + temp_partitioner_dict = {name: [curr_partitioner]} + preserved_ops = ( + _replace_aten_ops_with_transformed_ops( + name, program, temp_partitioner_dict + ) + or [] + ) + final_ops_to_preserve.update(preserved_ops) + + # Second pass of decompositions with this partitioner's preserved ops after filtering + program = program.run_decompositions(_default_decomposition_table()) + + # Restore ops from edge_no_decomp_namespace to aten ops + _restore_transformed_ops_to_aten_ops(program) + ops_set_to_not_decompose_by_program[name].extend(final_ops_to_preserve) - edge_programs[name] = program edge_programs[name] = _generate_edge_program( config, program, - preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])), + preserve_ops=ops_set_to_not_decompose_by_program.get(name, []), ) edge_manager = EdgeProgramManager( @@ -1349,9 +1356,6 @@ def to_edge_transform_and_lower( # noqa: C901 elif partitioner is None: partitioner = {name: [] for name in aten_programs.keys()} - can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP( - partitioner, aten_programs - ) edge_manager = _gen_edge_manager_for_partitioners( partitioner, aten_programs, config, constant_methods, generate_etrecord ) @@ -1377,7 +1381,8 @@ def to_edge_transform_and_lower( # noqa: C901 curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose( program ) - if not can_skip_using_EDGE_DO_NOT_DECOMP: + + if not _can_skip_using_EDGE_DO_NOT_DECOMP(curr_partitioner, program): curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set) ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set) _sanity_check_graph_for_non_decomp_ops( diff --git a/export/target_recipes.py b/export/target_recipes.py index 2d2eba46b0a..eac35c08bf7 100644 --- a/export/target_recipes.py +++ b/export/target_recipes.py @@ -93,7 +93,7 @@ def get_ios_recipe( # pyre-ignore "ios-arm64-coreml-fp32": [CoreMLRecipeType.FP32, XNNPackRecipeType.FP32], # pyre-ignore - "ios-arm64-coreml-fp16": [CoreMLRecipeType.FP16], + "ios-arm64-coreml-fp16": [CoreMLRecipeType.FP16, XNNPackRecipeType.FP32], # pyre-ignore "ios-arm64-coreml-int8": [CoreMLRecipeType.PT2E_INT8_STATIC], } @@ -165,7 +165,7 @@ def get_android_recipe( android_configs: Dict[str, List[RecipeType]] = { # pyre-ignore - "android-arm64-snapdragon-fp16": [QNNRecipeType.FP16], + "android-arm64-snapdragon-fp16": [QNNRecipeType.FP16, XNNPackRecipeType.FP32], } if target_config not in android_configs: diff --git a/export/tests/test_target_recipes.py b/export/tests/test_target_recipes.py index 61725e58f3a..48f7dfc67db 100644 --- a/export/tests/test_target_recipes.py +++ b/export/tests/test_target_recipes.py @@ -387,7 +387,7 @@ def _get_model_test_configs( @classmethod def _get_recipes(cls) -> Dict[str, Tuple[ExportRecipe, str]]: """Get available recipes with their configurations based on platform.""" - all_recipes = {} + all_recipes: Dict[str, Tuple[ExportRecipe, str]] = {} # Add iOS recipes if is_supported_platform_for_coreml_lowering():