Skip to content

Conversation

metascroy
Copy link
Contributor

This updates CoreML docs to:

  • Discuss the new partitioner options
  • Discuss quantize_ support
  • Discuss backward compatibility guarantees

@metascroy metascroy requested a review from mergennachin as a code owner August 5, 2025 00:33
Copy link

pytorch-bot bot commented Aug 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/13120

Note: Links to docs will display an error until the docs builds have been completed.

❌ 8 New Failures, 2 Unrelated Failures

As of commit ffc7040 with merge base ec35f56 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 5, 2025
@metascroy metascroy requested a review from YifanShenSZ August 5, 2025 00:33
Copy link

github-actions bot commented Aug 5, 2025

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

)
```

Both of the above examples will export and lower to CoreML with the to_edge_transform_and_lower API.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does codebook one actually lower to coreml? I tried looking up choose_qparams_and_quantize_codebook in et and coremltools but didnt find anythibng

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a user's perspective, it should just lower: after quantize_, you can run torch.export.export, and then to_edge_transform_and_lower.

In terms of how it works, I added the ability to register custom MIL ops in ET CoreML, and I used that to register the dequantize_codebook quant primitive that is produced by CodebookWeightOnlyConfig.


See [PyTorch 2 Export Post Training Quantization](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) for more information.

### LLM quantization with quantize_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy Is there a minimum_deployment_target required/ published for Torchao quantization and PT2E quantization, i remember you mentioned it is None by default but how do we enforce if one is using quantization recipe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CoreML should select the required minimum_deployment target automatically. For PT2E, it should select iOS17.

But for quantize_, I noticed it was only working for iOS18 now (need to investigate further): #13122

In terms of how do we enforce it: it should work automatically for PT2E, but let me know if it doesn't. For quantize_, I'll try to make it work automatically, but as an intermediate stop-gap, we can explicitly set to iOS18 if quantize_ is used in the recipe.

- `coremltools.ComputeUnit.CPU_AND_NE` (uses both the CPU and ANE, but not the GPU)
- `minimum_deployment_target`: The minimum iOS deployment target (e.g., `coremltools.target.iOS18`). The default value is `coremltools.target.iOS15`.
- `minimum_deployment_target`: The minimum iOS deployment target (e.g., `coremltools.target.iOS18`). By default, the smallest deployment target needed to deploy the model is selected. During export, you will see a warning about the "CoreML specification version" that was used for the model, which maps onto a deployment target as discussed [here](https://apple.github.io/coremltools/mlmodel/Format/Model.html#model). If you need to control the deployment target, please specify it explicitly.
- `compute_precision`: The compute precision used by CoreML (`coremltools.precision.FLOAT16` or `coremltools.precision.FLOAT32`). The default value is `coremltools.precision.FLOAT16`. Note that the compute precision is applied no matter what dtype is specified in the exported PyTorch model. For example, an FP32 PyTorch model will be converted to FP16 when delegating to the CoreML backend by default. Also note that the ANE only supports FP16 precision.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob question, it seems we publish the default as FLOAT16 in the generate_compile_specs function, what happens when a quantizer, would the backend ignores this, or is it upto the user to make sure there is no compute_precision in compile specs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even for a quantized model, there is a compute precision. Compute precision controls the precision of the non-quantized ops in the model.

@metascroy
Copy link
Contributor Author

@YifanShenSZ can I get a review on these doc updates for the CoreML backend?

Comment on lines 1 to 3
# CoreML Backend

Core ML delegate is the ExecuTorch solution to take advantage of Apple's [CoreML framework](https://developer.apple.com/documentation/coreml) for on-device ML. With CoreML, a model can run on CPU, GPU, and the Apple Neural Engine (ANE).
CoreML delegate is the ExecuTorch solution to take advantage of Apple's [CoreML framework](https://developer.apple.com/documentation/coreml) for on-device ML. With CoreML, a model can run on CPU, GPU, and the Apple Neural Engine (ANE).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Official name is "Core ML" not "CoreML" https://developer.apple.com/documentation/coreml

Comment on lines +158 to +160
# When using an enumerated shape compile spec, you must specify lower_full_graph=True
# in the CoreMLPartitioner. We do not support using enumerated shapes
# for partially exported models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there any error or warning thrown or it just slienty fails to lower

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It throws an error asking users to set lower_full_graph=True

Comment on lines 278 to 279
* Quantize embedding/linear layers with IntxWeightOnlyConfig (with weight_dtype torch.int4 or torch.int8, using PerGroup or PerAxis granularity). Using 4-bit or PerGroup quantization requires exporting with minimum_deployment_target >= ct.target.iOS18. Using 8-bit quantization with per-axis granularity is supported on ct.target.IOS16+. See [CoreML `CompileSpec`](#coreml-compilespec) for more information on setting the deployment target.
* Quantize embedding/linear layers with CodebookWeightOnlyConfig (with dtype torch.uint1 through torch.uint8, using various block sizes). Quantizing with CodebookWeightOnlyConfig requires exporting with minimum_deployment_target >= ct.target.iOS18, see [CoreML `CompileSpec`](#coreml-compilespec) for more information on setting the deployment target.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abhinaykukkadapu are these part of some coreml recipe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this diff has been out for a while. Are you planning to land?

@metascroy
Copy link
Contributor Author

this diff has been out for a while. Are you planning to land?

Yes, I hope to land this week. @YifanShenSZ approved the changes last week, but I need a stamp from someone in PyTorch to land.

@metascroy
Copy link
Contributor Author

@mergennachin @abhinaykukkadapu @kimishpatel @digantdesai can I get a stamp on this doc update?


The Core ML backend also supports quantizing models with the [torchao](https://github.com/pytorch/ao) quantize_ API. This is most commonly used for LLMs, requiring more advanced quantization. Since quantize_ is not backend aware, it is important to use a config that is compatible with Core ML:

* Quantize embedding/linear layers with IntxWeightOnlyConfig (with weight_dtype torch.int4 or torch.int8, using PerGroup or PerAxis granularity). Using 4-bit or PerGroup quantization requires exporting with minimum_deployment_target >= ct.target.iOS18. Using 8-bit quantization with per-axis granularity is supported on ct.target.IOS16+. See [Core ML `CompileSpec`](#coreml-compilespec) for more information on setting the deployment target.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy do you have to update minimum_deployment_target to ios16 after this PR: #13896

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s already updated in the docs. They say iOS16 for 8-bit per channel, and iOS18 for everything else

@metascroy metascroy merged commit 8d1684e into main Sep 10, 2025
106 of 116 checks passed
@metascroy metascroy deleted the update-coreml-docs branch September 10, 2025 16:42
StrycekSimon pushed a commit to nxp-upstream/executorch that referenced this pull request Sep 23, 2025
This updates CoreML docs to:

* Discuss the new partitioner options
* Discuss quantize_ support
* Discuss backward compatibility guarantees
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants