Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 157 additions & 121 deletions docs/source/backends-coreml.md
Original file line number Diff line number Diff line change
@@ -1,164 +1,200 @@
# Core ML Backend

Core ML delegate uses Core ML APIs to enable running neural networks via Apple's hardware acceleration. For more about Core ML you can read [here](https://developer.apple.com/documentation/coreml). In this tutorial, we will walk through the steps of lowering a PyTorch model to Core ML delegate
Core ML delegate is the ExecuTorch solution to take advantage of Apple's [CoreML framework](https://developer.apple.com/documentation/coreml) for on-device ML. With CoreML, a model can run on CPU, GPU, and the Apple Neural Engine (ANE).

## Features

::::{grid} 2
:::{grid-item-card} What you will learn in this tutorial:
:class-card: card-prerequisites
* In this tutorial you will learn how to export [MobileNet V3](https://pytorch.org/vision/main/models/mobilenetv3.html) model so that it runs on Core ML backend.
* You will also learn how to deploy and run the exported model on a supported Apple device.
:::
:::{grid-item-card} Tutorials we recommend you complete before this:
:class-card: card-prerequisites
* [Introduction to ExecuTorch](./intro-how-it-works.md)
* [Getting Started](./getting-started.md)
* [Building ExecuTorch with CMake](./using-executorch-building-from-source.md)
* [ExecuTorch iOS Demo App](demo-apps-ios.md)
:::
::::
- Dynamic dispatch to the CPU, GPU, and ANE.
- Supports fp32 and fp16 computation.

## Target Requirements

## Prerequisites (Hardware and Software)
Below are the minimum OS requirements on various hardware for running a CoreML-delegated ExecuTorch model:
- [macOS](https://developer.apple.com/macos) >= 13.0
- [iOS](https://developer.apple.com/ios/) >= 16.0
- [iPadOS](https://developer.apple.com/ipados/) >= 16.0
- [tvOS](https://developer.apple.com/tvos/) >= 16.0

In order to be able to successfully build and run the ExecuTorch's Core ML backend you'll need the following hardware and software components.
## Development Requirements
To develop you need:

### Hardware:
- A [mac](https://www.apple.com/mac/) system for building.
- A [mac](https://www.apple.com/mac/) or [iPhone](https://www.apple.com/iphone/) or [iPad](https://www.apple.com/ipad/) or [Apple TV](https://www.apple.com/tv-home/) device for running the model.
- [macOS](https://developer.apple.com/macos) >= 13.0.
- [Xcode](https://developer.apple.com/documentation/xcode) >= 14.1

### Software:

- [Xcode](https://developer.apple.com/documentation/xcode) >= 14.1, [macOS](https://developer.apple.com/macos) >= 13.0 for building.
- [macOS](https://developer.apple.com/macos) >= 13.0, [iOS](https://developer.apple.com/ios/) >= 16.0, [iPadOS](https://developer.apple.com/ipados/) >= 16.0, and [tvOS](https://developer.apple.com/tvos/) >= 16.0 for running the model.

## Setting up your developer environment

1. Make sure that you have completed the ExecuTorch setup tutorials linked to at the top of this page and setup the environment.
2. Run `install_requirements.sh` to install dependencies required by the **Core ML** backend.

```bash
cd executorch
./backends/apple/coreml/scripts/install_requirements.sh
```
3. Install [Xcode](https://developer.apple.com/xcode/).
4. Install Xcode Command Line Tools.
Before starting, make sure you install the Xcode Command Line Tools:

```bash
xcode-select --install
```

## Build

### AOT (Ahead-of-time) components:


**Exporting a Core ML delegated Program**:
- In this step, you will lower the [MobileNet V3](https://pytorch.org/vision/main/models/mobilenetv3.html) model to the Core ML backend and export the ExecuTorch program. You'll then deploy and run the exported program on a supported Apple device using Core ML backend.
Finally you must install the CoreML backend by running the following script:
```bash
cd executorch

# Generates ./mv3_coreml_all.pte file.
python3 -m examples.apple.coreml.scripts.export --model_name mv3
sh ./backends/apple/coreml/scripts/install_requirements.sh
```

- Core ML backend uses [coremltools](https://apple.github.io/coremltools/docs-guides/source/overview-coremltools.html) to lower [Edge dialect](ir-exir.md#edge-dialect) to Core ML format and then bundles it in the `.pte` file.


### Runtime:

**Running a Core ML delegated Program**:
1. Build the runner.
```bash
cd executorch
----

# Builds `coreml_executor_runner`.
./examples/apple/coreml/scripts/build_executor_runner.sh
```
2. Run the CoreML delegated program.
```bash
cd executorch
## Using the CoreML Backend

# Runs the exported mv3 model using the Core ML backend.
./coreml_executor_runner --model_path mv3_coreml_all.pte
```
To target the CoreML backend during the export and lowering process, pass an instance of the `CoreMLPartitioner` to `to_edge_transform_and_lower`. The example below demonstrates this process using the MobileNet V2 model from torchvision.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯


**Profiling a Core ML delegated Program**:
```python
import torch
import torchvision.models as models
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.exir import to_edge_transform_and_lower

Note that profiling is supported on [macOS](https://developer.apple.com/macos) >= 14.4.
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )

1. [Optional] Generate an [ETRecord](./etrecord.rst) when exporting your model.
```bash
cd executorch
et_program = to_edge_transform_and_lower(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

torch.export.export(mobilenet_v2, sample_inputs),
partitioner=[CoreMLPartitioner()],
).to_executorch()

# Generates `mv3_coreml_all.pte` and `mv3_coreml_etrecord.bin` files.
python3 -m examples.apple.coreml.scripts.export --model_name mv3 --generate_etrecord
with open("mv2_coreml.pte", "wb") as file:
et_program.write_to_file(file)
```

2. Build the runner.
```bash
# Builds `coreml_executor_runner`.
./examples/apple/coreml/scripts/build_executor_runner.sh
```
3. Run and generate an [ETDump](./etdump.md).
```bash
cd executorch

# Generate the ETDump file.
./coreml_executor_runner --model_path mv3_coreml_all.pte --profile_model --etdump_path etdump.etdp
### Partitioner API

The CoreML partitioner API allows for configuration of the model delegation to CoreML. Passing an `CoreMLPartitioner` instance with no additional parameters will run as much of the model as possible on the CoreML backend with default settings. This is the most common use-case. For advanced use cases, the partitioner exposes the following options via the [constructor](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/partition/coreml_partitioner.py#L60):


- `skip_ops_for_coreml_delegation`: Allows you to skip ops for delegation by CoreML. By default, all ops that CoreML supports will be delegated. See [here](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/test/test_coreml_partitioner.py#L42) for an example of skipping an op for delegation.
- `compile_specs`: A list of CompileSpec for the CoreML backend. These control low-level details of CoreML delegation, such as the compute unit (CPU, GPU, ANE), the iOS deployment target, and the compute precision (FP16, FP32). These are discussed more below.
- `take_over_mutable_buffer`: A boolean that indicates whether PyTorch mutable buffers in stateful models should be converted to [CoreML MLState](https://developer.apple.com/documentation/coreml/mlstate). If set to false, mutable buffers in the PyTorch graph are converted to graph inputs and outputs to the CoreML lowered module under the hood. Generally setting take_over_mutable_buffer to true will result in better performance, but using MLState requires iOS >= 18.0, macOS >= 15.0, and XCode >= 16.0.

#### CoreML CompileSpec

A list of CompileSpec is constructed with [CoreMLBackend.generate_compile_specs](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L210). Below are the available options:
- `compute_unit`: this controls the compute units (CPU, GPU, ANE) that are used by CoreML. The default value is coremltools.ComputeUnit.ALL. The available options from coremltools are:
- coremltools.ComputeUnit.ALL (uses the CPU, GPU, and ANE)
- coremltools.ComputeUnit.CPU_ONLY (uses the CPU only)
- coremltools.ComputeUnit.CPU_AND_GPU (uses both the CPU and GPU, but not the ANE)
- coremltools.ComputeUnit.CPU_AND_NE (uses both the CPU and ANE, but not the GPU)
- `minimum_deployment_target`: The minimum iOS deployment target (e.g., coremltools.target.iOS18). The default value is coremltools.target.iOS15.
- `compute_precision`: The compute precision used by CoreML (coremltools.precision.FLOAT16, coremltools.precision.FLOAT32). The default value is coremltools.precision.FLOAT16. Note that the compute precision is applied no matter what dtype is specified in the exported PyTorch model. For example, an FP32 PyTorch model will be converted to FP16 when delegating to the CoreML backend by default. Also note that the ANE only supports FP16 precision.
- `model_type`: Whether the model should be compiled to the CoreML [mlmodelc format](https://developer.apple.com/documentation/coreml/downloading-and-compiling-a-model-on-the-user-s-device) during .pte creation ([CoreMLBackend.MODEL_TYPE.COMPILED_MODEL](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L71)), or whether it should be compiled to mlmodelc on device ([CoreMLBackend.MODEL_TYPE.MODEL](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L70)). Using CoreMLBackend.MODEL_TYPE.COMPILED_MODEL and doing compilation ahead of time should improve the first time on-device model load time.

### Testing the Model

After generating the CoreML-delegated .pte, the model can be tested from Python using the ExecuTorch runtime python bindings. This can be used to sanity check the model and evaluate numerical accuracy. See [Testing the Model](using-executorch-export.md#testing-the-model) for more information.

----

### Quantization

To quantize a PyTorch model for the CoreML backend, use the `CoreMLQuantizer`. `Quantizers` are backend specific, and the `CoreMLQuantizer` is configured to quantize models to leverage the available quantization for the CoreML backend.

### 8-bit Quantization using the PT2E Flow

To perform 8-bit quantization with the PT2E flow, perform the following steps:

1) Define [coremltools.optimize.torch.quantization.LinearQuantizerConfig](https://apple.github.io/coremltools/source/coremltools.optimize.torch.quantization.html#coremltools.optimize.torch.quantization.LinearQuantizerConfig) and use to to create an instance of a `CoreMLQuantizer`.
2) Use `torch.export.export_for_training` to export a graph module that will be prepared for quantization.
3) Call `prepare_pt2e` to prepare the model for quantization.
4) For static quantization, run the prepared model with representative samples to calibrate the quantizated tensor activation ranges.
5) Call `convert_pt2e` to quantize the model.
6) Export and lower the model using the standard flow.

The output of `convert_pt2e` is a PyTorch model which can be exported and lowered using the normal flow. As it is a regular PyTorch model, it can also be used to evaluate the accuracy of the quantized model using standard PyTorch techniques.

```python
import torch
import coremltools as ct
import torchvision.models as models
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from executorch.exir import to_edge_transform_and_lower
from executorch.backends.apple.coreml.compiler import CoreMLBackend

mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )

# Step 1: Define a LinearQuantizerConfig and create an instance of a CoreMLQuantizer
quantization_config = ct.optimize.torch.quantization.LinearQuantizerConfig.from_dict(
{
"global_config": {
"quantization_scheme": ct.optimize.torch.quantization.QuantizationScheme.symmetric,
"milestones": [0, 0, 10, 10],
"activation_dtype": torch.quint8,
"weight_dtype": torch.qint8,
"weight_per_channel": True,
}
}
)
quantizer = CoreMLQuantizer(quantization_config)

# Step 2: Export the model for training
training_gm = torch.export.export_for_training(mobilenet_v2, sample_inputs).module()

# Step 3: Prepare the model for quantization
prepared_model = prepare_pt2e(training_gm, quantizer)

# Step 4: Calibrate the model on representative data
# Replace with your own calibration data
for calibration_sample in [torch.randn(1, 3, 224, 224)]:
prepared_model(calibration_sample)

# Step 5: Convert the calibrated model to a quantized model
quantized_model = convert_pt2e(prepared_model)

# Step 6: Export the quantized model to CoreML
et_program = to_edge_transform_and_lower(
torch.export.export(quantized_model, sample_inputs),
partitioner=[
CoreMLPartitioner(
# iOS17 is required for the quantized ops in this example
compile_specs=CoreMLBackend.generate_compile_specs(
minimum_deployment_target=ct.target.iOS17
)
)
],
).to_executorch()
```

4. Create an instance of the [Inspector API](./model-inspector.rst) by passing in the [ETDump](./etdump.md) you have sourced from the runtime along with the optionally generated [ETRecord](./etrecord.rst) from step 1 or execute the following command in your terminal to display the profiling data table.
```bash
python examples/apple/coreml/scripts/inspector_cli.py --etdump_path etdump.etdp --etrecord_path mv3_coreml.bin
```
See [PyTorch 2 Export Post Training Quantization](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html) for more information.

----

## Deploying and running on a device
## Runtime integration

**Running the Core ML delegated Program in the Demo iOS App**:
1. Please follow the [Export Model](demo-apps-ios.md#models-and-labels) step of the tutorial to bundle the exported [MobileNet V3](https://pytorch.org/vision/main/models/mobilenetv3.html) program. You only need to do the Core ML part.
To run the model on-device, use the standard ExecuTorch runtime APIs. See [Running on Device](getting-started.md#running-on-device) for more information, including building the iOS frameworks.

2. Complete the [Build Runtime and Backends](demo-apps-ios.md#build-runtime-and-backends) section of the tutorial. When building the frameworks you only need the `coreml` option.
When building from source, pass `-DEXECUTORCH_BUILD_COREML=ON` when configuring the CMake build to compile the CoreML backend.

3. Complete the [Final Steps](demo-apps-ios.md#final-steps) section of the tutorial to build and run the demo app.
To link against the `coremldelegate` target. Due to the use of static registration, it may be necessary to link with whole-archive. This can typically be done by passing `"$<LINK_LIBRARY:WHOLE_ARCHIVE,coremldelegate>"` to `target_link_libraries`.

<br>**Running the Core ML delegated Program in your App**
1. Build frameworks, running the following will create a `executorch.xcframework` and `coreml_backend.xcframework` in the `cmake-out` directory.
```bash
cd executorch
./build/build_apple_frameworks.sh --coreml
```
2. Create a new [Xcode project](https://developer.apple.com/documentation/xcode/creating-an-xcode-project-for-an-app#) or open an existing project.

3. Drag the `executorch.xcframework` and `coreml_backend.xcframework` generated from Step 2 to Frameworks.

4. Go to the project's [Build Phases](https://developer.apple.com/documentation/xcode/customizing-the-build-phases-of-a-target) - Link Binaries With Libraries, click the + sign, and add the following frameworks:
```
executorch.xcframework
coreml_backend.xcframework
Accelerate.framework
CoreML.framework
libsqlite3.tbd
# CMakeLists.txt
add_subdirectory("executorch")
...
target_link_libraries(
my_target
PRIVATE executorch
executorch_module_static
executorch_tensor
optimized_native_cpu_ops_lib
coremldelegate)
```
5. Add the exported program to the [Copy Bundle Phase](https://developer.apple.com/documentation/xcode/customizing-the-build-phases-of-a-target#Copy-files-to-the-finished-product) of your Xcode target.

6. Please follow the [Runtime APIs Tutorial](extension-module.md) to integrate the code for loading an ExecuTorch program.
No additional steps are necessary to use the backend beyond linking the target. A CoreML-delegated .pte file will automatically run on the registered backend.

7. Update the code to load the program from the Application's bundle.
``` objective-c
NSURL *model_url = [NBundle.mainBundle URLForResource:@"mv3_coreml_all" extension:@"pte"];

Result<executorch::extension::FileDataLoader> loader =
executorch::extension::FileDataLoader::from(model_url.path.UTF8String);
```
---

8. Use [Xcode](https://developer.apple.com/documentation/xcode/building-and-running-an-app#Build-run-and-debug-your-app) to deploy the application on the device.
## Advanced

9. The application can now run the [MobileNet V3](https://pytorch.org/vision/main/models/mobilenetv3.html) model on the Core ML backend.
### Extracting the mlpackage

<br>In this tutorial, you have learned how to lower the [MobileNet V3](https://pytorch.org/vision/main/models/mobilenetv3.html) model to the Core ML backend, deploy, and run it on an Apple device.
## Frequently encountered errors and resolution.
[CoreML *.mlpackage files](https://apple.github.io/coremltools/docs-guides/source/convert-to-ml-program.html#save-ml-programs-as-model-packages) can be extracted from a CoreML-delegated *.pte file. This can help with debugging and profiling for users who are more familiar with *.mlpackage files:
```bash
python examples/apple/coreml/scripts/extract_coreml_models.py -m /path/to/model.pte
```

If you encountered any bugs or issues following this tutorial please file a bug/issue [here](https://github.com/pytorch/executorch/issues) with tag #coreml.
Note that if the ExecuTorch model has graph breaks, there may be multiple extracted *.mlpackage files.
Loading