Add MLX backend support for Gemma 4 31B#19524
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19524
Note: Links to docs will display an error until the docs builds have been completed. ❗ 2 Active SEVsThere are 2 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
7394a9c to
6423b4b
Compare
There was a problem hiding this comment.
Pull request overview
Adds an MLX (Apple Silicon) backend for the Gemma 4 31B-IT example, sharing the quantized checkpoint with the existing CUDA path. The MLX flow re-packs Int4 weights as IntxUnpackedToInt8Tensor, applies source transforms that swap PyTorch ops for mlx.rope/mlx.kv_cache_update/mlx.custom_sdpa, and exports a single dynamic-seq_len forward method that the C++ runner samples on host. A small runtime fix in MLXInterpreter/custom_ops.py makes fast::rope work with 1-D freqs (for proportional partial RoPE). The shared model/sampler/runner are simplified: temperature is now required, lm_head always runs on the last token only, and main.cpp is unified behind #ifdef EXECUTORCH_BUILD_CUDA.
Changes:
- Add MLX packer, source transforms, and
_export_mlxpath plus an MLX preset/Makefile target. - Drop the optional/
Nonetemperature codepath frommodel.forwardandsampler.sample; always return sampled tokens / last-token logits. - Refactor the C++ runner to support both backends, restructure prefill/decode dispatch, and fix MLX
ropeto passbase=nulloptwhenfreqsis provided.
Reviewed changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| Makefile | Adds gemma4_31b-mlx target. |
| examples/models/gemma4_31b/CMakeLists.txt | Selects CUDA or MLX backend and links MLX delegate + metallib. |
| examples/models/gemma4_31b/CMakePresets.json | Adds gemma4-31b-mlx configure/build/workflow presets (Darwin-gated). |
| examples/models/gemma4_31b/README.md | Documents MLX export and runner build. |
| examples/models/gemma4_31b/model.md | Documents MLX method signature, KV-cache replacement, and pack flow. |
| examples/models/gemma4_31b/export.py | Adds mlx backend dispatch and _export_mlx (single dynamic forward, MLXPartitioner). |
| examples/models/gemma4_31b/inference.py | Default temperature changed from 0.0 to 0.8. |
| examples/models/gemma4_31b/model.py | forward now requires temperature and always uses last-logits-only sampling. |
| examples/models/gemma4_31b/sampler.py | Removes temperature=None pass-through. |
| examples/models/gemma4_31b/mlx_source_transformations.py | Replaces attention/layer/model forward and KV caches with MLX-op variants. |
| examples/models/gemma4_31b/main.cpp | Unifies CUDA/MLX runner; adds host-side logits_to_token for MLX, hard-codes EOS 106, simplifies error handling. |
| examples/models/gemma4_31b/quant/init.py | Re-exports MLX packers. |
| examples/models/gemma4_31b/quant/pack_mlx.py | Implements Int4 → IntxUnpackedToInt8Tensor conversion and embedding regrouping for MLX. |
| examples/models/gemma4_31b/quant/README.md | Documents pack_mlx.py. |
| examples/models/gemma4_31b/quant/tests/test_pack_mlx.py | Unit tests for MLX packing helpers. |
| examples/models/gemma4_31b/tests/test_cuda_pipeline.py | Updates chunked-prefill test to new (temp)-required signature returning token IDs. |
| examples/models/gemma4_31b/tests/test_mlx_pipeline.py | End-to-end MLX pack/forward/export tests on a tiny model. |
| backends/mlx/custom_ops.py | Handles 1-D freqs (partial-rotary) input to rope. |
| backends/mlx/runtime/MLXInterpreter.h | Passes base=nullopt to fast::rope when freqs is provided. |
| backends/mlx/test/test_ops.py | Adds RopeCustomFreqsTest for 1-D-freqs RoPE. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@claude Review this PR |
|
Claude finished @mergennachin's task in 2m 1s —— View job Review: PR #19524 — Add MLX backend support for Gemma 4 31B
SummaryThis is a well-structured PR that adds Apple Silicon (MLX) backend support for Gemma 4 31B-IT. The architecture cleanly separates concerns: MLX-specific packing happens at load time, source transforms replace ops with MLX custom ops at export time, and the C++ runner uses Issues & Suggestions1. Error handling for CUDA
|
| intx_w = _int4_to_intx_unpacked(int4_w) | ||
| self.assertEqual(intx_w.shape, torch.Size([64, 256])) | ||
|
|
||
| def test_matmul_approximates_original(self): |
There was a problem hiding this comment.
Do we have a test case that covers comparing regrouped linear with original?
| module.weight = nn.Parameter(w, requires_grad=False) | ||
|
|
||
|
|
||
| def pack_embedding_for_mlx(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: |
There was a problem hiding this comment.
Is there a reason why we need separate functions for linear and embedding?
It looks like both functions do the same conversion, except embedding also has the regrouping logic (which linear could also use)?
- pack_mlx.py: converts Int4Tensor → IntxUnpackedToInt8Tensor at pack time (nibble unpack + scale transpose) so the default dispatch produces the dequantize_affine → linear pattern MLX expects. IntxUnpackedToInt8Tensor passes through unchanged. Embedding with incompatible per-axis group_size is regrouped to gs=128. - export.py: add --backend mlx with single-method export (dynamic seq_len), sampler stripping, and MLXPartitioner lowering. No int4_dispatch import — MLX uses the standard dequantize_affine path. - main.cpp: handle both CUDA (prefill+decode, on-device sampling) and MLX (single forward method, host-side argmax) via #ifdef. - CMakeLists.txt / CMakePresets.json / Makefile: add gemma4_31b-mlx build target linking mlxdelegate. - test_pack_mlx.py: 15 tests covering Int4→IntxUnpacked conversion correctness, passthrough, regrouping, error cases. - test_mlx_pipeline.py: 4 e2e tests including export-to-pte. Validated: same CUDA-quantized checkpoint packs for both backends, 100% op delegation to MLX, real 31B checkpoint packs at 4.0 GB RSS. PR authored with Claude.
…LX packer - Replace custom argmax_last_token with llm::logits_to_token for host-side sampling on non-CUDA builds, matching qwen3_5_moe runner. Supports temperature-controlled sampling (was greedy-only). - Add --cuda_graph warning on non-CUDA builds. - Support Int4Tensor embeddings in pack_embedding_for_mlx by converting to IntxUnpackedToInt8Tensor (same as linear path). - Add divisibility guard in _regroup_intx. Co-authored-by: Claude <noreply@anthropic.com>
… temp handling - Always compute logits for the last position only (lm_head on x[:, -1, :]), avoiding the (1, T, 262144) matmul during prefill. Applies to both CUDA and MLX paths. - Remove the temperature=None codepath from model.py forward and sampler.py. Temperature is now always required. MLX _strip_sampler_from_forward handles the no-sampler case independently. - Add mlx_source_transformations.py: replaces generic PyTorch ops with mlx.rope, mlx.kv_cache_update, and mlx.custom_sdpa for optimized Metal kernels. Applied during MLX export before torch.export. - Unify temperature clamping in main.cpp: compute temp_val once before the #ifdef, used by both CUDA (temp_tensor) and MLX (logits_to_token). - Fix generate() default temperature to 0.8 (was 0.0, inconsistent with C++).
…layers) - custom_ops.py: support 1D freqs in the Python fake op. When freqs is 1D, compute inv_freq = 1/freqs and build angles from positions, matching the C++ runtime behavior. 2D freqs path unchanged. - MLXInterpreter.h: pass base=nullopt when freqs is provided. MLX's fast::rope requires exactly one of base or freqs. - mlx_source_transformations.py: pass dims=rotary_dim (not head_dim) with 1D freqs containing only the non-zero rotary frequencies. The old code passed 2D precomputed angles which was incorrect at the C++ level. - test_ops.py: add RopeCustomFreqsTest (3 configs) verifying export and MLX delegation with 1D custom frequencies. Co-authored-by: Claude <noreply@anthropic.com>
The MLX source transformation replaces each attention's forward with one that builds its own masks via custom_sdpa (is_causal / sliding window mask). The model-level _build_masks was still being called and its results passed through every layer — dead code that introduced non-delegatable ops (arange, comparisons, boolean masks) into the traced graph, potentially splitting the MLX delegate and causing CPU-GPU sync points. Strip _build_masks from _clean_forward, replace layer forward to remove mask parameters, and drop the unused attn_mask argument from the MLX attention forward. Verified with tiny model: 1 delegate call, 0 non-delegated ops. Co-authored-by: Claude <noreply@anthropic.com>
Previously _strip_sampler_from_forward (in export.py) and mlx_source_transformations had to be called together — neither worked alone after the layer/attention forwards were changed to drop mask parameters. The model-level forward still called _build_masks and passed 4 args to layers, which now only accept 2. Move the model-level forward replacement into mlx_source_transformations as _replace_model_forward. After the transform, the model has signature (tokens, input_pos) → (B, V) logits — no temperature, no sampler, no masks. Remove the now-redundant _strip_sampler_from_forward from export.py and update the test to use the 2-arg signature. Co-authored-by: Claude <noreply@anthropic.com>
main.cpp: - Fix stale top comment (was claiming MLX did greedy argmax with no temperature; now correctly describes host-side llm::logits_to_token). - Print the first generated token (from the last prefill chunk) using the last prompt token as the streaming-decode prefix — previously it was dropped (off-by-one in the output). - Set stats.first_token_ms after prefill (was set at decode step 0, which mis-named the TTFT metric). - Add hit_eos check before decode loop in case the first generated token is already EOS. - Remove redundant per-stat printfs — print_report covers them. test_mlx_pipeline: - Add test_source_transforms_use_mlx_ops to verify the traced graph contains the expected MLX custom ops (2× rope, 2× kv_cache_update, 1× custom_sdpa per layer). Regression-protects the source transforms against silent reverts to PyTorch ops. Co-authored-by: Claude <noreply@anthropic.com>
All other runners in examples/models/ use Mmap. The two LLM runners (qwen3_5_moe, gemma4_31b) were the outliers using File, likely inherited rather than intentional. File uses pread() syscalls per access; Mmap maps the file lazily and benefits from the OS page cache. For a 21 GB .pte this matters: per-access pread has measurable overhead and prevents kernel-level read-ahead. Co-authored-by: Claude <noreply@anthropic.com>
On memory-constrained machines (e.g. 32 GB with a 21 GB model), plain Mmap causes constant page eviction and re-fault — 0.2 tok/s decode. MmapUseMlockIgnoreErrors calls mlock() to pin pages after mmap, keeping weights resident once loaded. Falls back silently if mlock exceeds the system limit. Measured 4.7 tok/s decode on M3 Pro 32 GB (24x speedup). Co-authored-by: Claude <noreply@anthropic.com>
The error checks were dropped during refactoring. Without them, a failed set_option silently disables weight sharing, causing prefill and decode to allocate separate KV-cache buffers (OOM at runtime with no diagnostic). Also resolve <turn|> EOS token ID from the tokenizer at startup instead of hardcoding token 106. Co-authored-by: Claude <noreply@anthropic.com>
| for layer in self.layers: | ||
| x = layer(x, input_pos) | ||
| x = self.norm(x) | ||
| last = self.lm_head(x[:, -1, :]).float() |
There was a problem hiding this comment.
Do we need to cast here? Does torch.tanh not use higher precision internally?
There was a problem hiding this comment.
The .float() cast is needed. Without it, the tanh would run in bf16 (input dtype)
The embedding packer was a strict superset of the linear packer — both convert Int4Tensor → IntxUnpacked, but embedding also regroups to an MLX-compatible group_size. Regrouping is a no-op for standard linear group sizes (32/64/128), so one function handles both. Add test_regroup_preserves_dequant to verify regrouped linear weights dequantize identically to the originals. Co-authored-by: Claude <noreply@anthropic.com>
Adds Apple Silicon (MLX) backend for the Gemma 4 31B-IT model. The same quantized checkpoint works for both CUDA and MLX — backend-specific packing happens at load time.
Key changes:
Nothing in the CUDA backend code itself. The CUDA-side changes are in the shared model/runner code:
On my 32GB RAM M1 macbook pro