diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index bb8a752de6c..93506e6d985 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -20,6 +20,7 @@ PartitionResult, ) from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase @@ -56,6 +57,80 @@ def log_once(self, msg: str) -> None: logger.info(msg) self._logged_msgs.add(msg) + def should_skip_op_for_delegation(self, node_target_name: str) -> bool: + skipped_ops = self.skip_ops_for_coreml_delegation or [] + if node_target_name in skipped_ops: + assert ( + not self.lower_full_graph + ), f"Cannot skip {node_target_name} because lower_full_graph is True. Please set skip_ops_for_coreml_delegation=None or lower_full_graph=False in the CoreMLPartitioner" + self.log_once( + "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: " + + node_target_name + ) + return True + return False + + def should_override_support(self, node) -> bool: + # https://github.com/apple/coremltools/issues/2573 + if ( + node.target + in [ + torch.ops.aten.sub.Tensor, + exir_ops.edge.aten.sub.Tensor, + torch.ops.aten.add.Tensor, + exir_ops.edge.aten.add.Tensor, + ] + and "alpha" in node.kwargs + and node.kwargs["alpha"] != 1 + ): + self.log_once( + "torch.ops.aten.{sub, add}.Tensor with alpha != 1 is not supported by CoreML. Overriding support." + ) + return True + + # https://github.com/apple/coremltools/issues/2565 + if node.target in [ + torch.ops.aten.diagonal.default, + torch.ops.aten.diagonal_copy.default, + exir_ops.edge.aten.diagonal.default, + exir_ops.edge.aten.diagonal_copy.default, + ]: + self.log_once( + "torch.ops.aten.diagonal.default has a bug in CoreML. Overriding op support." + ) + return True + + # https://github.com/apple/coremltools/issues/2569 + if node.target in [ + torch.ops.aten.acosh.default, + exir_ops.edge.aten.acosh.default, + torch.ops.aten.asinh.default, + exir_ops.edge.aten.asinh.default, + ]: + self.log_once( + "torch.ops.aten.{acosh, asinh}.default is not supported by CoreML. Overriding op support." + ) + return True + + # TODO: enable this after bugs in ExecuTorch's partitioner are fixed + # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args + # # in the placeholders due to partitioning, which CoreML does not support + # if not self.lower_full_graph and any( + # isinstance(arg, torch.fx.Node) + # and isinstance( + # arg.meta.get("val", None), + # (torch.SymInt, torch.SymBool, torch.SymFloat), + # ) + # for arg in node.args + # ): + # self.log_once( + # "Skipping op for CoreML delegation because it contains symbolic args: " + # + node_target_name + # ) + # return True + + return False + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: # get_attr node can always be supported on any backend if node.op == "get_attr": @@ -64,38 +139,17 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: elif node.op == "call_function": # skip ops if specified by user node_target_name = getattr(node.target, "__name__", "").lower() - if node_target_name in (self.skip_ops_for_coreml_delegation or []): - self.log_once( - "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: " - + node_target_name - ) - assert ( - not self.lower_full_graph - ), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True" - return False - # TODO: enable this after bugs in ExecuTorch's partitioner are fixed - # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args - # # in the placeholders due to partitioning, which CoreML does not support - # if not self.lower_full_graph and any( - # isinstance(arg, torch.fx.Node) - # and isinstance( - # arg.meta.get("val", None), - # (torch.SymInt, torch.SymBool, torch.SymFloat), - # ) - # for arg in node.args - # ): - # self.log_once( - # "Skipping op for CoreML delegation because it contains symbolic args: " - # + node_target_name - # ) - # assert not self.lower_full_graph - # return False + if self.should_skip_op_for_delegation(node_target_name): + return False # query coremltools to see if node is supported is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported( node ) + if self.should_override_support(node): + is_supported = False + if not is_supported: if self.lower_full_graph: raise NotImplementedError( @@ -126,7 +180,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: class CoreMLPartitioner(Partitioner): - def __init__( self, *,