Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 61 additions & 56 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import io
import logging
import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Type, Union

import torch
Expand Down Expand Up @@ -1136,25 +1137,16 @@ def keep(op):


def _can_skip_using_EDGE_DO_NOT_DECOMP(
partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram]
partitioner: Partitioner, program: ExportedProgram
) -> bool:
# THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
# has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
# fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
# and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
# EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
# As a temp fix, we give a more reliable path for backends that do not specify check_op_support
can_skip_using_EDGE_DO_NOT_DECOMP = True
for name, program in aten_programs.items():
if partitioner is not None:
for curr_partitioner in partitioner.get(name, []):
(
curr_ops_no_decomp,
check_op_support,
) = curr_partitioner.ops_to_not_decompose(program)
if check_op_support is not None:
can_skip_using_EDGE_DO_NOT_DECOMP = False
return can_skip_using_EDGE_DO_NOT_DECOMP
_, check_op_support = partitioner.ops_to_not_decompose(program)
return check_op_support is None


def _gen_edge_manager_for_partitioners(
Expand All @@ -1177,60 +1169,75 @@ def _gen_edge_manager_for_partitioners(
on nodes with preserved aten targets. They are then replaces with transformed ops to
keep them through the second pass of decompositions
"""
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
partitioner, aten_programs
)
ops_set_to_not_decompose_by_program = {}
ops_set_to_not_decompose_by_program = defaultdict(list)
edge_programs: Dict[str, ExportedProgram] = {}
for name, program in aten_programs.items():
# Functionalize program before asking partitioners to preserve ops
program = program.run_decompositions({})

if partitioner is not None:
# preserve all ops listed by all partitioners first
all_ops_no_decomp = set()
all_ops_no_decomp_needing_preservation = []
for curr_partitioner in partitioner.get(name, []):
partitioners_for_program = partitioner.get(name, [])
final_ops_to_preserve = set()

# Decompose by default if there are no partitioners for the method
if not partitioners_for_program:
program = program.run_decompositions(_default_decomposition_table())

# Process each partitioner individually using their specific requirements
for curr_partitioner in partitioners_for_program:
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
all_ops_no_decomp |= set(curr_ops_no_decomp)

# If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
# Otherwise there will be issues
if not can_skip_using_EDGE_DO_NOT_DECOMP:
all_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
list(all_ops_no_decomp)
)
all_ops_no_decomp = set(all_ops_no_decomp)

# Run default decompositions, except for those in all_ops_no_decomp
table = _default_decomposition_table()
for op in all_ops_no_decomp:
if table.pop(op, None) is not None:
all_ops_no_decomp_needing_preservation.append(op)
program = program.run_decompositions(table)

# Among all the preserved aten ops, use the check_op_fn to do an additional
# check on which ops need to be preserved and which ops need to be decomposed
# Those which are truly preserved will be replaced with transformed ops
if can_skip_using_EDGE_DO_NOT_DECOMP:
ops_set_to_not_decompose_by_program[name] = (
all_ops_no_decomp_needing_preservation
)
else:
ops_set_to_not_decompose_by_program[name] = (
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
or []
# Check if this partitioner can skip using EDGE_DO_NOT_DECOMP
can_skip_using_edge_do_not_decomp = _can_skip_using_EDGE_DO_NOT_DECOMP(
curr_partitioner, program
)

if not can_skip_using_EDGE_DO_NOT_DECOMP:
program = program.run_decompositions(_default_decomposition_table())
_restore_transformed_ops_to_aten_ops(program)
if can_skip_using_edge_do_not_decomp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but there is still a logical bug in ET's AOT code with the EDGE_DO_NOT_DECOMP namespace / preservation when it comes to SDPA (and maybe other ops).

For CoreML, we get around it by skipping that path, but other backends (e.g., XNNPACK or QNN) will run into it if they preserve SDPA.

Perhaps the runtime time could take a look at this issue? cc @JacobSzwejbka @larryliu0820

# Preserve all ops in curr_ops_no_decomp from decomposition
table = _default_decomposition_table()
ops_needing_preservation = []

for op in curr_ops_no_decomp:
if table.pop(op, None) is not None:
ops_needing_preservation.append(op)

program = program.run_decompositions(table)
final_ops_to_preserve.update(ops_needing_preservation)
else:
# EDGE_DO_NOT_DECOMP path for the partitioner
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
curr_ops_no_decomp
)

# Apply decompositions with this partitioner's preserved ops
table = _default_decomposition_table()
for op in curr_ops_no_decomp:
table.pop(op, None)

# First pass of decompositions with this partitioner's preserved ops
program = program.run_decompositions(table)

# Filter ops using EDGE_DO_NOT_DECOMP
temp_partitioner_dict = {name: [curr_partitioner]}
preserved_ops = (
_replace_aten_ops_with_transformed_ops(
name, program, temp_partitioner_dict
)
or []
)
final_ops_to_preserve.update(preserved_ops)

# Second pass of decompositions with this partitioner's preserved ops after filtering
program = program.run_decompositions(_default_decomposition_table())

# Restore ops from edge_no_decomp_namespace to aten ops
_restore_transformed_ops_to_aten_ops(program)
ops_set_to_not_decompose_by_program[name].extend(final_ops_to_preserve)

edge_programs[name] = program
edge_programs[name] = _generate_edge_program(
config,
program,
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
preserve_ops=ops_set_to_not_decompose_by_program.get(name, []),
)

edge_manager = EdgeProgramManager(
Expand Down Expand Up @@ -1349,9 +1356,6 @@ def to_edge_transform_and_lower( # noqa: C901
elif partitioner is None:
partitioner = {name: [] for name in aten_programs.keys()}

can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
partitioner, aten_programs
)
edge_manager = _gen_edge_manager_for_partitioners(
partitioner, aten_programs, config, constant_methods, generate_etrecord
)
Expand All @@ -1377,7 +1381,8 @@ def to_edge_transform_and_lower( # noqa: C901
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
program
)
if not can_skip_using_EDGE_DO_NOT_DECOMP:

if not _can_skip_using_EDGE_DO_NOT_DECOMP(curr_partitioner, program):
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
_sanity_check_graph_for_non_decomp_ops(
Expand Down
4 changes: 2 additions & 2 deletions export/target_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_ios_recipe(
# pyre-ignore
"ios-arm64-coreml-fp32": [CoreMLRecipeType.FP32, XNNPackRecipeType.FP32],
# pyre-ignore
"ios-arm64-coreml-fp16": [CoreMLRecipeType.FP16],
"ios-arm64-coreml-fp16": [CoreMLRecipeType.FP16, XNNPackRecipeType.FP32],
# pyre-ignore
"ios-arm64-coreml-int8": [CoreMLRecipeType.PT2E_INT8_STATIC],
}
Expand Down Expand Up @@ -165,7 +165,7 @@ def get_android_recipe(

android_configs: Dict[str, List[RecipeType]] = {
# pyre-ignore
"android-arm64-snapdragon-fp16": [QNNRecipeType.FP16],
"android-arm64-snapdragon-fp16": [QNNRecipeType.FP16, XNNPackRecipeType.FP32],
}

if target_config not in android_configs:
Expand Down
2 changes: 1 addition & 1 deletion export/tests/test_target_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def _get_model_test_configs(
@classmethod
def _get_recipes(cls) -> Dict[str, Tuple[ExportRecipe, str]]:
"""Get available recipes with their configurations based on platform."""
all_recipes = {}
all_recipes: Dict[str, Tuple[ExportRecipe, str]] = {}

# Add iOS recipes
if is_supported_platform_for_coreml_lowering():
Expand Down
Loading