From 4ec3b3d33a3dbfdd81375473bca021f9c99eda90 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 30 Jul 2025 17:02:38 -0700 Subject: [PATCH 1/3] up --- .../coreml/partition/coreml_partitioner.py | 79 ++++++++++++++----- 1 file changed, 58 insertions(+), 21 deletions(-) diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index bb8a752de6c..6b6e9df53e5 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -20,6 +20,7 @@ PartitionResult, ) from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase @@ -56,6 +57,57 @@ def log_once(self, msg: str) -> None: logger.info(msg) self._logged_msgs.add(msg) + def should_skip_op_for_delegation(self, node_target_name: str) -> bool: + skipped_ops = self.skip_ops_for_coreml_delegation or [] + + if node_target_name in skipped_ops: + return True + + # For backwards compatibility + split_name = node_target_name.split("::") + if len(split_name) == 2: + namespace, name_without_namespace = split_name + if namespace == "aten" and name_without_namespace in skipped_ops: + return True + + return False + + def should_override_support(self, node) -> bool: + # https://github.com/apple/coremltools/issues/2573 + if ( + node.target + in [ + torch.ops.aten.sub.Tensor, + exir_ops.edge.aten.sub.Tensor, + torch.ops.aten.add.Tensor, + exir_ops.edge.aten.add.Tensor, + ] + and "alpha" in node.kwargs + ): + self.log_once( + "torch.ops.aten.{sub, add}.Tensor with alpha is not supported by CoreML. Overriding support." + ) + return True + + # TODO: enable this after bugs in ExecuTorch's partitioner are fixed + # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args + # # in the placeholders due to partitioning, which CoreML does not support + # if not self.lower_full_graph and any( + # isinstance(arg, torch.fx.Node) + # and isinstance( + # arg.meta.get("val", None), + # (torch.SymInt, torch.SymBool, torch.SymFloat), + # ) + # for arg in node.args + # ): + # self.log_once( + # "Skipping op for CoreML delegation because it contains symbolic args: " + # + node_target_name + # ) + # return True + + return False + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: # get_attr node can always be supported on any backend if node.op == "get_attr": @@ -63,8 +115,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: # check if the PyTorch op get called is supported in Core ML elif node.op == "call_function": # skip ops if specified by user - node_target_name = getattr(node.target, "__name__", "").lower() - if node_target_name in (self.skip_ops_for_coreml_delegation or []): + node_target_name = node.target.name().lower() + + if self.should_skip_op_for_delegation(node_target_name): self.log_once( "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: " + node_target_name @@ -74,28 +127,13 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: ), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True" return False - # TODO: enable this after bugs in ExecuTorch's partitioner are fixed - # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args - # # in the placeholders due to partitioning, which CoreML does not support - # if not self.lower_full_graph and any( - # isinstance(arg, torch.fx.Node) - # and isinstance( - # arg.meta.get("val", None), - # (torch.SymInt, torch.SymBool, torch.SymFloat), - # ) - # for arg in node.args - # ): - # self.log_once( - # "Skipping op for CoreML delegation because it contains symbolic args: " - # + node_target_name - # ) - # assert not self.lower_full_graph - # return False - # query coremltools to see if node is supported is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported( node ) + if self.should_override_support(node): + is_supported = False + if not is_supported: if self.lower_full_graph: raise NotImplementedError( @@ -126,7 +164,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: class CoreMLPartitioner(Partitioner): - def __init__( self, *, From 66c928a4d9b812bf940d933aa163e396d624b584 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 30 Jul 2025 17:26:44 -0700 Subject: [PATCH 2/3] up --- .../coreml/partition/coreml_partitioner.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index 6b6e9df53e5..aa6f8652e4b 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -89,6 +89,30 @@ def should_override_support(self, node) -> bool: ) return True + # https://github.com/apple/coremltools/issues/2565 + if node.target in [ + torch.ops.aten.diagonal.default, + torch.ops.aten.diagonal_copy.default, + exir_ops.edge.aten.diagonal.default, + exir_ops.edge.aten.diagonal_copy.default, + ]: + self.log_once( + "torch.ops.aten.diagonal.default has a bug in CoreML. Overriding op support." + ) + return True + + # https://github.com/apple/coremltools/issues/2569 + if node.target in [ + torch.ops.aten.acosh.default, + exir_ops.edge.aten.acosh.default, + torch.ops.aten.asinh.default, + exir_ops.edge.aten.asinh.default, + ]: + self.log_once( + "torch.ops.aten.{acosh, asinh}.default is not supported by CoreML. Overriding op support." + ) + return True + # TODO: enable this after bugs in ExecuTorch's partitioner are fixed # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args # # in the placeholders due to partitioning, which CoreML does not support From 10a4fb976e08a8bd00a7674fafcd9a843154f96d Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 31 Jul 2025 10:59:51 -0700 Subject: [PATCH 3/3] up --- .../coreml/partition/coreml_partitioner.py | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index aa6f8652e4b..93506e6d985 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -59,17 +59,15 @@ def log_once(self, msg: str) -> None: def should_skip_op_for_delegation(self, node_target_name: str) -> bool: skipped_ops = self.skip_ops_for_coreml_delegation or [] - if node_target_name in skipped_ops: + assert ( + not self.lower_full_graph + ), f"Cannot skip {node_target_name} because lower_full_graph is True. Please set skip_ops_for_coreml_delegation=None or lower_full_graph=False in the CoreMLPartitioner" + self.log_once( + "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: " + + node_target_name + ) return True - - # For backwards compatibility - split_name = node_target_name.split("::") - if len(split_name) == 2: - namespace, name_without_namespace = split_name - if namespace == "aten" and name_without_namespace in skipped_ops: - return True - return False def should_override_support(self, node) -> bool: @@ -83,9 +81,10 @@ def should_override_support(self, node) -> bool: exir_ops.edge.aten.add.Tensor, ] and "alpha" in node.kwargs + and node.kwargs["alpha"] != 1 ): self.log_once( - "torch.ops.aten.{sub, add}.Tensor with alpha is not supported by CoreML. Overriding support." + "torch.ops.aten.{sub, add}.Tensor with alpha != 1 is not supported by CoreML. Overriding support." ) return True @@ -139,16 +138,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: # check if the PyTorch op get called is supported in Core ML elif node.op == "call_function": # skip ops if specified by user - node_target_name = node.target.name().lower() + node_target_name = getattr(node.target, "__name__", "").lower() if self.should_skip_op_for_delegation(node_target_name): - self.log_once( - "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: " - + node_target_name - ) - assert ( - not self.lower_full_graph - ), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True" return False # query coremltools to see if node is supported