Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16718
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 2 Cancelled Jobs, 1 Unrelated FailureAs of commit d8ee9d2 with merge base 25f2a3f ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
5e84aa5 to
1867cfc
Compare
6541f23 to
f296724
Compare
e0e015c to
240b241
Compare
| @@ -451,7 +451,9 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag | |||
| return self | |||
|
|
|||
| def to_edge_transform_and_lower( | |||
| self, partitioners: Optional[List[Partitioner]] | |||
| self, | |||
| # Only do this check if all the dims are static. | ||
| if all(isinstance(dim, int) for dim in orig_tensor.size()): | ||
| if orig_tensor.shape == node.meta["val"].shape: | ||
| output_tensor = node.meta["val"] |
This introduces a CSE pass to ExecuTorch, which eliminates common subexpressions that occur in exported programs. This pass was first developed as part of the MLX delegate (#16718) to optimize transformers, but I'm introducing it to ExecuTorch more generally because I believe it could benefit many other backends. Examples of common subexpressions that occur in transformers: * Repeated mask constructions per layer (only needs to be done once) * Repeated extraction of symints from 1d tensors for cache position (emits .item calls, which cause tensor materialization) This pass eliminates these inefficiencies without having to rewrite the model.
| } | ||
|
|
||
| // For fields that can be either a literal int or a runtime Vid | ||
| table IntOrVid { |
There was a problem hiding this comment.
Why not just use a union type?
There was a problem hiding this comment.
FlatBuffer unions can only contain tables, not scalar primitives. Using a union here would require wrapping the int64 literal in a table (table IntLiteral { value: int64; }), adding pointer indirection and a table allocation for every shape dimension. The is_vid bool keeps the common case (literal) as a direct inline field read.
There was a problem hiding this comment.
The more important reason is flatbutter doesn't allow vector of union
| scale: float = 1.0; | ||
| } | ||
|
|
||
| table SdpaNode { |
There was a problem hiding this comment.
Whats the story for out of tree custom ops? Just not possible?
There was a problem hiding this comment.
What do you mean by out of tree custom ops?
From MLX perspective, they allow custom ops using raw shader code. I guess we could add a node for that to the schema.
| causal: bool = false; | ||
| } | ||
|
|
||
| table AddNode { |
There was a problem hiding this comment.
Similarly what happens if MLX changes the schema of any of these operators in a BC breaking way, or even adds a new optional arg with default values. Seems like there would be a lot of op versioning issues by explicitly encoding every ops schema in the serialization format. We also have a lot less influence on their op versioning story compared with ATen (where its already a big pita).
There was a problem hiding this comment.
AddNode is an abstraction in the MLX delegate that we control (not part of MLX itself).
The runtime takes AddNode and translates it to an MLX call (add op) by parsing the args. So even if mlx changes the signature of their add op, it doesn't mean we break AddNode in the graph. We would just update our runtime appropriately. (With that said, MLX did say they try to maintain BC of their underlying ops).
Maybe I'm missing the concern here, though?
backends/mlx/ops.py
Outdated
| return out | ||
|
|
||
| except ImportError: | ||
| # Edge IR ops not available (e.g., when building from ATen dialect) |
There was a problem hiding this comment.
When are you doing this?
There was a problem hiding this comment.
I'm not. I'll remove this
|
|
||
|
|
||
| @REGISTRY.register_pattern(name="ET_KV_CACHE_UPDATE") | ||
| class ETKVCacheUpdateHandler(PatternHandler): |
There was a problem hiding this comment.
Why is both a custom mutator op supported, but also the slice_update ops earlier?
There was a problem hiding this comment.
I guess you mean the difference between the IndexUpdateNode pattern above and this one?
They are similar, but ET_KV_CACHE_UPDATE has slightly better perf and is tied to a KVCache building block in mlx/llm/cache.py.
In terms of the difference:
-
INDEX_UPDATE matches a functionalized copy_ pattern (used in many HF models) and maps it onto IndexUpdateNode, which is an in-place update op in the MLX backend. IndexUpdateNode works by finding contiguous ranges in indices and repeatedly calling slice update. So in practice it will be 1 slice update call for regular kv cache, but up to 2 calls for ring buffer stuff, but it does require the indices be materialized in order to parse the index ranges (which can slow down MLX computation).
-
ET_KV_CACHE_UPDATE emits as one slice update call, and does not require materialization of the indices tensor.
At this point, the benefit is maybe 10-15% extra perf.
|
|
||
|
|
||
| @dataclass | ||
| class PatternMatch: |
There was a problem hiding this comment.
XNNPack has their own pattern matching util iirc. Could we just offer this in Exir cc @GregoryComer
There was a problem hiding this comment.
Happy to reuse any pre-existing utilities if they're in exir.
If the XNNPACK one is based on the torch.fx pattern finding util, I don't know if it'll work, though. I recall it being very bad at fuzzy matching of args/kwargs (like you needed exact matches, so you have to write like 3-4 pattern matchers for each op, and it's a bit error prone)
|
|
||
| logger.info("Exporting audio preprocessor with MLX backend...") | ||
|
|
||
| model = WhisperAudioProcessor( |
There was a problem hiding this comment.
Rename if this is used in models other than Whisper (like here in Voxtral)
I added a strict compile test, addressed security issues, and addressed null ptr references. |
Summary
This PR adds an MLX backend for ExecuTorch, enabling Metal-accelerated inference on Apple Silicon. It runs Llama, Qwen, Gemma, Whisper, Voxtral, and Parakeet models end-to-end, with 637 passing op tests and multithreaded execution support. For many models, it offers best performance among all ExecuTorch backends on Apple Silicon, offering 2-6x speedups over what was previously possible with ExecuTorch, and up to 30% smaller model sizes compared to XNNPACK due to BF16 support and tied quantized embedding support.
The PR is large due to extensive op coverage, testing, and documentation, but almost all changes are confined to
backends/mlx/. The design is described inbackends/mlx/README.md.Suggested review approach:
backends/mlx/carefully — these integrate with ExecuTorch's build system and are the most likely to need changes.backends/mlx/, focus on structural design (see README) and test coverage (CI job is .github/workflows/mlx.yml)Prerequisite PRs
These fixes were developed alongside the MLX backend. Once merged, this PR can be rebased to remove the duplicated changes:
remove_noop_passTests
CI is defined in
.github/workflows/mlx.yml:The 1 failing operator test is a test-side issue being fixed in #17539. The 3 failing model tests will be addressed in follow-ups — they are not an initial focus compared to the GenAI models above.