Skip to content

Add MLX backend support for Gemma 4 31B#19524

Merged
mergennachin merged 13 commits into
mainfrom
gemma4_mlx
May 19, 2026
Merged

Add MLX backend support for Gemma 4 31B#19524
mergennachin merged 13 commits into
mainfrom
gemma4_mlx

Conversation

@mergennachin
Copy link
Copy Markdown
Contributor

@mergennachin mergennachin commented May 12, 2026

Adds Apple Silicon (MLX) backend for the Gemma 4 31B-IT model. The same quantized checkpoint works for both CUDA and MLX — backend-specific packing happens at load time.

Key changes:

  • MLX packer converts Int4Tensor → IntxUnpackedToInt8Tensor for MLX's quantized linear fusion
  • Source transforms replace PyTorch ops with mlx.rope, mlx.kv_cache_update, mlx.custom_sdpa for optimized Metal kernels
  • Proportional partial RoPE (full-attention layers) passes 1D frequencies to mlx.rope with dims=rotary_dim, fixing the C++ runtime to pass base=nullopt when freqs is provided
  • Single-method export with dynamic seq_len and host-side sampling
  • C++ runner supports both backends via #ifdef, using shared logits_to_token for MLX sampling
  • Last-logits-only optimization: lm_head always runs on last position only, removing the full-logits codepath entirely

Nothing in the CUDA backend code itself. The CUDA-side changes are in the shared model/runner code:

  • model.py: forward() now always does last-logits-only and temperature is required (no None path). Affects both CUDA and MLX.
  • sampler.py: Removed temperature=None passthrough.
  • main.cpp: Unified temp_val clamping before the #ifdef. CUDA path behavior unchanged.
  • inference.py: Default temperature changed from 0.0 to 0.8 to match C++ runner default.

On my 32GB RAM M1 macbook pro

(executorch_dev) mnachin@mnachin-mbp executorch % cmake-out/examples/models/gemma4_31b/gemma4_31b_runner --model_path ~/repos/models/gemma-4-31B-it-HQQ-INT4/model.pte --tokenizer_path ~/repos/models/gemma-4-31B-it-HQQ-INT4/tokenizer.json --prompt "Write a short joke about saving RAM." --max_new_tokens 128
I tokenizers:regex.cpp:27] Registering override fallback regex
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1779218557.174278 43844526 re2.cc:237] Error parsing '((\<pad\>|ool\|\>1\x00\x00\
                                                                                             �\<t|respo|\<tool_call\|\>|\<bos\>|\<\|tool_response\>|\<\|think\|\>|\x0...': invalid UTF-8
I tokenizers:re2_regex.cpp:27] Re2 failed to compile regex: ((\<pad\>|ool\|\>1\x00\x00\
                                                                                       �\<t|respo|\<tool_call\|\>|\<bos\>|\<\|tool_response\>|\<\|think\|\>|\x00\x00\\\<|\<tool_response\|\>|\<mask\>|\<\|\"\|\>|all\|\>j\x00\x00\\|\<channel\|\>|\<\|turn\>|\<turn\|\>|\<\|image\>|\<\|$
I tokenizers:regex_lookahead.cpp:27] Creating PCRE2 regex
I tokenizers:pcre2_regex.cpp:48] PCRE2 UTF-8 validation failed at offset 27: UTF-8 error: byte 2 top bits not 0x80. Retrying without UTF flags.
Loading model...
Prompt tokens: 24
Why did the programmer get kicked out of the library?

He kept trying to free the memory.<turn|>
PyTorchObserver {"prefill_token_per_sec":7.56859,"decode_token_per_sec":2.09161,"prompt_tokens":24,"generated_tokens":20,"model_load_start_ms":1779218556804,"model_load_end_ms":1779218560048,"inference_start_ms":1779218560052,"inference_end_ms":1779218572785,"prompt_eval_end_ms":1779218563223,"first_token_ms":1779218563223,"aggregate_sampling_time_ms":0,"SCALING_FACTOR_UNITS_PER_SECOND":1000}

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 12, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19524

Note: Links to docs will display an error until the docs builds have been completed.

❗ 2 Active SEVs

There are 2 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 12, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Comment thread examples/models/gemma4_31b/main.cpp Outdated
Comment thread examples/models/gemma4_31b/quant/pack_mlx.py Outdated
Comment thread examples/models/gemma4_31b/quant/pack_mlx.py
@mergennachin mergennachin force-pushed the gemma4_mlx branch 2 times, most recently from 7394a9c to 6423b4b Compare May 18, 2026 21:24
@mergennachin mergennachin marked this pull request as ready for review May 19, 2026 19:26
Copilot AI review requested due to automatic review settings May 19, 2026 19:26
@mergennachin mergennachin requested a review from lucylq as a code owner May 19, 2026 19:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an MLX (Apple Silicon) backend for the Gemma 4 31B-IT example, sharing the quantized checkpoint with the existing CUDA path. The MLX flow re-packs Int4 weights as IntxUnpackedToInt8Tensor, applies source transforms that swap PyTorch ops for mlx.rope/mlx.kv_cache_update/mlx.custom_sdpa, and exports a single dynamic-seq_len forward method that the C++ runner samples on host. A small runtime fix in MLXInterpreter/custom_ops.py makes fast::rope work with 1-D freqs (for proportional partial RoPE). The shared model/sampler/runner are simplified: temperature is now required, lm_head always runs on the last token only, and main.cpp is unified behind #ifdef EXECUTORCH_BUILD_CUDA.

Changes:

  • Add MLX packer, source transforms, and _export_mlx path plus an MLX preset/Makefile target.
  • Drop the optional/None temperature codepath from model.forward and sampler.sample; always return sampled tokens / last-token logits.
  • Refactor the C++ runner to support both backends, restructure prefill/decode dispatch, and fix MLX rope to pass base=nullopt when freqs is provided.

Reviewed changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
Makefile Adds gemma4_31b-mlx target.
examples/models/gemma4_31b/CMakeLists.txt Selects CUDA or MLX backend and links MLX delegate + metallib.
examples/models/gemma4_31b/CMakePresets.json Adds gemma4-31b-mlx configure/build/workflow presets (Darwin-gated).
examples/models/gemma4_31b/README.md Documents MLX export and runner build.
examples/models/gemma4_31b/model.md Documents MLX method signature, KV-cache replacement, and pack flow.
examples/models/gemma4_31b/export.py Adds mlx backend dispatch and _export_mlx (single dynamic forward, MLXPartitioner).
examples/models/gemma4_31b/inference.py Default temperature changed from 0.0 to 0.8.
examples/models/gemma4_31b/model.py forward now requires temperature and always uses last-logits-only sampling.
examples/models/gemma4_31b/sampler.py Removes temperature=None pass-through.
examples/models/gemma4_31b/mlx_source_transformations.py Replaces attention/layer/model forward and KV caches with MLX-op variants.
examples/models/gemma4_31b/main.cpp Unifies CUDA/MLX runner; adds host-side logits_to_token for MLX, hard-codes EOS 106, simplifies error handling.
examples/models/gemma4_31b/quant/init.py Re-exports MLX packers.
examples/models/gemma4_31b/quant/pack_mlx.py Implements Int4 → IntxUnpackedToInt8Tensor conversion and embedding regrouping for MLX.
examples/models/gemma4_31b/quant/README.md Documents pack_mlx.py.
examples/models/gemma4_31b/quant/tests/test_pack_mlx.py Unit tests for MLX packing helpers.
examples/models/gemma4_31b/tests/test_cuda_pipeline.py Updates chunked-prefill test to new (temp)-required signature returning token IDs.
examples/models/gemma4_31b/tests/test_mlx_pipeline.py End-to-end MLX pack/forward/export tests on a tiny model.
backends/mlx/custom_ops.py Handles 1-D freqs (partial-rotary) input to rope.
backends/mlx/runtime/MLXInterpreter.h Passes base=nullopt to fast::rope when freqs is provided.
backends/mlx/test/test_ops.py Adds RopeCustomFreqsTest for 1-D-freqs RoPE.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/gemma4_31b/main.cpp
Comment thread examples/models/gemma4_31b/main.cpp Outdated
Comment thread examples/models/gemma4_31b/inference.py
Copilot AI review requested due to automatic review settings May 19, 2026 19:38
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 20 changed files in this pull request and generated no new comments.

@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR

@claude
Copy link
Copy Markdown

claude Bot commented May 19, 2026

Claude finished @mergennachin's task in 2m 1s —— View job


Review: PR #19524 — Add MLX backend support for Gemma 4 31B

  • Read and analyze core implementation files
  • Review MLX source transformations
  • Review MLX packing logic
  • Review C++ runner changes
  • Review model/sampler/export changes
  • Review tests
  • Check MLX backend runtime changes
  • Provide comprehensive review feedback

Summary

This is a well-structured PR that adds Apple Silicon (MLX) backend support for Gemma 4 31B-IT. The architecture cleanly separates concerns: MLX-specific packing happens at load time, source transforms replace ops with MLX custom ops at export time, and the C++ runner uses #ifdef to dispatch between backends. The shared model/runner simplifications (last-logits-only, required temperature) are sensible.


Issues & Suggestions

1. Error handling for CUDA set_option — Already restored ✅

Copilot flagged missing error handling for set_option in main.cpp. Looking at the current code (lines 193–210), error handling is present — both backend_options.set_option(...) and executorch::runtime::set_option(...) check for Error::Ok and return with an error message. This appears to have been addressed in commit b5285d5. No action needed.

2. Hard-coded EOS token ID 106 removed ✅

Looking at main.cpp:241-246, the current implementation dynamically resolves the end-of-turn token via tokenizer->encode("<turn|>", ...) and inserts it into eos_ids. The hard-coded 106 from an earlier version has been replaced. Good.

3. _mlx_forward uses start_pos = input_pos[0].item() — potential tracing issue

mlx_source_transformations.py:35:

start_pos = input_pos[0].item()

Using .item() inside a function that will be traced with torch.export creates a data-dependent integer (SymInt). This is fine for the MLX ops that accept int positions, and the test at test_mlx_pipeline.py:201-207 confirms export succeeds with dynamic shapes. However, this means every op downstream that uses start_pos creates a graph break point for shape specialization. This is likely intentional for the MLX backend (which handles SymInts at runtime), but worth confirming it doesn't regress if the model is ever exported with strict=False.

4. rope custom op: freqs interpretation mismatch risk

custom_ops.py:231-237 — When freqs is 1D, the Python reference computes:

inv_freq = (1.0 / freqs.to(torch.float32)).view(1, 1, 1, half)

And MLXInterpreter.h:247-248 passes base=nullopt when freqs is provided. This is correct for the MLX C++ fast::rope which computes 1/freqs internally. However, the Python op and C++ runtime have inverted semantics: the Python op receives raw frequencies and computes 1/freqs, while the MLX C++ kernel receives raw frequencies and also computes 1/freqs. This works because the Python op is only used for tracing (not execution), and the C++ runtime handles the actual computation. The docstring at line 64 of mlx_source_transformations.py says "MLX computes inv_freq = 1/freqs internally" which correctly documents this.

However, in mlx_source_transformations.py:67:

mlx_freqs = 1.0 / rotary_inv_freq

Here rotary_inv_freq is already 1/(theta^(2i/D)), so mlx_freqs = theta^(2i/D). This gets passed to the rope op which then computes 1/mlx_freqs = 1/theta^(2i/D) = inv_freq. The math checks out, but this double-inversion is subtle. A brief comment at line 67 clarifying the chain would help future readers (e.g., # Convert inv_freq back to raw freq; MLX's fast::rope will re-invert).

5. _regroup_intx correctness for non-exact-divisible cases

pack_mlx.py:94-95 already validates:

if old_gs % new_gs != 0:
    raise ValueError(...)

Good — the reviewer comment from @metascroy about erroring out if old_gs doesn't divide new_gs is already handled.

6. MLX export: min=1 for seq_dim vs CUDA min=5

export.py:311:

seq_dim = Dim("seq_len", min=1, max=max_prefill)

CUDA uses min=5 (line 186) because it has a separate decode method for T=1. MLX uses a unified method with min=1 which is correct since it handles both prefill and decode in forward. This is fine.

7. Temperature default change (inference.py:89)

The default temperature changed from 0.0 to 0.8. As Copilot noted, this is a user-visible behavioral change. The PR description documents the rationale (matching C++ runner default). The README should mention this for users upgrading, but this is a documentation concern rather than a code issue.

8. Missing attn_mask parameter in sliding-window SDPA call

mlx_source_transformations.py:74-84:

sdpa_mask = self.kv_cache.create_sliding_window_mask(start_pos, T)
y = torch.ops.mlx.custom_sdpa(
    q, k_cache, v_cache,
    start_pos=self.kv_cache.buffer_size - T,
    attn_mask=sdpa_mask,
    ...
)

The sliding-window path passes is_causal=False with an explicit attn_mask. The full-attention path passes is_causal=True without a mask. This is correct — custom_sdpa (line 151 of custom_ops.py) builds its own lower-right causal mask when is_causal=True.

9. pack_embedding_for_mlx handling of non-quantized embeddings

pack_mlx.py:142-161: If the embedding weight is neither Int4Tensor nor IntxUnpackedToInt8Tensor (e.g., a plain float tensor), the function silently does nothing beyond the isinstance checks — it doesn't assign module.weight. This is actually fine because pack_model in pack.py falls through to pack_one's default handler for non-quantized weights. But if pack_embedding_for_mlx is called directly with a non-quantized tensor, it would be a no-op, which could be confusing. Consider adding a fallback module.weight = nn.Parameter(w, requires_grad=False) at the end, or documenting that it only handles quantized embeddings.

10. Test coverage is solid

The test suite covers:

  • Unit tests for Int4→IntxUnpacked conversion (test_pack_mlx.py)
  • End-to-end pipeline: quantize → pack → forward (test_mlx_pipeline.py)
  • Source transform verification: correct op counts in traced graph
  • Full export to .pte (skipped if MLX backend unavailable)
  • RoPE with custom 1D frequencies (test_ops.py:RopeCustomFreqsTest)
  • CUDA pipeline tests updated for new signature

Architecture Assessment

The design is clean:

  • Separation of concerns: MLX-specific code lives in dedicated files (mlx_source_transformations.py, pack_mlx.py, _export_mlx)
  • Backend-agnostic model: model.py stays backend-neutral; transforms are applied at export time
  • Shared checkpoint: Same quantized safetensors works for CUDA and MLX — only the packing differs
  • Unified runner: main.cpp cleanly separates CUDA/MLX paths with #ifdef

The simplification to always do last-logits-only and require temperature is a positive change that removes dead code paths.


Verdict

This PR is in good shape. The core logic is correct, the test coverage is comprehensive, and the architecture is well-designed. The main actionable item is the subtle freq inversion in the source transforms (point 4) which would benefit from a clarifying comment, though it's not a bug. The embedding packer edge case (point 9) is minor.

intx_w = _int4_to_intx_unpacked(int4_w)
self.assertEqual(intx_w.shape, torch.Size([64, 256]))

def test_matmul_approximates_original(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a test case that covers comparing regrouped linear with original?

module.weight = nn.Parameter(w, requires_grad=False)


def pack_embedding_for_mlx(module: nn.Module, weights: dict[str, torch.Tensor]) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we need separate functions for linear and embedding?

It looks like both functions do the same conversion, except embedding also has the regrouping logic (which linear could also use)?

mergennachin and others added 4 commits May 19, 2026 17:07
- pack_mlx.py: converts Int4Tensor → IntxUnpackedToInt8Tensor at pack
  time (nibble unpack + scale transpose) so the default dispatch
  produces the dequantize_affine → linear pattern MLX expects.
  IntxUnpackedToInt8Tensor passes through unchanged. Embedding with
  incompatible per-axis group_size is regrouped to gs=128.
- export.py: add --backend mlx with single-method export (dynamic
  seq_len), sampler stripping, and MLXPartitioner lowering. No
  int4_dispatch import — MLX uses the standard dequantize_affine path.
- main.cpp: handle both CUDA (prefill+decode, on-device sampling) and
  MLX (single forward method, host-side argmax) via #ifdef.
- CMakeLists.txt / CMakePresets.json / Makefile: add gemma4_31b-mlx
  build target linking mlxdelegate.
- test_pack_mlx.py: 15 tests covering Int4→IntxUnpacked conversion
  correctness, passthrough, regrouping, error cases.
- test_mlx_pipeline.py: 4 e2e tests including export-to-pte.

Validated: same CUDA-quantized checkpoint packs for both backends,
100% op delegation to MLX, real 31B checkpoint packs at 4.0 GB RSS.

PR authored with Claude.
…LX packer

- Replace custom argmax_last_token with llm::logits_to_token for
  host-side sampling on non-CUDA builds, matching qwen3_5_moe runner.
  Supports temperature-controlled sampling (was greedy-only).
- Add --cuda_graph warning on non-CUDA builds.
- Support Int4Tensor embeddings in pack_embedding_for_mlx by converting
  to IntxUnpackedToInt8Tensor (same as linear path).
- Add divisibility guard in _regroup_intx.

Co-authored-by: Claude <noreply@anthropic.com>
… temp handling

- Always compute logits for the last position only (lm_head on x[:, -1, :]),
  avoiding the (1, T, 262144) matmul during prefill. Applies to both CUDA
  and MLX paths.
- Remove the temperature=None codepath from model.py forward and sampler.py.
  Temperature is now always required. MLX _strip_sampler_from_forward handles
  the no-sampler case independently.
- Add mlx_source_transformations.py: replaces generic PyTorch ops with
  mlx.rope, mlx.kv_cache_update, and mlx.custom_sdpa for optimized Metal
  kernels. Applied during MLX export before torch.export.
- Unify temperature clamping in main.cpp: compute temp_val once before the
  #ifdef, used by both CUDA (temp_tensor) and MLX (logits_to_token).
- Fix generate() default temperature to 0.8 (was 0.0, inconsistent with C++).
…layers)

- custom_ops.py: support 1D freqs in the Python fake op. When freqs is
  1D, compute inv_freq = 1/freqs and build angles from positions,
  matching the C++ runtime behavior. 2D freqs path unchanged.
- MLXInterpreter.h: pass base=nullopt when freqs is provided. MLX's
  fast::rope requires exactly one of base or freqs.
- mlx_source_transformations.py: pass dims=rotary_dim (not head_dim)
  with 1D freqs containing only the non-zero rotary frequencies. The
  old code passed 2D precomputed angles which was incorrect at the C++
  level.
- test_ops.py: add RopeCustomFreqsTest (3 configs) verifying export and
  MLX delegation with 1D custom frequencies.

Co-authored-by: Claude <noreply@anthropic.com>
mergennachin and others added 8 commits May 19, 2026 17:07
The MLX source transformation replaces each attention's forward with one
that builds its own masks via custom_sdpa (is_causal / sliding window
mask).  The model-level _build_masks was still being called and its
results passed through every layer — dead code that introduced
non-delegatable ops (arange, comparisons, boolean masks) into the traced
graph, potentially splitting the MLX delegate and causing CPU-GPU sync
points.

Strip _build_masks from _clean_forward, replace layer forward to remove
mask parameters, and drop the unused attn_mask argument from the MLX
attention forward.  Verified with tiny model: 1 delegate call,
0 non-delegated ops.

Co-authored-by: Claude <noreply@anthropic.com>
Previously _strip_sampler_from_forward (in export.py) and
mlx_source_transformations had to be called together — neither worked
alone after the layer/attention forwards were changed to drop mask
parameters. The model-level forward still called _build_masks and passed
4 args to layers, which now only accept 2.

Move the model-level forward replacement into mlx_source_transformations
as _replace_model_forward. After the transform, the model has signature
(tokens, input_pos) → (B, V) logits — no temperature, no sampler, no
masks. Remove the now-redundant _strip_sampler_from_forward from
export.py and update the test to use the 2-arg signature.

Co-authored-by: Claude <noreply@anthropic.com>
main.cpp:
- Fix stale top comment (was claiming MLX did greedy argmax with no
  temperature; now correctly describes host-side llm::logits_to_token).
- Print the first generated token (from the last prefill chunk) using
  the last prompt token as the streaming-decode prefix — previously it
  was dropped (off-by-one in the output).
- Set stats.first_token_ms after prefill (was set at decode step 0,
  which mis-named the TTFT metric).
- Add hit_eos check before decode loop in case the first generated
  token is already EOS.
- Remove redundant per-stat printfs — print_report covers them.

test_mlx_pipeline:
- Add test_source_transforms_use_mlx_ops to verify the traced graph
  contains the expected MLX custom ops (2× rope, 2× kv_cache_update,
  1× custom_sdpa per layer). Regression-protects the source transforms
  against silent reverts to PyTorch ops.

Co-authored-by: Claude <noreply@anthropic.com>
All other runners in examples/models/ use Mmap. The two LLM runners
(qwen3_5_moe, gemma4_31b) were the outliers using File, likely
inherited rather than intentional. File uses pread() syscalls per
access; Mmap maps the file lazily and benefits from the OS page cache.
For a 21 GB .pte this matters: per-access pread has measurable overhead
and prevents kernel-level read-ahead.

Co-authored-by: Claude <noreply@anthropic.com>
On memory-constrained machines (e.g. 32 GB with a 21 GB model), plain
Mmap causes constant page eviction and re-fault — 0.2 tok/s decode.
MmapUseMlockIgnoreErrors calls mlock() to pin pages after mmap, keeping
weights resident once loaded. Falls back silently if mlock exceeds the
system limit.  Measured 4.7 tok/s decode on M3 Pro 32 GB (24x speedup).

Co-authored-by: Claude <noreply@anthropic.com>
The error checks were dropped during refactoring. Without them, a failed
set_option silently disables weight sharing, causing prefill and decode
to allocate separate KV-cache buffers (OOM at runtime with no diagnostic).

Also resolve <turn|> EOS token ID from the tokenizer at startup instead
of hardcoding token 106.

Co-authored-by: Claude <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings May 19, 2026 21:07
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 22 out of 22 changed files in this pull request and generated no new comments.

for layer in self.layers:
x = layer(x, input_pos)
x = self.norm(x)
last = self.lm_head(x[:, -1, :]).float()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to cast here? Does torch.tanh not use higher precision internally?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .float() cast is needed. Without it, the tanh would run in bf16 (input dtype)

The embedding packer was a strict superset of the linear packer — both
convert Int4Tensor → IntxUnpacked, but embedding also regroups to an
MLX-compatible group_size. Regrouping is a no-op for standard linear
group sizes (32/64/128), so one function handles both.

Add test_regroup_preserves_dequant to verify regrouped linear weights
dequantize identically to the originals.

Co-authored-by: Claude <noreply@anthropic.com>
@mergennachin mergennachin merged commit f8cfc73 into main May 19, 2026
226 of 229 checks passed
@mergennachin mergennachin deleted the gemma4_mlx branch May 19, 2026 22:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants