From 5f455a24a1bc7e0b4064fdda7854792bf998d79c Mon Sep 17 00:00:00 2001 From: Zeel Date: Thu, 23 Apr 2026 07:55:00 -0400 Subject: [PATCH 01/16] Add Gemma 4 MLX install-path support --- .github/workflows/mlx.yml | 13 +- backends/mlx/examples/llm/README.md | 26 +++- backends/mlx/examples/llm/export_llm_hf.py | 14 +- backends/mlx/examples/llm/run_llm_hf.py | 147 +++++++++++++++++++-- backends/mlx/llm/cache.py | 110 ++++++++++++--- backends/mlx/llm/hf_attention.py | 48 ++++++- backends/mlx/llm/source_transformation.py | 57 +++++--- backends/mlx/runtime/MLXBackend.cpp | 28 ++++ setup.py | 48 ++++++- 9 files changed, 421 insertions(+), 70 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 8c079e785e3..ebbfa6b3a68 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -470,6 +470,12 @@ jobs: name: "gemma3-1b" use-custom: [false, true] qconfig: ["4w", "nvfp4"] + include: + - model: + id: "google/gemma-4-E2B-it" + name: "gemma4-e2b" + use-custom: true + qconfig: "4w" uses: pytorch/test-infra/.github/workflows/macos_job.yml@main secrets: inherit with: @@ -493,6 +499,11 @@ jobs: CUSTOM_ARGS="--use-custom-sdpa --use-custom-kv-cache" fi + QEMBEDDING_ARGS="--qembedding ${QCONFIG}" + if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then + QEMBEDDING_ARGS="" + fi + echo "::group::Install ExecuTorch and configure MLX build" ${CONDA_RUN} python install_executorch.py > /dev/null ${CONDA_RUN} cmake --preset mlx-release @@ -512,7 +523,7 @@ jobs: --model-id "${MODEL_ID}" \ --output /tmp/${MODEL_NAME}.pte \ --qlinear ${QCONFIG} \ - --qembedding ${QCONFIG} \ + ${QEMBEDDING_ARGS} \ ${CUSTOM_ARGS} echo "::endgroup::" diff --git a/backends/mlx/examples/llm/README.md b/backends/mlx/examples/llm/README.md index f860c4f1ce0..04bcd500651 100644 --- a/backends/mlx/examples/llm/README.md +++ b/backends/mlx/examples/llm/README.md @@ -9,6 +9,7 @@ This example demonstrates how to export and run LLMs using the MLX delegate for - **KV Cache**: Efficient KV cache implementation for autoregressive generation - **Custom Ops**: Uses `mlx::custom_sdpa` and `mlx::kv_cache_update` for optimal execution on MLX - **Pybindings**: Run inference using ExecuTorch Python bindings +- **Gemma 4**: Text-only export and run flow supports processor-backed checkpoints such as `google/gemma-4-E2B-it` ## Requirements @@ -52,8 +53,19 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --use-custom-kv-cache \ --qlinear 4w \ --qembedding 4w + +# Gemma 4 text-only export +python -m executorch.backends.mlx.examples.llm.export_llm_hf \ + --model-id "google/gemma-4-E2B-it" \ + --output gemma4_hf_int4.pte \ + --use-custom-sdpa \ + --use-custom-kv-cache \ + --qlinear 4w ``` +Gemma 4 support is currently validated for the text-only path using +`--use-custom-sdpa --use-custom-kv-cache --qlinear 4w`. + ### Options | Option | Default | Description | @@ -81,12 +93,24 @@ python -m executorch.backends.mlx.examples.llm.run_llm_hf \ --prompt "Explain quantum computing in simple terms" ``` +Gemma 4 checkpoints may use `AutoProcessor` instead of `AutoTokenizer`; `run_llm_hf` now supports both paths automatically for text-only prompts. + +Validated Gemma 4 run command: + +```bash +python -m executorch.backends.mlx.examples.llm.run_llm_hf \ + --pte gemma4_hf_int4.pte \ + --model-id google/gemma-4-E2B-it \ + --prompt "What is the capital of France?" \ + --max-new-tokens 50 +``` + ### Options | Option | Default | Description | |--------|---------|-------------| | `--pte` | `llama_hf.pte` | Path to .pte file | -| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer) | +| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer or processor) | | `--prompt` | `The quick brown fox` | Input prompt | | `--max-new-tokens` | `50` | Maximum tokens to generate | diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index 39f13e434be..ec5bcf5abf8 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -166,9 +166,6 @@ def _export_with_custom_components( attn_implementation = "mlx" if use_custom_sdpa else None - # Detect sliding window models (e.g., gemma) - sliding_window = None - logger.info(f"Loading HuggingFace model: {model_id}") load_kwargs = { "torch_dtype": torch_dtype, @@ -178,8 +175,10 @@ def _export_with_custom_components( load_kwargs["attn_implementation"] = attn_implementation model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs) - # Check if model uses sliding window attention - sliding_window = getattr(model.config, "sliding_window", None) + # Check if model uses sliding window attention. Multimodal configs like + # Gemma 4 keep transformer attributes under text_config. + text_config = model.config.get_text_config() + sliding_window = getattr(text_config, "sliding_window", None) if sliding_window is not None: logger.info(f"Model has sliding_window={sliding_window}") # Cap max_seq_len to sliding window size for cache allocation @@ -188,11 +187,16 @@ def _export_with_custom_components( else: effective_cache_len = max_seq_len + # The HF ExecuTorch cache wrappers validate both generation_config.use_cache + # and the text config's use_cache flag before constructing static caches. + model.generation_config.use_cache = True model.generation_config.cache_implementation = "static" model.generation_config.cache_config = { "batch_size": 1, "max_cache_len": effective_cache_len, } + text_config = model.config.get_text_config() + text_config.use_cache = True model.eval() # Use HybridCache wrapper for sliding window models (stores cache as .cache), diff --git a/backends/mlx/examples/llm/run_llm_hf.py b/backends/mlx/examples/llm/run_llm_hf.py index ca3d0468114..9c5d1d0bf5f 100644 --- a/backends/mlx/examples/llm/run_llm_hf.py +++ b/backends/mlx/examples/llm/run_llm_hf.py @@ -7,10 +7,11 @@ # LICENSE file in the root directory of this source tree. """ -Run exported Llama model (from HuggingFace) using ExecuTorch pybindings. +Run exported HuggingFace LLM using ExecuTorch pybindings. This script runs models exported using export_llm_hf.py. It loads the tokenizer -directly from HuggingFace using the same model ID used during export. +or processor directly from HuggingFace using the same model ID used during +export. Usage: python -m executorch.backends.mlx.examples.llm.run_llm_hf \ @@ -20,18 +21,89 @@ """ import argparse +import ctypes import logging +import os +import shutil import time +from pathlib import Path import torch from executorch.runtime import Runtime, Verification -from transformers import AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) +def _iter_mlx_backend_candidates(): + env_path = os.environ.get("ET_MLX_BACKEND_DYLIB") + if env_path: + yield Path(env_path) + + for parent in Path(__file__).resolve().parents: + pip_out = parent / "pip-out" + if pip_out.exists(): + yield from sorted( + pip_out.glob( + "temp.*/cmake-out/backends/mlx/libmlxdelegate_runtime.dylib" + ) + ) + yield from sorted( + pip_out.glob("temp.*/cmake-out/backends/mlx/libmlxdelegate.dylib") + ) + break + + +def _ensure_mlx_metallib(dylib_path: Path) -> None: + metallib_path = dylib_path.with_name("mlx.metallib") + if metallib_path.exists(): + return + + for parent in Path(__file__).resolve().parents: + pip_out = parent / "pip-out" + if not pip_out.exists(): + continue + matches = sorted( + pip_out.glob( + "temp.*/cmake-out/backends/mlx/mlx/mlx/backend/metal/kernels/mlx.metallib" + ) + ) + if not matches: + continue + shutil.copyfile(matches[0], metallib_path) + logger.info(f"Copied MLX metallib next to runtime library: {metallib_path}") + return + + +def _ensure_mlx_backend_registered() -> Runtime: + runtime = Runtime.get() + if runtime.backend_registry.is_available("MLXBackend"): + return runtime + + for candidate in _iter_mlx_backend_candidates(): + if not candidate.is_file(): + continue + try: + _ensure_mlx_metallib(candidate) + ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL) + except OSError as exc: + logger.info(f"Failed to load MLX backend library {candidate}: {exc}") + continue + + runtime = Runtime.get() + if runtime.backend_registry.is_available("MLXBackend"): + logger.info(f"Loaded MLX backend runtime library: {candidate}") + return runtime + + logger.warning( + "MLXBackend is not registered. If you built mlxdelegate locally, " + "set ET_MLX_BACKEND_DYLIB to the path of libmlxdelegate_runtime.dylib." + ) + return runtime + + def _get_max_input_seq_len(program) -> int: """Inspect the .pte program metadata to determine the max input_ids seq len. @@ -46,6 +118,50 @@ def _get_max_input_seq_len(program) -> int: return sizes[1] if len(sizes) >= 2 else 1 +def _load_text_processor(model_id: str): + """ + Load a text processor for the model. + + Prefer AutoProcessor for multimodal/text-hybrid models like Gemma 4, and + fall back to AutoTokenizer for text-only checkpoints. + """ + try: + processor = AutoProcessor.from_pretrained(model_id) + if hasattr(processor, "apply_chat_template") and hasattr(processor, "decode"): + logger.info(f"Loaded processor from HuggingFace: {model_id}") + return processor, True + except Exception as exc: + logger.info(f"AutoProcessor unavailable for {model_id}: {exc}") + + logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") + tokenizer = AutoTokenizer.from_pretrained(model_id) + return tokenizer, False + + +def _apply_chat_template(text_processor, messages) -> str: + try: + return text_processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + except TypeError: + return text_processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + +def _get_eos_token_id(text_processor): + eos_token_id = getattr(text_processor, "eos_token_id", None) + if eos_token_id is not None: + return eos_token_id + tokenizer = getattr(text_processor, "tokenizer", None) + return getattr(tokenizer, "eos_token_id", None) + + def run_inference( pte_path: str, model_id: str, @@ -53,11 +169,10 @@ def run_inference( max_new_tokens: int = 50, ) -> str: """Run inference on the exported HuggingFace model.""" - logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") - tokenizer = AutoTokenizer.from_pretrained(model_id) + text_processor, uses_processor = _load_text_processor(model_id) logger.info(f"Loading model from {pte_path}...") - et_runtime = Runtime.get() + et_runtime = _ensure_mlx_backend_registered() program = et_runtime.load_program(pte_path, verification=Verification.Minimal) max_seq_len = _get_max_input_seq_len(program) @@ -67,14 +182,18 @@ def run_inference( logger.info(f"Encoding prompt: {prompt!r}") messages = [{"role": "user", "content": prompt}] - formatted_prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt") + formatted_prompt = _apply_chat_template(text_processor, messages) + if uses_processor: + input_ids = text_processor(text=formatted_prompt, return_tensors="pt")[ + "input_ids" + ] + else: + input_ids = text_processor.encode(formatted_prompt, return_tensors="pt") logger.info(f"Input shape: {input_ids.shape}") generated_tokens = input_ids[0].tolist() seq_len = input_ids.shape[1] + eos_token_id = _get_eos_token_id(text_processor) start_time = time.time() @@ -120,7 +239,7 @@ def run_inference( next_token = torch.argmax(next_token_logits).item() generated_tokens.append(next_token) - if next_token == tokenizer.eos_token_id: + if eos_token_id is not None and next_token == eos_token_id: logger.info(f"EOS token reached at position {i + 1}") break @@ -135,12 +254,12 @@ def run_inference( # Decode only the newly generated tokens (not the input prompt) new_tokens = generated_tokens[seq_len:] - generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) + generated_text = text_processor.decode(new_tokens, skip_special_tokens=True) return generated_text def main(): - parser = argparse.ArgumentParser(description="Run exported HuggingFace Llama model") + parser = argparse.ArgumentParser(description="Run exported HuggingFace LLM") parser.add_argument( "--pte", type=str, @@ -151,7 +270,7 @@ def main(): "--model-id", type=str, default="unsloth/Llama-3.2-1B-Instruct", - help="HuggingFace model ID (used to load tokenizer)", + help="HuggingFace model ID (used to load tokenizer or processor)", ) parser.add_argument( "--prompt", diff --git a/backends/mlx/llm/cache.py b/backends/mlx/llm/cache.py index 9709980689b..b0654adeb71 100644 --- a/backends/mlx/llm/cache.py +++ b/backends/mlx/llm/cache.py @@ -21,6 +21,62 @@ from executorch.backends.mlx import custom_ops as _mlx_custom_ops # noqa: F401 +def resolve_hf_text_config(config): + """Return the text config for multimodal HF models, or the config itself.""" + if hasattr(config, "get_text_config"): + return config.get_text_config() + return getattr(config, "text_config", config) + + +def resolve_hf_cache_layout(config): + """ + Return per-cache-layer metadata for HuggingFace hybrid/static caches. + + Some models such as Gemma 4 use different KV geometries depending on the + attention layer type. Match the upstream `transformers` hybrid cache layout + so our replacement cache allocates the same number of layers with the same + `(num_heads, head_dim)` for each backing cache entry. + """ + text_config = resolve_hf_text_config(config) + layer_types = getattr(text_config, "layer_types", None) + + if layer_types is None: + if getattr(text_config, "sliding_window", None) is not None: + layer_types = ["sliding_attention" for _ in range(text_config.num_hidden_layers)] + else: + layer_types = ["full_attention" for _ in range(text_config.num_hidden_layers)] + else: + layer_types = list(layer_types) + + if hasattr(text_config, "num_kv_shared_layers"): + layer_types = layer_types[: -text_config.num_kv_shared_layers] + + if hasattr(text_config, "global_head_dim"): + head_dims = [ + text_config.global_head_dim if layer_type == "full_attention" else text_config.head_dim + for layer_type in layer_types + ] + num_heads = [ + text_config.num_global_key_value_heads + if layer_type == "full_attention" and getattr(text_config, "attention_k_eq_v", False) + else text_config.num_key_value_heads + for layer_type in layer_types + ] + else: + head_dim = getattr( + text_config, + "head_dim", + text_config.hidden_size // text_config.num_attention_heads, + ) + num_head = getattr( + text_config, "num_key_value_heads", text_config.num_attention_heads + ) + head_dims = [head_dim for _ in layer_types] + num_heads = [num_head for _ in layer_types] + + return layer_types, num_heads, head_dims + + class KVCache(nn.Module): """ MLX-optimized KV cache with ExecutorTorch llama KVCache interface. @@ -326,14 +382,13 @@ def __init__( device: Device for cache tensors (default: None = CPU) dtype: Data type for cache tensors (default: torch.float32) """ - # Resolve dimensions from config BEFORE calling parent - num_layers = config.num_hidden_layers - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) + # Resolve dimensions from the text config before calling parent. Multimodal + # configs like Gemma 4 expose transformer dims under text_config. + text_config = resolve_hf_text_config(config) + layer_types, num_heads, head_dims = resolve_hf_cache_layout(config) + num_model_layers = text_config.num_hidden_layers actual_max_cache_len = max_cache_len or getattr( - config, "max_position_embeddings", 2048 + text_config, "max_position_embeddings", 2048 ) # Initialize parent StaticCache with required arguments @@ -348,15 +403,22 @@ def __init__( self.early_initialization( batch_size=max_batch_size, num_heads=num_heads, - head_dim=head_dim, + head_dim=head_dims, dtype=dtype, device=device, ) + # Some models (for example Gemma 4) only allocate cache entries for the + # non-shared KV layers. Mirror the parent StaticCache layout exactly so + # layer_idx values passed to update() line up with our backing cache. + num_cache_layers = len(self.layers) + # Store dimensions as instance attributes - self.num_layers = num_layers + self.num_model_layers = num_model_layers + self.num_layers = num_cache_layers + self.layer_types = layer_types self.num_heads = num_heads - self.head_dim = head_dim + self.head_dim = head_dims # Create KVCache wrappers for each layer - these use mlx::kv_cache_update # Named 'kv_cache' to match optimum-executorch's ETCustomStaticCache pattern @@ -365,12 +427,12 @@ def __init__( KVCache( max_batch_size=max_batch_size, max_context_length=actual_max_cache_len, - n_heads=num_heads, - head_dim=head_dim, + n_heads=layer_num_heads, + head_dim=layer_head_dim, enable_dynamic_shape=True, dtype=dtype, ) - for _ in range(num_layers) + for layer_num_heads, layer_head_dim in zip(num_heads, head_dims) ] ) @@ -394,18 +456,24 @@ def update( key_states: New key states [batch_size, num_heads, seq_len, head_dim] value_states: New value states [batch_size, num_heads, seq_len, head_dim] layer_idx: Index of the layer to update - cache_kwargs: Dictionary containing 'cache_position' tensor with start position + cache_kwargs: Optional dictionary containing 'cache_position' tensor + with start position. Newer HF StaticCache callers seed + `self.layers[layer_idx].cumulative_length` directly and do not + pass cache_kwargs. Returns: Tuple of (key_cache, value_cache) for the full cache after update """ - assert ( - cache_kwargs is not None - ), "cache_kwargs must be provided with 'cache_position'" - cache_position = cache_kwargs.get("cache_position") - assert ( - cache_position is not None - ), "cache_position must be provided in cache_kwargs" + if cache_kwargs is not None: + cache_position = cache_kwargs.get("cache_position") + else: + cache_position = None + + if cache_position is None: + # Current HF ExecuTorch wrappers copy the requested cache position + # into each StaticCache layer's cumulative_length before forward(). + cache_position = self.layers[layer_idx].cumulative_length + assert isinstance( cache_position, torch.Tensor ), "cache_position must be a tensor" diff --git a/backends/mlx/llm/hf_attention.py b/backends/mlx/llm/hf_attention.py index 9e3c864dce6..f2a01c9e653 100644 --- a/backends/mlx/llm/hf_attention.py +++ b/backends/mlx/llm/hf_attention.py @@ -89,8 +89,10 @@ def mlx_sdpa_with_start_pos_forward( def sdpa_mask_passthrough( batch_size: int, - cache_position: torch.Tensor, - kv_length: int, + cache_position: Optional[torch.Tensor] = None, + q_length: Optional[int] = None, + kv_length: Optional[int] = None, + q_offset: Optional[Union[int, torch.Tensor]] = None, kv_offset: int = 0, mask_function: Optional[Callable] = None, attention_mask: Optional[torch.Tensor] = None, @@ -139,6 +141,27 @@ def get_mlx_sliding_window_sdpa(exportable_module) -> Callable: Attention function compatible with HuggingFace's attention interface. """ + def _resolve_cache_layer_idx(module: torch.nn.Module, cache) -> Optional[int]: + """ + Map a transformer layer index to the backing cache slot index. + + Hybrid/shared-KV models like Gemma 4 only allocate cache entries for the + non-shared KV layers. Shared layers expose `kv_shared_layer_index`, which + points at the earlier cache-producing layer they reuse. + """ + layer_idx = getattr(module, "layer_idx", None) + if layer_idx is None: + return None + + if layer_idx < len(cache.kv_cache): + return layer_idx + + shared_layer_idx = getattr(module, "kv_shared_layer_index", None) + if shared_layer_idx is not None and shared_layer_idx < len(cache.kv_cache): + return shared_layer_idx + + return None + def _sliding_window_sdpa_forward( module: torch.nn.Module, query: torch.Tensor, # [B, num_heads, seq_len, head_dim] - BHSD @@ -165,6 +188,7 @@ def _sliding_window_sdpa_forward( attn_mask = None start_pos = 0 + layer_cache = None if layer_idx is not None and position_ids is not None: start_pos = position_ids[0][0].item() @@ -173,7 +197,9 @@ def _sliding_window_sdpa_forward( cache = getattr(exportable_module, "cache", None) if cache is not None: - layer_cache = cache.kv_cache[layer_idx] + cache_layer_idx = _resolve_cache_layer_idx(module, cache) + if cache_layer_idx is not None: + layer_cache = cache.kv_cache[cache_layer_idx] if isinstance(layer_cache, RingBufferKVCache): attn_mask = layer_cache.create_sliding_window_mask( start_pos, seq_len @@ -182,11 +208,19 @@ def _sliding_window_sdpa_forward( # stop_pos = start_pos + seq_len = buffer_size start_pos = layer_cache.buffer_size - seq_len + # Hybrid models use one global HF attention implementation. Sliding + # layers need the ring-buffer mask path, while full-attention layers + # should keep the regular causal SDPA path even under the same hook. if attn_mask is None: - raise RuntimeError( - f"Sliding window attention at layer {layer_idx} requires a " - f"RingBufferKVCache, but none was found. Ensure the model's " - f"cache is set up with RingBufferKVCache for sliding window layers." + return mlx_sdpa_with_start_pos_forward( + module, + query, + key, + value, + attention_mask, + position_ids=position_ids, + scaling=scaling, + **kwargs, ) output = torch.ops.mlx.custom_sdpa( diff --git a/backends/mlx/llm/source_transformation.py b/backends/mlx/llm/source_transformation.py index d90073c633e..fa727b650fd 100644 --- a/backends/mlx/llm/source_transformation.py +++ b/backends/mlx/llm/source_transformation.py @@ -19,7 +19,13 @@ import torch import torch.nn as nn -from executorch.backends.mlx.llm.cache import HFStaticCache, KVCache, RingBufferKVCache +from executorch.backends.mlx.llm.cache import ( + HFStaticCache, + KVCache, + RingBufferKVCache, + resolve_hf_cache_layout, + resolve_hf_text_config, +) logger = logging.getLogger(__name__) @@ -123,9 +129,16 @@ def replace_hf_cache_with_mlx( def _install_cache(attr_name): setattr(module, attr_name, mlx_cache) - for i, layer_cache in enumerate(mlx_cache.kv_cache): + for i, (cache_layer, layer_cache) in enumerate( + zip(mlx_cache.layers, mlx_cache.kv_cache) + ): setattr(module, f"key_cache_{i}", layer_cache.k_cache) setattr(module, f"value_cache_{i}", layer_cache.v_cache) + setattr( + module, + f"cumulative_length_{i}", + cache_layer.cumulative_length, + ) if hasattr(module, "static_cache"): assert isinstance( @@ -171,12 +184,6 @@ def replace_hf_cache_with_mlx_ring_buffer( """ from transformers.cache_utils import StaticCache - num_layers = config.num_hidden_layers - num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) - # Create HFStaticCache with ring buffer layers mlx_cache = HFStaticCache( config=config, @@ -185,22 +192,38 @@ def replace_hf_cache_with_mlx_ring_buffer( dtype=dtype, ) - # Replace each layer's KVCache with RingBufferKVCache - for i in range(num_layers): - ring_cache = RingBufferKVCache( + # Replace only the sliding-window cache entries with ring buffers, while + # preserving full-attention entries as linear caches. Hybrid models like + # Gemma 4 mix both layouts and can also vary head_dim per cache layer. + layer_types, num_heads, head_dims = resolve_hf_cache_layout(config) + num_cache_layers = len(mlx_cache.layers) + num_ring_layers = 0 + for i, (layer_type, layer_num_heads, layer_head_dim) in enumerate( + zip(layer_types, num_heads, head_dims) + ): + if layer_type != "sliding_attention": + continue + mlx_cache.kv_cache[i] = RingBufferKVCache( max_batch_size=max_batch_size, max_context_length=window_size, - n_heads=num_kv_heads, - head_dim=head_dim, + n_heads=layer_num_heads, + head_dim=layer_head_dim, dtype=dtype, ) - mlx_cache.kv_cache[i] = ring_cache + num_ring_layers += 1 def _install_cache(attr_name): setattr(module, attr_name, mlx_cache) - for i, layer_cache in enumerate(mlx_cache.kv_cache): + for i, (cache_layer, layer_cache) in enumerate( + zip(mlx_cache.layers, mlx_cache.kv_cache) + ): setattr(module, f"key_cache_{i}", layer_cache.k_cache) setattr(module, f"value_cache_{i}", layer_cache.v_cache) + setattr( + module, + f"cumulative_length_{i}", + cache_layer.cumulative_length, + ) if hasattr(module, "static_cache"): assert isinstance( @@ -218,8 +241,8 @@ def _install_cache(attr_name): raise ValueError("Module must have 'static_cache' or 'cache' attribute") logger.info( - f"Installed RingBufferKVCache: {num_layers} layers, " - f"window_size={window_size}, heads={num_kv_heads}, head_dim={head_dim}" + f"Installed hybrid MLX cache: {num_ring_layers} ring-buffer layers / " + f"{num_cache_layers} total cache layers, window_size={window_size}" ) return module diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp index 99e20114ea7..5127a12d146 100644 --- a/backends/mlx/runtime/MLXBackend.cpp +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -209,14 +209,18 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { } try { + std::cerr << "MLX init: constructing handle" << std::endl; new (handle) MLXHandle(); + std::cerr << "MLX init: handle constructed" << std::endl; if (!processed || !processed->data() || processed->size() == 0) { throw std::runtime_error("init: null or empty delegate payload"); } + std::cerr << "MLX init: parsing delegate payload" << std::endl; handle->program = loader::load_program( static_cast(processed->data()), processed->size()); + std::cerr << "MLX init: delegate payload parsed" << std::endl; // Validate schema version int schema_version = 1; @@ -244,27 +248,34 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { // runtime const runtime::NamedDataMap* named_data_map = context.get_named_data_map(); + std::cerr << "MLX init: loading constants" << std::endl; load_constants( handle->program, named_data_map, handle->constants, handle->constant_buffers); + std::cerr << "MLX init: constants loaded" << std::endl; // Delegate payload no longer needed after constants are loaded processed->Free(); processed = nullptr; // Load mutable buffers (e.g., KV cache) + std::cerr << "MLX init: loading mutable buffers" << std::endl; load_mutable_buffers(handle->program, handle->mutable_buffers); + std::cerr << "MLX init: mutable buffers loaded" << std::endl; // Bind execution state (reused across execute() calls) + std::cerr << "MLX init: binding execution state" << std::endl; handle->state.bind( handle->program, handle->constants, handle->mutable_buffers); + std::cerr << "MLX init: execution state bound" << std::endl; // Run init chain if present. // SAFETY: The >= 0 check ensures init_chain_idx is non-negative, so the // static_cast cannot produce UINT32_MAX from a -1 sentinel. if (handle->program.init_chain_idx >= 0) { + std::cerr << "MLX init: running init chain" << std::endl; handle->state.is_init_chain = true; handle->interpreter.run_chain( handle->program, @@ -276,10 +287,21 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { // Evaluate any constants written by the init chain so the first // execute() doesn't pay the cost of materializing them. eval(handle->constants.tensors); + std::cerr << "MLX init: init chain complete" << std::endl; } } catch (const std::exception& e) { ET_LOG(Error, "Failed to load MLX program: %s", e.what()); + std::cerr << "Failed to load MLX program: " << e.what() << std::endl; + handle->~MLXHandle(); + if (processed != nullptr) { + processed->Free(); + } + return Error::InvalidProgram; + } catch (...) { + ET_LOG(Error, "Failed to load MLX program: unknown non-std exception"); + std::cerr << "Failed to load MLX program: unknown non-std exception" + << std::endl; handle->~MLXHandle(); if (processed != nullptr) { processed->Free(); @@ -415,6 +437,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { return Error::Ok; } catch (const std::exception& e) { ET_LOG(Error, "MLX execute failed: %s", e.what()); + std::cerr << "MLX execute failed: " << e.what() << std::endl; + return Error::Internal; + } catch (...) { + ET_LOG(Error, "MLX execute failed: unknown non-std exception"); + std::cerr << "MLX execute failed: unknown non-std exception" + << std::endl; return Error::Internal; } } diff --git a/setup.py b/setup.py index d219135992f..1de30c20117 100644 --- a/setup.py +++ b/setup.py @@ -173,7 +173,20 @@ def write_to_python_file(cls, path: str) -> None: # set to a non-empty value, the build type is Debug. Otherwise, the build type # is Release. def get_build_type(is_debug=None) -> str: - debug = int(os.environ.get("DEBUG", 0) or 0) if is_debug is None else is_debug + if is_debug is None: + raw_debug = os.environ.get("DEBUG", 0) + if isinstance(raw_debug, str): + normalized = raw_debug.strip().lower() + if normalized in ("", "0", "false", "off", "release"): + debug = 0 + elif normalized in ("1", "true", "on", "debug"): + debug = 1 + else: + debug = int(normalized) + else: + debug = int(raw_debug or 0) + else: + debug = is_debug return "Debug" if debug else "Release" @@ -193,6 +206,30 @@ def get_executable_name(name: str) -> str: return name +def get_cmake_command() -> str: + """ + Resolve the CMake executable to use for wheel builds. + + Prefer an explicit `CMAKE_COMMAND`, then a CMake binary colocated with the + current Python interpreter (common for virtualenv installs), and finally + fall back to PATH lookup. + """ + env_cmake = os.environ.get("CMAKE_COMMAND", "").strip() + if env_cmake: + return env_cmake + + python_bin_dir = os.path.dirname(sys.executable) + venv_cmake = os.path.join(python_bin_dir, get_executable_name("cmake")) + if os.path.exists(venv_cmake): + return venv_cmake + + cmake_on_path = shutil.which("cmake") + if cmake_on_path: + return cmake_on_path + + return "cmake" + + class _BaseExtension(Extension): """A base class that maps an abstract source to an abstract destination.""" @@ -753,9 +790,10 @@ def run(self): # noqa C901 log.info(f"clearing {cmake_cache_dir}") shutil.rmtree(cmake_cache_dir) + cmake_command = get_cmake_command() subprocess.run( [ - "cmake", + cmake_command, *cmake_configuration_args, "--preset", "pybind", @@ -831,10 +869,12 @@ def run(self): # noqa C901 # Set PYTHONPATH to the location of the pip package. os.environ["PYTHONPATH"] = ( - site.getsitepackages()[0] + ";" + os.environ.get("PYTHONPATH", "") + site.getsitepackages()[0] + + os.pathsep + + os.environ.get("PYTHONPATH", "") ) # Build the system. - self.spawn(["cmake", "--build", cmake_cache_dir, *cmake_build_args]) + self.spawn([get_cmake_command(), "--build", cmake_cache_dir, *cmake_build_args]) # Share the cmake-out location with _BaseExtension. self.cmake_cache_dir = cmake_cache_dir # Finally, run the underlying subcommands like build_py, build_ext. From 0a822bd7778d63f33152182eb1f8e090ec2e9b58 Mon Sep 17 00:00:00 2001 From: Zeel Date: Thu, 23 Apr 2026 21:21:59 -0400 Subject: [PATCH 02/16] Remove MLX runtime fallback and debug logging --- backends/mlx/examples/llm/run_llm_hf.py | 73 +------------------------ backends/mlx/runtime/MLXBackend.cpp | 19 ------- 2 files changed, 1 insertion(+), 91 deletions(-) diff --git a/backends/mlx/examples/llm/run_llm_hf.py b/backends/mlx/examples/llm/run_llm_hf.py index 9c5d1d0bf5f..21f197fb564 100644 --- a/backends/mlx/examples/llm/run_llm_hf.py +++ b/backends/mlx/examples/llm/run_llm_hf.py @@ -21,12 +21,8 @@ """ import argparse -import ctypes import logging -import os -import shutil import time -from pathlib import Path import torch from executorch.runtime import Runtime, Verification @@ -37,73 +33,6 @@ logger = logging.getLogger(__name__) -def _iter_mlx_backend_candidates(): - env_path = os.environ.get("ET_MLX_BACKEND_DYLIB") - if env_path: - yield Path(env_path) - - for parent in Path(__file__).resolve().parents: - pip_out = parent / "pip-out" - if pip_out.exists(): - yield from sorted( - pip_out.glob( - "temp.*/cmake-out/backends/mlx/libmlxdelegate_runtime.dylib" - ) - ) - yield from sorted( - pip_out.glob("temp.*/cmake-out/backends/mlx/libmlxdelegate.dylib") - ) - break - - -def _ensure_mlx_metallib(dylib_path: Path) -> None: - metallib_path = dylib_path.with_name("mlx.metallib") - if metallib_path.exists(): - return - - for parent in Path(__file__).resolve().parents: - pip_out = parent / "pip-out" - if not pip_out.exists(): - continue - matches = sorted( - pip_out.glob( - "temp.*/cmake-out/backends/mlx/mlx/mlx/backend/metal/kernels/mlx.metallib" - ) - ) - if not matches: - continue - shutil.copyfile(matches[0], metallib_path) - logger.info(f"Copied MLX metallib next to runtime library: {metallib_path}") - return - - -def _ensure_mlx_backend_registered() -> Runtime: - runtime = Runtime.get() - if runtime.backend_registry.is_available("MLXBackend"): - return runtime - - for candidate in _iter_mlx_backend_candidates(): - if not candidate.is_file(): - continue - try: - _ensure_mlx_metallib(candidate) - ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL) - except OSError as exc: - logger.info(f"Failed to load MLX backend library {candidate}: {exc}") - continue - - runtime = Runtime.get() - if runtime.backend_registry.is_available("MLXBackend"): - logger.info(f"Loaded MLX backend runtime library: {candidate}") - return runtime - - logger.warning( - "MLXBackend is not registered. If you built mlxdelegate locally, " - "set ET_MLX_BACKEND_DYLIB to the path of libmlxdelegate_runtime.dylib." - ) - return runtime - - def _get_max_input_seq_len(program) -> int: """Inspect the .pte program metadata to determine the max input_ids seq len. @@ -172,7 +101,7 @@ def run_inference( text_processor, uses_processor = _load_text_processor(model_id) logger.info(f"Loading model from {pte_path}...") - et_runtime = _ensure_mlx_backend_registered() + et_runtime = Runtime.get() program = et_runtime.load_program(pte_path, verification=Verification.Minimal) max_seq_len = _get_max_input_seq_len(program) diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp index 5127a12d146..5bd3bf263d1 100644 --- a/backends/mlx/runtime/MLXBackend.cpp +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -19,7 +19,6 @@ #include #include -#include #include #include #include @@ -209,18 +208,14 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { } try { - std::cerr << "MLX init: constructing handle" << std::endl; new (handle) MLXHandle(); - std::cerr << "MLX init: handle constructed" << std::endl; if (!processed || !processed->data() || processed->size() == 0) { throw std::runtime_error("init: null or empty delegate payload"); } - std::cerr << "MLX init: parsing delegate payload" << std::endl; handle->program = loader::load_program( static_cast(processed->data()), processed->size()); - std::cerr << "MLX init: delegate payload parsed" << std::endl; // Validate schema version int schema_version = 1; @@ -248,34 +243,27 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { // runtime const runtime::NamedDataMap* named_data_map = context.get_named_data_map(); - std::cerr << "MLX init: loading constants" << std::endl; load_constants( handle->program, named_data_map, handle->constants, handle->constant_buffers); - std::cerr << "MLX init: constants loaded" << std::endl; // Delegate payload no longer needed after constants are loaded processed->Free(); processed = nullptr; // Load mutable buffers (e.g., KV cache) - std::cerr << "MLX init: loading mutable buffers" << std::endl; load_mutable_buffers(handle->program, handle->mutable_buffers); - std::cerr << "MLX init: mutable buffers loaded" << std::endl; // Bind execution state (reused across execute() calls) - std::cerr << "MLX init: binding execution state" << std::endl; handle->state.bind( handle->program, handle->constants, handle->mutable_buffers); - std::cerr << "MLX init: execution state bound" << std::endl; // Run init chain if present. // SAFETY: The >= 0 check ensures init_chain_idx is non-negative, so the // static_cast cannot produce UINT32_MAX from a -1 sentinel. if (handle->program.init_chain_idx >= 0) { - std::cerr << "MLX init: running init chain" << std::endl; handle->state.is_init_chain = true; handle->interpreter.run_chain( handle->program, @@ -287,12 +275,10 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { // Evaluate any constants written by the init chain so the first // execute() doesn't pay the cost of materializing them. eval(handle->constants.tensors); - std::cerr << "MLX init: init chain complete" << std::endl; } } catch (const std::exception& e) { ET_LOG(Error, "Failed to load MLX program: %s", e.what()); - std::cerr << "Failed to load MLX program: " << e.what() << std::endl; handle->~MLXHandle(); if (processed != nullptr) { processed->Free(); @@ -300,8 +286,6 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { return Error::InvalidProgram; } catch (...) { ET_LOG(Error, "Failed to load MLX program: unknown non-std exception"); - std::cerr << "Failed to load MLX program: unknown non-std exception" - << std::endl; handle->~MLXHandle(); if (processed != nullptr) { processed->Free(); @@ -437,12 +421,9 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { return Error::Ok; } catch (const std::exception& e) { ET_LOG(Error, "MLX execute failed: %s", e.what()); - std::cerr << "MLX execute failed: " << e.what() << std::endl; return Error::Internal; } catch (...) { ET_LOG(Error, "MLX execute failed: unknown non-std exception"); - std::cerr << "MLX execute failed: unknown non-std exception" - << std::endl; return Error::Internal; } } From 0e002903f0fbe0330c4125ec26a451a1f95f8780 Mon Sep 17 00:00:00 2001 From: Zeel Date: Mon, 27 Apr 2026 21:51:32 -0400 Subject: [PATCH 03/16] Support old and new HF cache interfaces in MLX custom cache --- backends/mlx/llm/cache.py | 42 ++++++++++++++++++----- backends/mlx/llm/source_transformation.py | 22 ++++++------ 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/backends/mlx/llm/cache.py b/backends/mlx/llm/cache.py index b0654adeb71..36e6c459ec8 100644 --- a/backends/mlx/llm/cache.py +++ b/backends/mlx/llm/cache.py @@ -12,6 +12,7 @@ Provides reusable KV cache implementations optimized for the MLX backend: """ +import inspect from typing import Tuple import torch @@ -399,14 +400,30 @@ def __init__( device=device, dtype=dtype, ) - # Call early_initialization to ensure parent's layers are fully initialized - self.early_initialization( - batch_size=max_batch_size, - num_heads=num_heads, - head_dim=head_dims, - dtype=dtype, - device=device, - ) + # The HF cache API pinned in CI expects scalar num_heads/head_dim in + # early_initialization(). Gemma 4-style hybrid layouts need per-layer + # shapes, so initialize each cache layer directly using the resolved + # backing-cache geometry instead of relying on the helper. + for layer, layer_num_heads, layer_head_dim in zip( + self.layers, num_heads, head_dims + ): + fake_keys_tensor = torch.zeros( + (max_batch_size, layer_num_heads, 0, layer_head_dim), + dtype=dtype, + device=device, + ) + lazy_init_sig = inspect.signature(layer.lazy_initialization) + # Older pinned HF caches take a single fake tensor, while newer + # versions expect both key_states and value_states separately. + if len(lazy_init_sig.parameters) == 1: + layer.lazy_initialization(fake_keys_tensor) + else: + fake_values_tensor = torch.zeros( + (max_batch_size, layer_num_heads, 0, layer_head_dim), + dtype=dtype, + device=device, + ) + layer.lazy_initialization(fake_keys_tensor, fake_values_tensor) # Some models (for example Gemma 4) only allocate cache entries for the # non-shared KV layers. Mirror the parent StaticCache layout exactly so @@ -472,7 +489,14 @@ def update( if cache_position is None: # Current HF ExecuTorch wrappers copy the requested cache position # into each StaticCache layer's cumulative_length before forward(). - cache_position = self.layers[layer_idx].cumulative_length + if hasattr(self.layers[layer_idx], "cumulative_length"): + cache_position = self.layers[layer_idx].cumulative_length + else: + raise RuntimeError( + "cache_position was not provided and the pinned " + "transformers StaticCache layer does not expose " + "cumulative_length" + ) assert isinstance( cache_position, torch.Tensor diff --git a/backends/mlx/llm/source_transformation.py b/backends/mlx/llm/source_transformation.py index fa727b650fd..06a45b9e22b 100644 --- a/backends/mlx/llm/source_transformation.py +++ b/backends/mlx/llm/source_transformation.py @@ -134,11 +134,12 @@ def _install_cache(attr_name): ): setattr(module, f"key_cache_{i}", layer_cache.k_cache) setattr(module, f"value_cache_{i}", layer_cache.v_cache) - setattr( - module, - f"cumulative_length_{i}", - cache_layer.cumulative_length, - ) + if hasattr(cache_layer, "cumulative_length"): + setattr( + module, + f"cumulative_length_{i}", + cache_layer.cumulative_length, + ) if hasattr(module, "static_cache"): assert isinstance( @@ -219,11 +220,12 @@ def _install_cache(attr_name): ): setattr(module, f"key_cache_{i}", layer_cache.k_cache) setattr(module, f"value_cache_{i}", layer_cache.v_cache) - setattr( - module, - f"cumulative_length_{i}", - cache_layer.cumulative_length, - ) + if hasattr(cache_layer, "cumulative_length"): + setattr( + module, + f"cumulative_length_{i}", + cache_layer.cumulative_length, + ) if hasattr(module, "static_cache"): assert isinstance( From 3a26baaeec7ee3be5d37d3b59dede035771d98da Mon Sep 17 00:00:00 2001 From: Zeel Date: Mon, 27 Apr 2026 21:55:51 -0400 Subject: [PATCH 04/16] Revert setup.py changes from Gemma 4 MLX PR --- setup.py | 48 ++++-------------------------------------------- 1 file changed, 4 insertions(+), 44 deletions(-) diff --git a/setup.py b/setup.py index 1de30c20117..d219135992f 100644 --- a/setup.py +++ b/setup.py @@ -173,20 +173,7 @@ def write_to_python_file(cls, path: str) -> None: # set to a non-empty value, the build type is Debug. Otherwise, the build type # is Release. def get_build_type(is_debug=None) -> str: - if is_debug is None: - raw_debug = os.environ.get("DEBUG", 0) - if isinstance(raw_debug, str): - normalized = raw_debug.strip().lower() - if normalized in ("", "0", "false", "off", "release"): - debug = 0 - elif normalized in ("1", "true", "on", "debug"): - debug = 1 - else: - debug = int(normalized) - else: - debug = int(raw_debug or 0) - else: - debug = is_debug + debug = int(os.environ.get("DEBUG", 0) or 0) if is_debug is None else is_debug return "Debug" if debug else "Release" @@ -206,30 +193,6 @@ def get_executable_name(name: str) -> str: return name -def get_cmake_command() -> str: - """ - Resolve the CMake executable to use for wheel builds. - - Prefer an explicit `CMAKE_COMMAND`, then a CMake binary colocated with the - current Python interpreter (common for virtualenv installs), and finally - fall back to PATH lookup. - """ - env_cmake = os.environ.get("CMAKE_COMMAND", "").strip() - if env_cmake: - return env_cmake - - python_bin_dir = os.path.dirname(sys.executable) - venv_cmake = os.path.join(python_bin_dir, get_executable_name("cmake")) - if os.path.exists(venv_cmake): - return venv_cmake - - cmake_on_path = shutil.which("cmake") - if cmake_on_path: - return cmake_on_path - - return "cmake" - - class _BaseExtension(Extension): """A base class that maps an abstract source to an abstract destination.""" @@ -790,10 +753,9 @@ def run(self): # noqa C901 log.info(f"clearing {cmake_cache_dir}") shutil.rmtree(cmake_cache_dir) - cmake_command = get_cmake_command() subprocess.run( [ - cmake_command, + "cmake", *cmake_configuration_args, "--preset", "pybind", @@ -869,12 +831,10 @@ def run(self): # noqa C901 # Set PYTHONPATH to the location of the pip package. os.environ["PYTHONPATH"] = ( - site.getsitepackages()[0] - + os.pathsep - + os.environ.get("PYTHONPATH", "") + site.getsitepackages()[0] + ";" + os.environ.get("PYTHONPATH", "") ) # Build the system. - self.spawn([get_cmake_command(), "--build", cmake_cache_dir, *cmake_build_args]) + self.spawn(["cmake", "--build", cmake_cache_dir, *cmake_build_args]) # Share the cmake-out location with _BaseExtension. self.cmake_cache_dir = cmake_cache_dir # Finally, run the underlying subcommands like build_py, build_ext. From 0bf5fc43ed1e49954576f7805dab911ed80ee603 Mon Sep 17 00:00:00 2001 From: Zeel Date: Mon, 27 Apr 2026 22:46:25 -0400 Subject: [PATCH 05/16] Use newer Transformers for Gemma 4 MLX CI --- .github/workflows/mlx.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index e61f4007932..0f5fc41229b 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -533,6 +533,11 @@ jobs: ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) ${CONDA_RUN} pip install transformers "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" + if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then + # Gemma 4 requires a newer Transformers build than the CI-wide + # optimum-executorch pin currently brings in. + ${CONDA_RUN} pip install -U "transformers @ git+https://github.com/huggingface/transformers.git" + fi echo "::endgroup::" ${CONDA_RUN} pip list From 90e5577749da64d727482a88e503c0d9e0bd5018 Mon Sep 17 00:00:00 2001 From: Zeel Date: Tue, 28 Apr 2026 16:48:32 -0400 Subject: [PATCH 06/16] Pin Gemma 4 MLX CI to validated Transformers commit --- .github/workflows/mlx.yml | 6 ++++-- backends/mlx/examples/llm/README.md | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 0f5fc41229b..023e61b1c7c 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -535,8 +535,10 @@ jobs: ${CONDA_RUN} pip install transformers "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then # Gemma 4 requires a newer Transformers build than the CI-wide - # optimum-executorch pin currently brings in. - ${CONDA_RUN} pip install -U "transformers @ git+https://github.com/huggingface/transformers.git" + # optimum-executorch pin currently brings in. Keep this pinned to the + # locally validated commit instead of floating on Transformers HEAD. + GEMMA4_TRANSFORMERS_COMMIT=61461a7bcb458db7cf6eeea49678b9ab776a7821 + ${CONDA_RUN} pip install -U "transformers @ git+https://github.com/huggingface/transformers.git@${GEMMA4_TRANSFORMERS_COMMIT}" fi echo "::endgroup::" diff --git a/backends/mlx/examples/llm/README.md b/backends/mlx/examples/llm/README.md index 04bcd500651..738bbfb8c14 100644 --- a/backends/mlx/examples/llm/README.md +++ b/backends/mlx/examples/llm/README.md @@ -66,6 +66,13 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \ Gemma 4 support is currently validated for the text-only path using `--use-custom-sdpa --use-custom-kv-cache --qlinear 4w`. +Validated with `transformers` commit +`61461a7bcb458db7cf6eeea49678b9ab776a7821`: + +```bash +pip install -U "transformers @ git+https://github.com/huggingface/transformers.git@61461a7bcb458db7cf6eeea49678b9ab776a7821" +``` + ### Options | Option | Default | Description | From ee272c3fb68569735430e0560e501ab0a8fd7eb0 Mon Sep 17 00:00:00 2001 From: Zeel Date: Wed, 29 Apr 2026 08:25:18 -0400 Subject: [PATCH 07/16] Prefer tokenizer for text-only Gemma 4 MLX runs --- backends/mlx/examples/llm/run_llm_hf.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/backends/mlx/examples/llm/run_llm_hf.py b/backends/mlx/examples/llm/run_llm_hf.py index 21f197fb564..cae670bb0dc 100644 --- a/backends/mlx/examples/llm/run_llm_hf.py +++ b/backends/mlx/examples/llm/run_llm_hf.py @@ -51,9 +51,18 @@ def _load_text_processor(model_id: str): """ Load a text processor for the model. - Prefer AutoProcessor for multimodal/text-hybrid models like Gemma 4, and - fall back to AutoTokenizer for text-only checkpoints. + Prefer AutoTokenizer for text-only prompting, even for checkpoints that + also ship an AutoProcessor. Some hybrid checkpoints (for example Gemma 4) + expose both, but the tokenizer path is the more stable interface for the + plain text generation flow exercised by this runner. """ + logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") + try: + tokenizer = AutoTokenizer.from_pretrained(model_id) + return tokenizer, False + except Exception as exc: + logger.info(f"AutoTokenizer unavailable for {model_id}: {exc}") + try: processor = AutoProcessor.from_pretrained(model_id) if hasattr(processor, "apply_chat_template") and hasattr(processor, "decode"): @@ -62,9 +71,7 @@ def _load_text_processor(model_id: str): except Exception as exc: logger.info(f"AutoProcessor unavailable for {model_id}: {exc}") - logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") - tokenizer = AutoTokenizer.from_pretrained(model_id) - return tokenizer, False + raise RuntimeError(f"Could not load tokenizer or processor for {model_id}") def _apply_chat_template(text_processor, messages) -> str: From ca37250f08ac893b7e01a9d9a8e27f9b45f75172 Mon Sep 17 00:00:00 2001 From: Zeel Date: Wed, 29 Apr 2026 16:54:41 -0400 Subject: [PATCH 08/16] Pin Gemma 4 MLX flow to validated model revision --- .github/workflows/mlx.yml | 6 ++++++ backends/mlx/examples/llm/README.md | 2 ++ backends/mlx/examples/llm/export_llm_hf.py | 15 +++++++++++++++ backends/mlx/examples/llm/run_llm_hf.py | 16 ++++++++++++---- 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 023e61b1c7c..e995e5e88f0 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -512,6 +512,10 @@ jobs: MODEL_NAME="${{ matrix.model.name }}" USE_CUSTOM="${{ matrix.use-custom }}" QCONFIG="${{ matrix.qconfig }}" + MODEL_REVISION="" + if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then + MODEL_REVISION="b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf" + fi CUSTOM_ARGS="" if [ "${USE_CUSTOM}" = "true" ]; then @@ -547,6 +551,7 @@ jobs: echo "::group::Export ${MODEL_NAME}" ${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --model-id "${MODEL_ID}" \ + ${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \ --output /tmp/${MODEL_NAME}.pte \ --qlinear ${QCONFIG} \ ${QEMBEDDING_ARGS} \ @@ -557,6 +562,7 @@ jobs: OUTPUT=$(${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.run_llm_hf \ --pte /tmp/${MODEL_NAME}.pte \ --model-id "${MODEL_ID}" \ + ${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \ --prompt "What is the capital of France?" \ --max-new-tokens 50 2>&1) echo "$OUTPUT" diff --git a/backends/mlx/examples/llm/README.md b/backends/mlx/examples/llm/README.md index 738bbfb8c14..8def8c1f06a 100644 --- a/backends/mlx/examples/llm/README.md +++ b/backends/mlx/examples/llm/README.md @@ -57,6 +57,7 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \ # Gemma 4 text-only export python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --model-id "google/gemma-4-E2B-it" \ + --revision "b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf" \ --output gemma4_hf_int4.pte \ --use-custom-sdpa \ --use-custom-kv-cache \ @@ -108,6 +109,7 @@ Validated Gemma 4 run command: python -m executorch.backends.mlx.examples.llm.run_llm_hf \ --pte gemma4_hf_int4.pte \ --model-id google/gemma-4-E2B-it \ + --revision b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf \ --prompt "What is the capital of France?" \ --max-new-tokens 50 ``` diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index ec5bcf5abf8..3ba483142c5 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -50,6 +50,7 @@ def _export_with_optimum( model_id: str, + revision: Optional[str], output_path: str, max_seq_len: int, dtype: str, @@ -73,6 +74,7 @@ def _export_with_optimum( logger.info(f"Loading model using optimum-executorch: {model_id}") exportable = load_causal_lm_model( model_id, + revision=revision, dtype=dtype_str, max_seq_len=max_seq_len, ) @@ -124,6 +126,7 @@ def _export_with_optimum( def _export_with_custom_components( model_id: str, + revision: Optional[str], output_path: str, max_seq_len: int, dtype: str, @@ -171,6 +174,8 @@ def _export_with_custom_components( "torch_dtype": torch_dtype, "low_cpu_mem_usage": True, } + if revision is not None: + load_kwargs["revision"] = revision if attn_implementation: load_kwargs["attn_implementation"] = attn_implementation model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs) @@ -345,6 +350,7 @@ def _save_program(executorch_program, output_path: str) -> None: def export_llama_hf( model_id: str, + revision: Optional[str], output_path: str, max_seq_len: int = 1024, dtype: str = "bf16", @@ -376,6 +382,7 @@ def export_llama_hf( ) _export_with_custom_components( model_id=model_id, + revision=revision, output_path=output_path, max_seq_len=max_seq_len, dtype=dtype, @@ -391,6 +398,7 @@ def export_llama_hf( logger.info("Using optimum-executorch pipeline (no custom components)") _export_with_optimum( model_id=model_id, + revision=revision, output_path=output_path, max_seq_len=max_seq_len, dtype=dtype, @@ -412,6 +420,12 @@ def main(): default="unsloth/Llama-3.2-1B-Instruct", help="HuggingFace model ID", ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HuggingFace model revision/commit to pin", + ) parser.add_argument( "--output", type=str, @@ -451,6 +465,7 @@ def main(): export_llama_hf( model_id=args.model_id, + revision=args.revision, output_path=args.output, max_seq_len=args.max_seq_len, dtype=args.dtype, diff --git a/backends/mlx/examples/llm/run_llm_hf.py b/backends/mlx/examples/llm/run_llm_hf.py index cae670bb0dc..c15bcd89c46 100644 --- a/backends/mlx/examples/llm/run_llm_hf.py +++ b/backends/mlx/examples/llm/run_llm_hf.py @@ -47,7 +47,7 @@ def _get_max_input_seq_len(program) -> int: return sizes[1] if len(sizes) >= 2 else 1 -def _load_text_processor(model_id: str): +def _load_text_processor(model_id: str, revision: str | None): """ Load a text processor for the model. @@ -58,13 +58,13 @@ def _load_text_processor(model_id: str): """ logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") try: - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) return tokenizer, False except Exception as exc: logger.info(f"AutoTokenizer unavailable for {model_id}: {exc}") try: - processor = AutoProcessor.from_pretrained(model_id) + processor = AutoProcessor.from_pretrained(model_id, revision=revision) if hasattr(processor, "apply_chat_template") and hasattr(processor, "decode"): logger.info(f"Loaded processor from HuggingFace: {model_id}") return processor, True @@ -101,11 +101,12 @@ def _get_eos_token_id(text_processor): def run_inference( pte_path: str, model_id: str, + revision: str | None, prompt: str, max_new_tokens: int = 50, ) -> str: """Run inference on the exported HuggingFace model.""" - text_processor, uses_processor = _load_text_processor(model_id) + text_processor, uses_processor = _load_text_processor(model_id, revision) logger.info(f"Loading model from {pte_path}...") et_runtime = Runtime.get() @@ -208,6 +209,12 @@ def main(): default="unsloth/Llama-3.2-1B-Instruct", help="HuggingFace model ID (used to load tokenizer or processor)", ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HuggingFace model revision/commit to pin", + ) parser.add_argument( "--prompt", type=str, @@ -226,6 +233,7 @@ def main(): generated_text = run_inference( pte_path=args.pte, model_id=args.model_id, + revision=args.revision, prompt=args.prompt, max_new_tokens=args.max_new_tokens, ) From 818a51d058c6c0b0864af39a326a65fe70a2a710 Mon Sep 17 00:00:00 2001 From: Zeel Date: Wed, 29 Apr 2026 17:36:19 -0400 Subject: [PATCH 09/16] Prefer HF early cache init for Gemma 4 MLX path --- backends/mlx/llm/cache.py | 42 +++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/backends/mlx/llm/cache.py b/backends/mlx/llm/cache.py index 36e6c459ec8..2890e823499 100644 --- a/backends/mlx/llm/cache.py +++ b/backends/mlx/llm/cache.py @@ -400,30 +400,38 @@ def __init__( device=device, dtype=dtype, ) - # The HF cache API pinned in CI expects scalar num_heads/head_dim in - # early_initialization(). Gemma 4-style hybrid layouts need per-layer - # shapes, so initialize each cache layer directly using the resolved - # backing-cache geometry instead of relying on the helper. - for layer, layer_num_heads, layer_head_dim in zip( - self.layers, num_heads, head_dims - ): - fake_keys_tensor = torch.zeros( - (max_batch_size, layer_num_heads, 0, layer_head_dim), + # Newer HF cache implementations already support per-layer layouts in + # early_initialization(). Keep that path for Gemma 4, and only fall + # back to manual layer initialization for the older CI-pinned API. + try: + self.early_initialization( + batch_size=max_batch_size, + num_heads=num_heads, + head_dim=head_dims, dtype=dtype, device=device, ) - lazy_init_sig = inspect.signature(layer.lazy_initialization) - # Older pinned HF caches take a single fake tensor, while newer - # versions expect both key_states and value_states separately. - if len(lazy_init_sig.parameters) == 1: - layer.lazy_initialization(fake_keys_tensor) - else: - fake_values_tensor = torch.zeros( + except TypeError: + for layer, layer_num_heads, layer_head_dim in zip( + self.layers, num_heads, head_dims + ): + fake_keys_tensor = torch.zeros( (max_batch_size, layer_num_heads, 0, layer_head_dim), dtype=dtype, device=device, ) - layer.lazy_initialization(fake_keys_tensor, fake_values_tensor) + lazy_init_sig = inspect.signature(layer.lazy_initialization) + # Older pinned HF caches take a single fake tensor, while newer + # versions expect both key_states and value_states separately. + if len(lazy_init_sig.parameters) == 1: + layer.lazy_initialization(fake_keys_tensor) + else: + fake_values_tensor = torch.zeros( + (max_batch_size, layer_num_heads, 0, layer_head_dim), + dtype=dtype, + device=device, + ) + layer.lazy_initialization(fake_keys_tensor, fake_values_tensor) # Some models (for example Gemma 4) only allocate cache entries for the # non-shared KV layers. Mirror the parent StaticCache layout exactly so From 6e520dde23cbf2964b82d93f4b79535110551777 Mon Sep 17 00:00:00 2001 From: Zeel Date: Wed, 29 Apr 2026 18:14:39 -0400 Subject: [PATCH 10/16] Stabilize Gemma 4 MLX constant slot ordering --- backends/mlx/builder/program_builder.py | 56 ++++++++++++++++++------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/backends/mlx/builder/program_builder.py b/backends/mlx/builder/program_builder.py index 0892476fedd..21ae8a3b6fa 100644 --- a/backends/mlx/builder/program_builder.py +++ b/backends/mlx/builder/program_builder.py @@ -444,26 +444,50 @@ def _make_io_slots(self): # noqa: C901 else: raise NotImplementedError(f"Support for input {arg} is not implemented") + placeholder_nodes = { + node.name: node for node in self.ep.graph.nodes if node.op == "placeholder" + } + + # Allocate placeholder-backed slots in graph-signature order instead of + # raw FX node traversal order. This keeps lifted constant tids stable + # across equivalent exports, which matters for models like Gemma 4 that + # carry multiple rotary constant placeholders with similar structure. + for name in constant_tensors: + node = placeholder_nodes.get(name) + if node is None or node.users == {}: + continue + self.make_or_get_slot(node, id_space=IdSpace.Constant) + + for name in user_inputs: + node = placeholder_nodes.get(name) + if node is None or node.users == {}: + continue + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor) and not val.is_contiguous(): + raise ValueError( + f"MLX backend requires contiguous input tensors, " + f"but input '{node.name}' has non-contiguous strides. " + f"shape={list(val.shape)}, stride={list(val.stride())}. " + f"Ensure example inputs passed to torch.export.export() " + f"are contiguous (call .contiguous() on them)." + ) + self.make_or_get_slot(node, id_space=IdSpace.Input) + + for name in mutable_buffers: + node = placeholder_nodes.get(name) + if node is None or node.users == {}: + continue + self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) + + classified_placeholders = ( + set(constant_tensors) | set(user_inputs) | set(mutable_buffers) + ) + for node in self.ep.graph.nodes: if node.op == "placeholder": if node.users == {}: continue - if node.name in constant_tensors: - self.make_or_get_slot(node, id_space=IdSpace.Constant) - elif node.name in user_inputs: - val = node.meta.get("val", None) - if isinstance(val, torch.Tensor) and not val.is_contiguous(): - raise ValueError( - f"MLX backend requires contiguous input tensors, " - f"but input '{node.name}' has non-contiguous strides. " - f"shape={list(val.shape)}, stride={list(val.stride())}. " - f"Ensure example inputs passed to torch.export.export() " - f"are contiguous (call .contiguous() on them)." - ) - self.make_or_get_slot(node, id_space=IdSpace.Input) - elif node.name in mutable_buffers: - self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) - else: + if node.name not in classified_placeholders: raise NotImplementedError( f"Support for placeholder {node.name} is not implemented" ) From 391cde44b544ebedb44c49ed682cc5c483e54519 Mon Sep 17 00:00:00 2001 From: Zeel Date: Thu, 30 Apr 2026 06:46:13 -0400 Subject: [PATCH 11/16] Keep Gemma 4 layer 22 down_proj in float for MLX export --- backends/mlx/examples/llm/export_llm_hf.py | 61 ++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index 3ba483142c5..a0e09062d35 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -47,6 +47,53 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) +_GEMMA4_MODEL_ID = "google/gemma-4-E2B-it" +_GEMMA4_PROBLEM_LAYER_FQN = "model.language_model.layers.22.mlp.down_proj" + + +def _get_submodule_by_fqn(root: torch.nn.Module, fqn: str) -> torch.nn.Module: + cur = root + for part in fqn.split("."): + if part.isdigit(): + cur = cur[int(part)] # type: ignore[index] + else: + cur = getattr(cur, part) + return cur + + +def _capture_gemma4_float_fallback_weight( + model_id: str, + qlinear: Optional[str], + model: torch.nn.Module, +) -> Optional[torch.Tensor]: + if model_id != _GEMMA4_MODEL_ID or qlinear != "4w": + return None + + layer = _get_submodule_by_fqn(model, _GEMMA4_PROBLEM_LAYER_FQN) + weight = layer.weight.detach().clone() + logger.info( + "Saving %s in floating point to avoid the current Gemma 4 4w mismatch", + _GEMMA4_PROBLEM_LAYER_FQN, + ) + return weight + + +def _restore_gemma4_float_fallback_weight( + model_id: str, + qlinear: Optional[str], + model: torch.nn.Module, + weight: Optional[torch.Tensor], +) -> None: + if weight is None or model_id != _GEMMA4_MODEL_ID or qlinear != "4w": + return + + layer = _get_submodule_by_fqn(model, _GEMMA4_PROBLEM_LAYER_FQN) + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + logger.info( + "Restored %s in floating point after quantization", + _GEMMA4_PROBLEM_LAYER_FQN, + ) + def _export_with_optimum( model_id: str, @@ -81,6 +128,10 @@ def _export_with_optimum( from executorch.backends.mlx.llm.quantization import quantize_model_ + gemma4_float_weight = _capture_gemma4_float_fallback_weight( + model_id, qlinear, exportable.model + ) + quantize_model_( exportable.model, qlinear_config=qlinear, @@ -92,6 +143,9 @@ def _export_with_optimum( ) and not no_tie_word_embeddings, ) + _restore_gemma4_float_fallback_weight( + model_id, qlinear, exportable.model, gemma4_float_weight + ) logger.info("Exporting model with torch.export...") exported_progs = exportable.export() @@ -277,6 +331,10 @@ def _export_with_custom_components( from executorch.backends.mlx.llm.quantization import quantize_model_ + gemma4_float_weight = _capture_gemma4_float_fallback_weight( + model_id, qlinear, exportable.model + ) + quantize_model_( exportable.model, qlinear_config=qlinear, @@ -286,6 +344,9 @@ def _export_with_custom_components( tie_word_embeddings=getattr(model.config, "tie_word_embeddings", False) and not no_tie_word_embeddings, ) + _restore_gemma4_float_fallback_weight( + model_id, qlinear, exportable.model, gemma4_float_weight + ) logger.info("Exporting model with torch.export...") seq_length = 3 From 19d6f098d17b939200134a05318d5a086750e581 Mon Sep 17 00:00:00 2001 From: Zeel Date: Thu, 30 Apr 2026 07:14:21 -0400 Subject: [PATCH 12/16] Use static cache for Gemma 4 MLX custom export --- backends/mlx/examples/llm/export_llm_hf.py | 56 ++++++---------------- 1 file changed, 15 insertions(+), 41 deletions(-) diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index a0e09062d35..abb66d49326 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -282,52 +282,26 @@ def _export_with_custom_components( ) if use_custom_kv_cache: - if sliding_window is not None: - # Use ring buffer cache for sliding window models - from executorch.backends.mlx.llm.source_transformation import ( - replace_hf_cache_with_mlx_ring_buffer, - ) + from executorch.backends.mlx.llm.source_transformation import ( + replace_hf_cache_with_mlx, + ) + if sliding_window is not None: logger.info( - f"Replacing StaticCache with RingBuffer KV cache " - f"(window_size={effective_cache_len})..." + "Replacing HuggingFace StaticCache with HFStaticCache " + f"(capped to sliding window: {effective_cache_len})..." ) - replace_hf_cache_with_mlx_ring_buffer( - exportable, - model.config, - max_batch_size=1, - window_size=effective_cache_len, - dtype=torch_dtype, - ) - - if use_custom_sdpa: - # Re-register attention with sliding window closure - from executorch.backends.mlx.llm.hf_attention import ( - register_mlx_sliding_window_attention, - ) - - register_mlx_sliding_window_attention(exportable) - model.config._attn_implementation = "mlx_sliding_window" - logger.info( - " Registered sliding window attention (mlx_sliding_window)" - ) - - logger.info(" RingBuffer KV cache installed successfully") else: - # Use standard linear cache for non-sliding-window models - from executorch.backends.mlx.llm.source_transformation import ( - replace_hf_cache_with_mlx, - ) - logger.info("Replacing HuggingFace StaticCache with HFStaticCache...") - replace_hf_cache_with_mlx( - exportable, - model.config, - max_batch_size=1, - max_cache_len=effective_cache_len, - dtype=torch_dtype, - ) - logger.info(" HFStaticCache installed successfully") + + replace_hf_cache_with_mlx( + exportable, + model.config, + max_batch_size=1, + max_cache_len=effective_cache_len, + dtype=torch_dtype, + ) + logger.info(" HFStaticCache installed successfully") from executorch.backends.mlx.llm.quantization import quantize_model_ From 41e3a51af49c3120496aa0c771b3d5d13fedd85b Mon Sep 17 00:00:00 2001 From: Zeel Date: Thu, 30 Apr 2026 07:37:09 -0400 Subject: [PATCH 13/16] Disable custom SDPA for Gemma 4 MLX export --- backends/mlx/examples/llm/export_llm_hf.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index abb66d49326..44a51c58768 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -215,13 +215,20 @@ def _export_with_custom_components( } torch_dtype = torch_dtype_map.get(dtype, torch.bfloat16) - if use_custom_sdpa: + effective_use_custom_sdpa = use_custom_sdpa + if model_id == _GEMMA4_MODEL_ID and use_custom_sdpa: + logger.info( + "Disabling custom SDPA for Gemma 4 while keeping the custom cache path" + ) + effective_use_custom_sdpa = False + + if effective_use_custom_sdpa: from executorch.backends.mlx.llm.hf_attention import register_mlx_attention register_mlx_attention() logger.info("Registered MLX custom SDPA attention") - attn_implementation = "mlx" if use_custom_sdpa else None + attn_implementation = "mlx" if effective_use_custom_sdpa else None logger.info(f"Loading HuggingFace model: {model_id}") load_kwargs = { From 9d3f841be7d9409bfe5e740793b9572ddb5479de Mon Sep 17 00:00:00 2001 From: Zeel Date: Thu, 30 Apr 2026 08:01:00 -0400 Subject: [PATCH 14/16] Disable custom cache path for Gemma 4 MLX export --- backends/mlx/examples/llm/export_llm_hf.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index 44a51c58768..2189c4cfdc3 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -216,11 +216,15 @@ def _export_with_custom_components( torch_dtype = torch_dtype_map.get(dtype, torch.bfloat16) effective_use_custom_sdpa = use_custom_sdpa + effective_use_custom_kv_cache = use_custom_kv_cache if model_id == _GEMMA4_MODEL_ID and use_custom_sdpa: logger.info( "Disabling custom SDPA for Gemma 4 while keeping the custom cache path" ) effective_use_custom_sdpa = False + if model_id == _GEMMA4_MODEL_ID and use_custom_kv_cache: + logger.info("Disabling custom KV cache for Gemma 4") + effective_use_custom_kv_cache = False if effective_use_custom_sdpa: from executorch.backends.mlx.llm.hf_attention import register_mlx_attention @@ -288,7 +292,7 @@ def _export_with_custom_components( max_cache_len=effective_cache_len, ) - if use_custom_kv_cache: + if effective_use_custom_kv_cache: from executorch.backends.mlx.llm.source_transformation import ( replace_hf_cache_with_mlx, ) From 719d2e837bba1dea4802876f302c95538b6ce298 Mon Sep 17 00:00:00 2001 From: Zeel Date: Thu, 30 Apr 2026 08:32:16 -0400 Subject: [PATCH 15/16] Route Gemma 4 MLX export through optimum fallback path --- backends/mlx/examples/llm/export_llm_hf.py | 24 +++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index 2189c4cfdc3..48b8f3c19e6 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -421,10 +421,24 @@ def export_llama_hf( use_custom_sdpa: Use MLX custom SDPA (mlx::custom_sdpa) use_custom_kv_cache: Use MLX custom KV cache (mlx::kv_cache_update) """ - if use_custom_sdpa or use_custom_kv_cache: + effective_use_custom_sdpa = use_custom_sdpa + effective_use_custom_kv_cache = use_custom_kv_cache + if model_id == _GEMMA4_MODEL_ID: + if effective_use_custom_sdpa: + logger.info( + "Disabling custom SDPA for Gemma 4 and falling back to the baseline export path" + ) + effective_use_custom_sdpa = False + if effective_use_custom_kv_cache: + logger.info( + "Disabling custom KV cache for Gemma 4 and falling back to the baseline export path" + ) + effective_use_custom_kv_cache = False + + if effective_use_custom_sdpa or effective_use_custom_kv_cache: logger.info( - f"Using custom components: sdpa={use_custom_sdpa}, " - f"kv_cache={use_custom_kv_cache}" + f"Using custom components: sdpa={effective_use_custom_sdpa}, " + f"kv_cache={effective_use_custom_kv_cache}" ) _export_with_custom_components( model_id=model_id, @@ -434,8 +448,8 @@ def export_llama_hf( dtype=dtype, qlinear=qlinear, qembedding=qembedding, - use_custom_sdpa=use_custom_sdpa, - use_custom_kv_cache=use_custom_kv_cache, + use_custom_sdpa=effective_use_custom_sdpa, + use_custom_kv_cache=effective_use_custom_kv_cache, no_tie_word_embeddings=no_tie_word_embeddings, qlinear_group_size=qlinear_group_size, qembedding_group_size=qembedding_group_size, From 065b50e5dd52d41bd143c0531af127f04b73754f Mon Sep 17 00:00:00 2001 From: Zeel Date: Fri, 1 May 2026 09:22:04 -0400 Subject: [PATCH 16/16] Clean up Gemma 4 MLX path for macOS 15 CI --- .github/workflows/mlx.yml | 17 ++-- backends/mlx/builder/program_builder.py | 56 ++++------- backends/mlx/examples/llm/README.md | 2 - backends/mlx/examples/llm/export_llm_hf.py | 102 ++------------------- 4 files changed, 33 insertions(+), 144 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index e995e5e88f0..cf0bc2bcfc0 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -489,17 +489,25 @@ jobs: name: "gemma3-1b" use-custom: [false, true] qconfig: ["4w", "nvfp4"] + runner: ["macos-14-xlarge"] include: - model: id: "google/gemma-4-E2B-it" name: "gemma4-e2b" use-custom: true qconfig: "4w" + runner: "macos-15-xlarge" + - model: + id: "google/gemma-4-E2B-it" + name: "gemma4-e2b" + use-custom: false + qconfig: "4w" + runner: "macos-15-xlarge" uses: pytorch/test-infra/.github/workflows/macos_job.yml@main secrets: inherit with: job-name: test-mlx-llm-${{ matrix.model.name }}${{ matrix.use-custom && '-custom' || '' }}-${{ matrix.qconfig }} - runner: macos-14-xlarge + runner: ${{ matrix.runner }} python-version: "3.12" submodules: recursive ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -512,11 +520,6 @@ jobs: MODEL_NAME="${{ matrix.model.name }}" USE_CUSTOM="${{ matrix.use-custom }}" QCONFIG="${{ matrix.qconfig }}" - MODEL_REVISION="" - if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then - MODEL_REVISION="b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf" - fi - CUSTOM_ARGS="" if [ "${USE_CUSTOM}" = "true" ]; then CUSTOM_ARGS="--use-custom-sdpa --use-custom-kv-cache" @@ -551,7 +554,6 @@ jobs: echo "::group::Export ${MODEL_NAME}" ${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --model-id "${MODEL_ID}" \ - ${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \ --output /tmp/${MODEL_NAME}.pte \ --qlinear ${QCONFIG} \ ${QEMBEDDING_ARGS} \ @@ -562,7 +564,6 @@ jobs: OUTPUT=$(${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.run_llm_hf \ --pte /tmp/${MODEL_NAME}.pte \ --model-id "${MODEL_ID}" \ - ${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \ --prompt "What is the capital of France?" \ --max-new-tokens 50 2>&1) echo "$OUTPUT" diff --git a/backends/mlx/builder/program_builder.py b/backends/mlx/builder/program_builder.py index 21ae8a3b6fa..0892476fedd 100644 --- a/backends/mlx/builder/program_builder.py +++ b/backends/mlx/builder/program_builder.py @@ -444,50 +444,26 @@ def _make_io_slots(self): # noqa: C901 else: raise NotImplementedError(f"Support for input {arg} is not implemented") - placeholder_nodes = { - node.name: node for node in self.ep.graph.nodes if node.op == "placeholder" - } - - # Allocate placeholder-backed slots in graph-signature order instead of - # raw FX node traversal order. This keeps lifted constant tids stable - # across equivalent exports, which matters for models like Gemma 4 that - # carry multiple rotary constant placeholders with similar structure. - for name in constant_tensors: - node = placeholder_nodes.get(name) - if node is None or node.users == {}: - continue - self.make_or_get_slot(node, id_space=IdSpace.Constant) - - for name in user_inputs: - node = placeholder_nodes.get(name) - if node is None or node.users == {}: - continue - val = node.meta.get("val", None) - if isinstance(val, torch.Tensor) and not val.is_contiguous(): - raise ValueError( - f"MLX backend requires contiguous input tensors, " - f"but input '{node.name}' has non-contiguous strides. " - f"shape={list(val.shape)}, stride={list(val.stride())}. " - f"Ensure example inputs passed to torch.export.export() " - f"are contiguous (call .contiguous() on them)." - ) - self.make_or_get_slot(node, id_space=IdSpace.Input) - - for name in mutable_buffers: - node = placeholder_nodes.get(name) - if node is None or node.users == {}: - continue - self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) - - classified_placeholders = ( - set(constant_tensors) | set(user_inputs) | set(mutable_buffers) - ) - for node in self.ep.graph.nodes: if node.op == "placeholder": if node.users == {}: continue - if node.name not in classified_placeholders: + if node.name in constant_tensors: + self.make_or_get_slot(node, id_space=IdSpace.Constant) + elif node.name in user_inputs: + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor) and not val.is_contiguous(): + raise ValueError( + f"MLX backend requires contiguous input tensors, " + f"but input '{node.name}' has non-contiguous strides. " + f"shape={list(val.shape)}, stride={list(val.stride())}. " + f"Ensure example inputs passed to torch.export.export() " + f"are contiguous (call .contiguous() on them)." + ) + self.make_or_get_slot(node, id_space=IdSpace.Input) + elif node.name in mutable_buffers: + self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) + else: raise NotImplementedError( f"Support for placeholder {node.name} is not implemented" ) diff --git a/backends/mlx/examples/llm/README.md b/backends/mlx/examples/llm/README.md index 8def8c1f06a..738bbfb8c14 100644 --- a/backends/mlx/examples/llm/README.md +++ b/backends/mlx/examples/llm/README.md @@ -57,7 +57,6 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \ # Gemma 4 text-only export python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --model-id "google/gemma-4-E2B-it" \ - --revision "b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf" \ --output gemma4_hf_int4.pte \ --use-custom-sdpa \ --use-custom-kv-cache \ @@ -109,7 +108,6 @@ Validated Gemma 4 run command: python -m executorch.backends.mlx.examples.llm.run_llm_hf \ --pte gemma4_hf_int4.pte \ --model-id google/gemma-4-E2B-it \ - --revision b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf \ --prompt "What is the capital of France?" \ --max-new-tokens 50 ``` diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index 48b8f3c19e6..fe6b8094f6b 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -47,53 +47,6 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) -_GEMMA4_MODEL_ID = "google/gemma-4-E2B-it" -_GEMMA4_PROBLEM_LAYER_FQN = "model.language_model.layers.22.mlp.down_proj" - - -def _get_submodule_by_fqn(root: torch.nn.Module, fqn: str) -> torch.nn.Module: - cur = root - for part in fqn.split("."): - if part.isdigit(): - cur = cur[int(part)] # type: ignore[index] - else: - cur = getattr(cur, part) - return cur - - -def _capture_gemma4_float_fallback_weight( - model_id: str, - qlinear: Optional[str], - model: torch.nn.Module, -) -> Optional[torch.Tensor]: - if model_id != _GEMMA4_MODEL_ID or qlinear != "4w": - return None - - layer = _get_submodule_by_fqn(model, _GEMMA4_PROBLEM_LAYER_FQN) - weight = layer.weight.detach().clone() - logger.info( - "Saving %s in floating point to avoid the current Gemma 4 4w mismatch", - _GEMMA4_PROBLEM_LAYER_FQN, - ) - return weight - - -def _restore_gemma4_float_fallback_weight( - model_id: str, - qlinear: Optional[str], - model: torch.nn.Module, - weight: Optional[torch.Tensor], -) -> None: - if weight is None or model_id != _GEMMA4_MODEL_ID or qlinear != "4w": - return - - layer = _get_submodule_by_fqn(model, _GEMMA4_PROBLEM_LAYER_FQN) - layer.weight = torch.nn.Parameter(weight, requires_grad=False) - logger.info( - "Restored %s in floating point after quantization", - _GEMMA4_PROBLEM_LAYER_FQN, - ) - def _export_with_optimum( model_id: str, @@ -128,10 +81,6 @@ def _export_with_optimum( from executorch.backends.mlx.llm.quantization import quantize_model_ - gemma4_float_weight = _capture_gemma4_float_fallback_weight( - model_id, qlinear, exportable.model - ) - quantize_model_( exportable.model, qlinear_config=qlinear, @@ -143,9 +92,6 @@ def _export_with_optimum( ) and not no_tie_word_embeddings, ) - _restore_gemma4_float_fallback_weight( - model_id, qlinear, exportable.model, gemma4_float_weight - ) logger.info("Exporting model with torch.export...") exported_progs = exportable.export() @@ -215,24 +161,13 @@ def _export_with_custom_components( } torch_dtype = torch_dtype_map.get(dtype, torch.bfloat16) - effective_use_custom_sdpa = use_custom_sdpa - effective_use_custom_kv_cache = use_custom_kv_cache - if model_id == _GEMMA4_MODEL_ID and use_custom_sdpa: - logger.info( - "Disabling custom SDPA for Gemma 4 while keeping the custom cache path" - ) - effective_use_custom_sdpa = False - if model_id == _GEMMA4_MODEL_ID and use_custom_kv_cache: - logger.info("Disabling custom KV cache for Gemma 4") - effective_use_custom_kv_cache = False - - if effective_use_custom_sdpa: + if use_custom_sdpa: from executorch.backends.mlx.llm.hf_attention import register_mlx_attention register_mlx_attention() logger.info("Registered MLX custom SDPA attention") - attn_implementation = "mlx" if effective_use_custom_sdpa else None + attn_implementation = "mlx" if use_custom_sdpa else None logger.info(f"Loading HuggingFace model: {model_id}") load_kwargs = { @@ -292,7 +227,7 @@ def _export_with_custom_components( max_cache_len=effective_cache_len, ) - if effective_use_custom_kv_cache: + if use_custom_kv_cache: from executorch.backends.mlx.llm.source_transformation import ( replace_hf_cache_with_mlx, ) @@ -316,10 +251,6 @@ def _export_with_custom_components( from executorch.backends.mlx.llm.quantization import quantize_model_ - gemma4_float_weight = _capture_gemma4_float_fallback_weight( - model_id, qlinear, exportable.model - ) - quantize_model_( exportable.model, qlinear_config=qlinear, @@ -329,9 +260,6 @@ def _export_with_custom_components( tie_word_embeddings=getattr(model.config, "tie_word_embeddings", False) and not no_tie_word_embeddings, ) - _restore_gemma4_float_fallback_weight( - model_id, qlinear, exportable.model, gemma4_float_weight - ) logger.info("Exporting model with torch.export...") seq_length = 3 @@ -421,24 +349,10 @@ def export_llama_hf( use_custom_sdpa: Use MLX custom SDPA (mlx::custom_sdpa) use_custom_kv_cache: Use MLX custom KV cache (mlx::kv_cache_update) """ - effective_use_custom_sdpa = use_custom_sdpa - effective_use_custom_kv_cache = use_custom_kv_cache - if model_id == _GEMMA4_MODEL_ID: - if effective_use_custom_sdpa: - logger.info( - "Disabling custom SDPA for Gemma 4 and falling back to the baseline export path" - ) - effective_use_custom_sdpa = False - if effective_use_custom_kv_cache: - logger.info( - "Disabling custom KV cache for Gemma 4 and falling back to the baseline export path" - ) - effective_use_custom_kv_cache = False - - if effective_use_custom_sdpa or effective_use_custom_kv_cache: + if use_custom_sdpa or use_custom_kv_cache: logger.info( - f"Using custom components: sdpa={effective_use_custom_sdpa}, " - f"kv_cache={effective_use_custom_kv_cache}" + f"Using custom components: sdpa={use_custom_sdpa}, " + f"kv_cache={use_custom_kv_cache}" ) _export_with_custom_components( model_id=model_id, @@ -448,8 +362,8 @@ def export_llama_hf( dtype=dtype, qlinear=qlinear, qembedding=qembedding, - use_custom_sdpa=effective_use_custom_sdpa, - use_custom_kv_cache=effective_use_custom_kv_cache, + use_custom_sdpa=use_custom_sdpa, + use_custom_kv_cache=use_custom_kv_cache, no_tie_word_embeddings=no_tie_word_embeddings, qlinear_group_size=qlinear_group_size, qembedding_group_size=qembedding_group_size,