-
Notifications
You must be signed in to change notification settings - Fork 972
Add Gemma 4 MLX install-path support #19065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5f455a2
0a822bd
fd78741
0e00290
3a26baa
0bf5fc4
90e5577
ee272c3
ca37250
818a51d
6e520dd
391cde4
19d6f09
41e3a51
9d3f841
719d2e8
065b50e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,6 +50,7 @@ | |
|
|
||
| def _export_with_optimum( | ||
| model_id: str, | ||
| revision: Optional[str], | ||
| output_path: str, | ||
| max_seq_len: int, | ||
| dtype: str, | ||
|
|
@@ -73,6 +74,7 @@ def _export_with_optimum( | |
| logger.info(f"Loading model using optimum-executorch: {model_id}") | ||
| exportable = load_causal_lm_model( | ||
| model_id, | ||
| revision=revision, | ||
| dtype=dtype_str, | ||
| max_seq_len=max_seq_len, | ||
| ) | ||
|
|
@@ -124,6 +126,7 @@ def _export_with_optimum( | |
|
|
||
| def _export_with_custom_components( | ||
| model_id: str, | ||
| revision: Optional[str], | ||
| output_path: str, | ||
| max_seq_len: int, | ||
| dtype: str, | ||
|
|
@@ -166,20 +169,21 @@ def _export_with_custom_components( | |
|
|
||
| attn_implementation = "mlx" if use_custom_sdpa else None | ||
|
|
||
| # Detect sliding window models (e.g., gemma) | ||
| sliding_window = None | ||
|
|
||
| logger.info(f"Loading HuggingFace model: {model_id}") | ||
| load_kwargs = { | ||
| "torch_dtype": torch_dtype, | ||
| "low_cpu_mem_usage": True, | ||
| } | ||
| if revision is not None: | ||
| load_kwargs["revision"] = revision | ||
| if attn_implementation: | ||
| load_kwargs["attn_implementation"] = attn_implementation | ||
| model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs) | ||
|
|
||
| # Check if model uses sliding window attention | ||
| sliding_window = getattr(model.config, "sliding_window", None) | ||
| # Check if model uses sliding window attention. Multimodal configs like | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this regress gemma3?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don’t expect this to regress Gemma 3. The change is just switching the sliding-window lookup to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it would be great to try on gemma3 as a smoke test, that would be great. If you are unable to access the version from Google, try the unsloth version unsloth/gemma-3-1b-it (https://github.com/pytorch/executorch/blob/main/.github/workflows/mlx.yml#L469C18-L469C39) |
||
| # Gemma 4 keep transformer attributes under text_config. | ||
| text_config = model.config.get_text_config() | ||
| sliding_window = getattr(text_config, "sliding_window", None) | ||
| if sliding_window is not None: | ||
| logger.info(f"Model has sliding_window={sliding_window}") | ||
| # Cap max_seq_len to sliding window size for cache allocation | ||
|
|
@@ -188,11 +192,16 @@ def _export_with_custom_components( | |
| else: | ||
| effective_cache_len = max_seq_len | ||
|
|
||
| # The HF ExecuTorch cache wrappers validate both generation_config.use_cache | ||
| # and the text config's use_cache flag before constructing static caches. | ||
| model.generation_config.use_cache = True | ||
| model.generation_config.cache_implementation = "static" | ||
| model.generation_config.cache_config = { | ||
| "batch_size": 1, | ||
| "max_cache_len": effective_cache_len, | ||
| } | ||
| text_config = model.config.get_text_config() | ||
| text_config.use_cache = True | ||
| model.eval() | ||
|
|
||
| # Use HybridCache wrapper for sliding window models (stores cache as .cache), | ||
|
|
@@ -219,52 +228,26 @@ def _export_with_custom_components( | |
| ) | ||
|
|
||
| if use_custom_kv_cache: | ||
| if sliding_window is not None: | ||
| # Use ring buffer cache for sliding window models | ||
| from executorch.backends.mlx.llm.source_transformation import ( | ||
| replace_hf_cache_with_mlx_ring_buffer, | ||
| ) | ||
| from executorch.backends.mlx.llm.source_transformation import ( | ||
| replace_hf_cache_with_mlx, | ||
| ) | ||
|
|
||
| if sliding_window is not None: | ||
| logger.info( | ||
| f"Replacing StaticCache with RingBuffer KV cache " | ||
| f"(window_size={effective_cache_len})..." | ||
| "Replacing HuggingFace StaticCache with HFStaticCache " | ||
| f"(capped to sliding window: {effective_cache_len})..." | ||
| ) | ||
| replace_hf_cache_with_mlx_ring_buffer( | ||
| exportable, | ||
| model.config, | ||
| max_batch_size=1, | ||
| window_size=effective_cache_len, | ||
| dtype=torch_dtype, | ||
| ) | ||
|
|
||
| if use_custom_sdpa: | ||
| # Re-register attention with sliding window closure | ||
| from executorch.backends.mlx.llm.hf_attention import ( | ||
| register_mlx_sliding_window_attention, | ||
| ) | ||
|
|
||
| register_mlx_sliding_window_attention(exportable) | ||
| model.config._attn_implementation = "mlx_sliding_window" | ||
| logger.info( | ||
| " Registered sliding window attention (mlx_sliding_window)" | ||
| ) | ||
|
|
||
| logger.info(" RingBuffer KV cache installed successfully") | ||
| else: | ||
| # Use standard linear cache for non-sliding-window models | ||
| from executorch.backends.mlx.llm.source_transformation import ( | ||
| replace_hf_cache_with_mlx, | ||
| ) | ||
|
|
||
| logger.info("Replacing HuggingFace StaticCache with HFStaticCache...") | ||
| replace_hf_cache_with_mlx( | ||
| exportable, | ||
| model.config, | ||
| max_batch_size=1, | ||
| max_cache_len=effective_cache_len, | ||
| dtype=torch_dtype, | ||
| ) | ||
| logger.info(" HFStaticCache installed successfully") | ||
|
|
||
| replace_hf_cache_with_mlx( | ||
| exportable, | ||
| model.config, | ||
| max_batch_size=1, | ||
| max_cache_len=effective_cache_len, | ||
| dtype=torch_dtype, | ||
| ) | ||
| logger.info(" HFStaticCache installed successfully") | ||
|
|
||
| from executorch.backends.mlx.llm.quantization import quantize_model_ | ||
|
|
||
|
|
@@ -341,6 +324,7 @@ def _save_program(executorch_program, output_path: str) -> None: | |
|
|
||
| def export_llama_hf( | ||
| model_id: str, | ||
| revision: Optional[str], | ||
| output_path: str, | ||
| max_seq_len: int = 1024, | ||
| dtype: str = "bf16", | ||
|
|
@@ -372,6 +356,7 @@ def export_llama_hf( | |
| ) | ||
| _export_with_custom_components( | ||
| model_id=model_id, | ||
| revision=revision, | ||
| output_path=output_path, | ||
| max_seq_len=max_seq_len, | ||
| dtype=dtype, | ||
|
|
@@ -387,6 +372,7 @@ def export_llama_hf( | |
| logger.info("Using optimum-executorch pipeline (no custom components)") | ||
| _export_with_optimum( | ||
| model_id=model_id, | ||
| revision=revision, | ||
| output_path=output_path, | ||
| max_seq_len=max_seq_len, | ||
| dtype=dtype, | ||
|
|
@@ -408,6 +394,12 @@ def main(): | |
| default="unsloth/Llama-3.2-1B-Instruct", | ||
| help="HuggingFace model ID", | ||
| ) | ||
| parser.add_argument( | ||
| "--revision", | ||
| type=str, | ||
| default=None, | ||
| help="Optional HuggingFace model revision/commit to pin", | ||
| ) | ||
| parser.add_argument( | ||
| "--output", | ||
| type=str, | ||
|
|
@@ -447,6 +439,7 @@ def main(): | |
|
|
||
| export_llama_hf( | ||
| model_id=args.model_id, | ||
| revision=args.revision, | ||
| output_path=args.output, | ||
| max_seq_len=args.max_seq_len, | ||
| dtype=args.dtype, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why no embeeding?