Skip to content

ExecuTorch MLX delegate#16718

Open
metascroy wants to merge 92 commits intomainfrom
mlx-delegate
Open

ExecuTorch MLX delegate#16718
metascroy wants to merge 92 commits intomainfrom
mlx-delegate

Conversation

@metascroy
Copy link
Contributor

@metascroy metascroy commented Jan 20, 2026

Summary

This PR adds an MLX backend for ExecuTorch, enabling Metal-accelerated inference on Apple Silicon. It runs Llama, Qwen, Gemma, Whisper, Voxtral, and Parakeet models end-to-end, with 637 passing op tests and multithreaded execution support. For many models, it offers best performance among all ExecuTorch backends on Apple Silicon, offering 2-6x speedups over what was previously possible with ExecuTorch, and up to 30% smaller model sizes compared to XNNPACK due to BF16 support and tied quantized embedding support.

The PR is large due to extensive op coverage, testing, and documentation, but almost all changes are confined to backends/mlx/. The design is described in backends/mlx/README.md.

Suggested review approach:

  1. Review files outside backends/mlx/ carefully — these integrate with ExecuTorch's build system and are the most likely to need changes.
  2. For backends/mlx/, focus on structural design (see README) and test coverage (CI job is .github/workflows/mlx.yml)

Prerequisite PRs

These fixes were developed alongside the MLX backend. Once merged, this PR can be rebased to remove the duplicated changes:

  • #17257 — Improve lowering time with NamedDataMap
  • #17679 — Allow transform passes in etLLM
  • #17678 — Fix dynamic shape bug in remove_noop_pass
  • #17378 — Fix pocketfft intermittent bus errors on macOS (upstream fix)

Tests

CI is defined in .github/workflows/mlx.yml:

  • test_ops.py: 637 passing op tests
  • Multithreading: launches models on 50 threads, verifies correctness
  • GenAI E2E: parakeet, voxtral, etLLM (stories110m), HF LLMs (llama, qwen, gemma)
  • backend-tester: 380 passed, 1 failed, 86 skipped operator tests; 34 passed, 3 failed, 5 skipped model tests

The 1 failing operator test is a test-side issue being fixed in #17539. The 3 failing model tests will be addressed in follow-ups — they are not an initial focus compared to the GenAI models above.

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 20, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16718

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 2 Cancelled Jobs, 1 Unrelated Failure

As of commit d8ee9d2 with merge base 25f2a3f (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 20, 2026
@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@metascroy metascroy force-pushed the mlx-delegate branch 7 times, most recently from e0e015c to 240b241 Compare February 25, 2026 01:06
@metascroy metascroy changed the title [draft] MLX delegate ExecuTorch MLX delegate Feb 25, 2026
@@ -451,7 +451,9 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag
return self

def to_edge_transform_and_lower(
self, partitioners: Optional[List[Partitioner]]
self,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove once #17679 lands

# Only do this check if all the dims are static.
if all(isinstance(dim, int) for dim in orig_tensor.size()):
if orig_tensor.shape == node.meta["val"].shape:
output_tensor = node.meta["val"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove once #17678 lands

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved that PR

metascroy added a commit that referenced this pull request Feb 27, 2026
This introduces a CSE pass to ExecuTorch, which eliminates common
subexpressions that occur in exported programs.

This pass was first developed as part of the MLX delegate
(#16718) to optimize
transformers, but I'm introducing it to ExecuTorch more generally
because I believe it could benefit many other backends.

Examples of common subexpressions that occur in transformers:

* Repeated mask constructions per layer (only needs to be done once)
* Repeated extraction of symints from 1d tensors for cache position
(emits .item calls, which cause tensor materialization)

This pass eliminates these inefficiencies without having to rewrite the
model.
}

// For fields that can be either a literal int or a runtime Vid
table IntOrVid {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use a union type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlatBuffer unions can only contain tables, not scalar primitives. Using a union here would require wrapping the int64 literal in a table (table IntLiteral { value: int64; }), adding pointer indirection and a table allocation for every shape dimension. The is_vid bool keeps the common case (literal) as a direct inline field read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more important reason is flatbutter doesn't allow vector of union

scale: float = 1.0;
}

table SdpaNode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats the story for out of tree custom ops? Just not possible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by out of tree custom ops?

From MLX perspective, they allow custom ops using raw shader code. I guess we could add a node for that to the schema.

causal: bool = false;
}

table AddNode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly what happens if MLX changes the schema of any of these operators in a BC breaking way, or even adds a new optional arg with default values. Seems like there would be a lot of op versioning issues by explicitly encoding every ops schema in the serialization format. We also have a lot less influence on their op versioning story compared with ATen (where its already a big pita).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AddNode is an abstraction in the MLX delegate that we control (not part of MLX itself).

The runtime takes AddNode and translates it to an MLX call (add op) by parsing the args. So even if mlx changes the signature of their add op, it doesn't mean we break AddNode in the graph. We would just update our runtime appropriately. (With that said, MLX did say they try to maintain BC of their underlying ops).

Maybe I'm missing the concern here, though?

return out

except ImportError:
# Edge IR ops not available (e.g., when building from ATen dialect)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When are you doing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not. I'll remove this



@REGISTRY.register_pattern(name="ET_KV_CACHE_UPDATE")
class ETKVCacheUpdateHandler(PatternHandler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is both a custom mutator op supported, but also the slice_update ops earlier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you mean the difference between the IndexUpdateNode pattern above and this one?

They are similar, but ET_KV_CACHE_UPDATE has slightly better perf and is tied to a KVCache building block in mlx/llm/cache.py.

In terms of the difference:

  • INDEX_UPDATE matches a functionalized copy_ pattern (used in many HF models) and maps it onto IndexUpdateNode, which is an in-place update op in the MLX backend. IndexUpdateNode works by finding contiguous ranges in indices and repeatedly calling slice update. So in practice it will be 1 slice update call for regular kv cache, but up to 2 calls for ring buffer stuff, but it does require the indices be materialized in order to parse the index ranges (which can slow down MLX computation).

  • ET_KV_CACHE_UPDATE emits as one slice update call, and does not require materialization of the indices tensor.

At this point, the benefit is maybe 10-15% extra perf.



@dataclass
class PatternMatch:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XNNPack has their own pattern matching util iirc. Could we just offer this in Exir cc @GregoryComer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to reuse any pre-existing utilities if they're in exir.

If the XNNPACK one is based on the torch.fx pattern finding util, I don't know if it'll work, though. I recall it being very bad at fuzzy matching of args/kwargs (like you needed exact matches, so you have to write like 3-4 pattern matchers for each op, and it's a bit error prone)


logger.info("Exporting audio preprocessor with MLX backend...")

model = WhisperAudioProcessor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename if this is used in models other than Whisper (like here in Voxtral)

@metascroy
Copy link
Contributor Author

  • In the runtime unittest, can we compile with -Wall, -Werror, and -Wconversion, -Wsign-conversion, -Wshorten-64-to-32, asan, ubsan flag to uncover any security issues
  • Use overflow-safe arithmetic for all bounds checks. Use __builtin_add_overflow / __builtin_mul_overflow or c10's safe arithmetic utilities.
  • Check for null pointer deferences

I added a strict compile test, addressed security issues, and addressed null ptr references.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants