Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 80 additions & 27 deletions backends/apple/coreml/partition/coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase
Expand Down Expand Up @@ -56,6 +57,80 @@ def log_once(self, msg: str) -> None:
logger.info(msg)
self._logged_msgs.add(msg)

def should_skip_op_for_delegation(self, node_target_name: str) -> bool:
skipped_ops = self.skip_ops_for_coreml_delegation or []
if node_target_name in skipped_ops:
assert (
not self.lower_full_graph
), f"Cannot skip {node_target_name} because lower_full_graph is True. Please set skip_ops_for_coreml_delegation=None or lower_full_graph=False in the CoreMLPartitioner"
self.log_once(
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
+ node_target_name
)
return True
return False

def should_override_support(self, node) -> bool:
# https://github.com/apple/coremltools/issues/2573
if (
node.target
in [
torch.ops.aten.sub.Tensor,
exir_ops.edge.aten.sub.Tensor,
torch.ops.aten.add.Tensor,
exir_ops.edge.aten.add.Tensor,
]
and "alpha" in node.kwargs
and node.kwargs["alpha"] != 1
):
self.log_once(
"torch.ops.aten.{sub, add}.Tensor with alpha != 1 is not supported by CoreML. Overriding support."
)
return True

# https://github.com/apple/coremltools/issues/2565
if node.target in [
torch.ops.aten.diagonal.default,
torch.ops.aten.diagonal_copy.default,
exir_ops.edge.aten.diagonal.default,
exir_ops.edge.aten.diagonal_copy.default,
]:
self.log_once(
"torch.ops.aten.diagonal.default has a bug in CoreML. Overriding op support."
)
return True

# https://github.com/apple/coremltools/issues/2569
if node.target in [
torch.ops.aten.acosh.default,
exir_ops.edge.aten.acosh.default,
torch.ops.aten.asinh.default,
exir_ops.edge.aten.asinh.default,
]:
self.log_once(
"torch.ops.aten.{acosh, asinh}.default is not supported by CoreML. Overriding op support."
)
return True

# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
# # in the placeholders due to partitioning, which CoreML does not support
# if not self.lower_full_graph and any(
# isinstance(arg, torch.fx.Node)
# and isinstance(
# arg.meta.get("val", None),
# (torch.SymInt, torch.SymBool, torch.SymFloat),
# )
# for arg in node.args
# ):
# self.log_once(
# "Skipping op for CoreML delegation because it contains symbolic args: "
# + node_target_name
# )
# return True

return False

def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
# get_attr node can always be supported on any backend
if node.op == "get_attr":
Expand All @@ -64,38 +139,17 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
elif node.op == "call_function":
# skip ops if specified by user
node_target_name = getattr(node.target, "__name__", "").lower()
if node_target_name in (self.skip_ops_for_coreml_delegation or []):
self.log_once(
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
+ node_target_name
)
assert (
not self.lower_full_graph
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
return False

# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
# # in the placeholders due to partitioning, which CoreML does not support
# if not self.lower_full_graph and any(
# isinstance(arg, torch.fx.Node)
# and isinstance(
# arg.meta.get("val", None),
# (torch.SymInt, torch.SymBool, torch.SymFloat),
# )
# for arg in node.args
# ):
# self.log_once(
# "Skipping op for CoreML delegation because it contains symbolic args: "
# + node_target_name
# )
# assert not self.lower_full_graph
# return False
if self.should_skip_op_for_delegation(node_target_name):
return False

# query coremltools to see if node is supported
is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported(
node
)
if self.should_override_support(node):
is_supported = False

if not is_supported:
if self.lower_full_graph:
raise NotImplementedError(
Expand Down Expand Up @@ -126,7 +180,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:


class CoreMLPartitioner(Partitioner):

def __init__(
self,
*,
Expand Down
Loading