Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions backends/apple/coreml/partition/coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Please refer to the license found in the LICENSE file in the root directory of the source tree.

import logging
import inspect
from typing import Callable, List, Optional, Tuple

import coremltools as ct
Expand Down Expand Up @@ -222,10 +223,38 @@ def __init__(
self.take_over_mutable_buffer
), "When lower_full_graph=True, you must set take_over_mutable_buffer=True"


def _check_if_called_from_to_backend(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the indentation here is off. As written, this is defined inside init

"""
Check if the partition method is being called from the deprecated
to_backend workflow.
Returns True if called from deprecated direct to_backend, False if called
from to_edge_transform_and_lower.
"""
stack = inspect.stack()

for frame_info in stack:
if frame_info.function == "to_edge_transform_and_lower":
return False

for frame_info in stack:
if frame_info.function == "to_backend":
filename = frame_info.filename
if "program/_program.py" in filename:
return True
return False
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# Run the CapabilityBasedPartitioner to return the largest possible
# subgraphs containing the nodes with the tags
logger.info("CoreMLPartitioner::partition")
# Check if we're being called from the deprecated to_backend workflow
if self._check_if_called_from_to_backend():
logger.warning("Using the old `to_edge()` flow with CoreML may result in performance regression. "
"The recommended flow is to use `to_edge_transform_and_lower()` with the CoreML partitioner. "
"See the documentation for more details: "
"https://github.com/pytorch/executorch/blob/main/docs/source/backends/coreml/coreml-overview.md#using-the-core-ml-backend"
)
partition_tags = {}

capability_partitioner = CapabilityBasedPartitioner(
Expand Down
78 changes: 78 additions & 0 deletions backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import copy
import sys
import unittest
import io
import logging

import coremltools as ct

Expand All @@ -16,6 +18,7 @@
from executorch.backends.apple.coreml.compiler import CoreMLBackend
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.exir.backend.utils import format_delegated_graph
from executorch.exir import to_edge, to_edge_transform_and_lower


@torch.library.custom_op("unsupported::linear", mutates_args=())
Expand Down Expand Up @@ -346,3 +349,78 @@ def forward(self, x):
test_runner.test_lower_full_graph()
# test_runner.test_symint_arg()
test_runner.test_take_over_constant_data_false()

def test_deprecation_warning_for_to_backend_workflow(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't these be methods inside TestCoreMLPartitioner?

Then in the main section you can add test_runner.test_deprecation_warning_for_to_backend_workflow() and test_runner.test_no_warning_for_to_edge_transform_and_lower_workflow()

"""
Test that the deprecated to_edge + to_backend workflow shows a deprecation
warning.
"""
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 5)

def forward(self, x):
return self.linear(x)

model = SimpleModel()
x = torch.randn(1, 10)

exported_model = torch.export.export(model, (x,))

# Capture log output to check for deprecation warning
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.WARNING)

logger = logging.getLogger(
"executorch.backends.apple.coreml.partition.coreml_partitioner"
)
logger.addHandler(ch)
logger.setLevel(logging.WARNING)

edge = to_edge(exported_model)
partitioner = CoreMLPartitioner()

edge.to_backend(partitioner)

log_contents = log_capture_string.getvalue()
self.assertIn("DEPRECATION WARNING", log_contents)
self.assertIn("to_edge() + to_backend()", log_contents)
self.assertIn("to_edge_transform_and_lower()", log_contents)

def test_no_warning_for_to_edge_transform_and_lower_workflow(self):
"""
Test that the recommended to_edge_transform_and_lower workflow does NOT
show a deprecation warning.
"""
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 5)

def forward(self, x):
return self.linear(x)

model = SimpleModel()
x = torch.randn(1, 10)

exported_model = torch.export.export(model, (x,))

# Capture log output to check for deprecation warning
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.WARNING)

logger = logging.getLogger(
"executorch.backends.apple.coreml.partition.coreml_partitioner"
)
logger.addHandler(ch)
logger.setLevel(logging.WARNING)

partitioner = CoreMLPartitioner()

to_edge_transform_and_lower(exported_model, partitioner=[partitioner])

log_contents = log_capture_string.getvalue()
self.assertNotIn("DEPRECATION WARNING", log_contents)