diff --git a/exir/program/_program.py b/exir/program/_program.py index a33d715ca3b..9298eb3e88d 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -11,6 +11,7 @@ import io import logging import os +from collections import defaultdict from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Type, Union import torch @@ -1136,7 +1137,7 @@ def keep(op): def _can_skip_using_EDGE_DO_NOT_DECOMP( - partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram] + partitioner: Partitioner, program: ExportedProgram ) -> bool: # THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition # has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to @@ -1144,17 +1145,8 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP( # and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/ # EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support # As a temp fix, we give a more reliable path for backends that do not specify check_op_support - can_skip_using_EDGE_DO_NOT_DECOMP = True - for name, program in aten_programs.items(): - if partitioner is not None: - for curr_partitioner in partitioner.get(name, []): - ( - curr_ops_no_decomp, - check_op_support, - ) = curr_partitioner.ops_to_not_decompose(program) - if check_op_support is not None: - can_skip_using_EDGE_DO_NOT_DECOMP = False - return can_skip_using_EDGE_DO_NOT_DECOMP + _, check_op_support = partitioner.ops_to_not_decompose(program) + return check_op_support is None def _gen_edge_manager_for_partitioners( @@ -1177,60 +1169,75 @@ def _gen_edge_manager_for_partitioners( on nodes with preserved aten targets. They are then replaces with transformed ops to keep them through the second pass of decompositions """ - can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP( - partitioner, aten_programs - ) - ops_set_to_not_decompose_by_program = {} + ops_set_to_not_decompose_by_program = defaultdict(list) edge_programs: Dict[str, ExportedProgram] = {} for name, program in aten_programs.items(): # Functionalize program before asking partitioners to preserve ops program = program.run_decompositions({}) if partitioner is not None: - # preserve all ops listed by all partitioners first - all_ops_no_decomp = set() - all_ops_no_decomp_needing_preservation = [] - for curr_partitioner in partitioner.get(name, []): + partitioners_for_program = partitioner.get(name, []) + final_ops_to_preserve = set() + + # Decompose by default if there are no partitioners for the method + if not partitioners_for_program: + program = program.run_decompositions(_default_decomposition_table()) + + # Process each partitioner individually using their specific requirements + for curr_partitioner in partitioners_for_program: curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) - all_ops_no_decomp |= set(curr_ops_no_decomp) - # If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops - # Otherwise there will be issues - if not can_skip_using_EDGE_DO_NOT_DECOMP: - all_ops_no_decomp = _remove_invalid_ops_for_not_decompose( - list(all_ops_no_decomp) - ) - all_ops_no_decomp = set(all_ops_no_decomp) - - # Run default decompositions, except for those in all_ops_no_decomp - table = _default_decomposition_table() - for op in all_ops_no_decomp: - if table.pop(op, None) is not None: - all_ops_no_decomp_needing_preservation.append(op) - program = program.run_decompositions(table) - - # Among all the preserved aten ops, use the check_op_fn to do an additional - # check on which ops need to be preserved and which ops need to be decomposed - # Those which are truly preserved will be replaced with transformed ops - if can_skip_using_EDGE_DO_NOT_DECOMP: - ops_set_to_not_decompose_by_program[name] = ( - all_ops_no_decomp_needing_preservation - ) - else: - ops_set_to_not_decompose_by_program[name] = ( - _replace_aten_ops_with_transformed_ops(name, program, partitioner) - or [] + # Check if this partitioner can skip using EDGE_DO_NOT_DECOMP + can_skip_using_edge_do_not_decomp = _can_skip_using_EDGE_DO_NOT_DECOMP( + curr_partitioner, program ) - if not can_skip_using_EDGE_DO_NOT_DECOMP: - program = program.run_decompositions(_default_decomposition_table()) - _restore_transformed_ops_to_aten_ops(program) + if can_skip_using_edge_do_not_decomp: + # Preserve all ops in curr_ops_no_decomp from decomposition + table = _default_decomposition_table() + ops_needing_preservation = [] + + for op in curr_ops_no_decomp: + if table.pop(op, None) is not None: + ops_needing_preservation.append(op) + + program = program.run_decompositions(table) + final_ops_to_preserve.update(ops_needing_preservation) + else: + # EDGE_DO_NOT_DECOMP path for the partitioner + curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose( + curr_ops_no_decomp + ) + + # Apply decompositions with this partitioner's preserved ops + table = _default_decomposition_table() + for op in curr_ops_no_decomp: + table.pop(op, None) + + # First pass of decompositions with this partitioner's preserved ops + program = program.run_decompositions(table) + + # Filter ops using EDGE_DO_NOT_DECOMP + temp_partitioner_dict = {name: [curr_partitioner]} + preserved_ops = ( + _replace_aten_ops_with_transformed_ops( + name, program, temp_partitioner_dict + ) + or [] + ) + final_ops_to_preserve.update(preserved_ops) + + # Second pass of decompositions with this partitioner's preserved ops after filtering + program = program.run_decompositions(_default_decomposition_table()) + + # Restore ops from edge_no_decomp_namespace to aten ops + _restore_transformed_ops_to_aten_ops(program) + ops_set_to_not_decompose_by_program[name].extend(final_ops_to_preserve) - edge_programs[name] = program edge_programs[name] = _generate_edge_program( config, program, - preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])), + preserve_ops=ops_set_to_not_decompose_by_program.get(name, []), ) edge_manager = EdgeProgramManager( @@ -1349,9 +1356,6 @@ def to_edge_transform_and_lower( # noqa: C901 elif partitioner is None: partitioner = {name: [] for name in aten_programs.keys()} - can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP( - partitioner, aten_programs - ) edge_manager = _gen_edge_manager_for_partitioners( partitioner, aten_programs, config, constant_methods, generate_etrecord ) @@ -1377,7 +1381,8 @@ def to_edge_transform_and_lower( # noqa: C901 curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose( program ) - if not can_skip_using_EDGE_DO_NOT_DECOMP: + + if not _can_skip_using_EDGE_DO_NOT_DECOMP(curr_partitioner, program): curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set) ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set) _sanity_check_graph_for_non_decomp_ops( diff --git a/export/target_recipes.py b/export/target_recipes.py index 2d2eba46b0a..eac35c08bf7 100644 --- a/export/target_recipes.py +++ b/export/target_recipes.py @@ -93,7 +93,7 @@ def get_ios_recipe( # pyre-ignore "ios-arm64-coreml-fp32": [CoreMLRecipeType.FP32, XNNPackRecipeType.FP32], # pyre-ignore - "ios-arm64-coreml-fp16": [CoreMLRecipeType.FP16], + "ios-arm64-coreml-fp16": [CoreMLRecipeType.FP16, XNNPackRecipeType.FP32], # pyre-ignore "ios-arm64-coreml-int8": [CoreMLRecipeType.PT2E_INT8_STATIC], } @@ -165,7 +165,7 @@ def get_android_recipe( android_configs: Dict[str, List[RecipeType]] = { # pyre-ignore - "android-arm64-snapdragon-fp16": [QNNRecipeType.FP16], + "android-arm64-snapdragon-fp16": [QNNRecipeType.FP16, XNNPackRecipeType.FP32], } if target_config not in android_configs: diff --git a/export/tests/test_target_recipes.py b/export/tests/test_target_recipes.py index 61725e58f3a..48f7dfc67db 100644 --- a/export/tests/test_target_recipes.py +++ b/export/tests/test_target_recipes.py @@ -387,7 +387,7 @@ def _get_model_test_configs( @classmethod def _get_recipes(cls) -> Dict[str, Tuple[ExportRecipe, str]]: """Get available recipes with their configurations based on platform.""" - all_recipes = {} + all_recipes: Dict[str, Tuple[ExportRecipe, str]] = {} # Add iOS recipes if is_supported_platform_for_coreml_lowering():