From 8a2dfb53ab72a2fb5105f8d9cbed6427565f732f Mon Sep 17 00:00:00 2001 From: john-rocky Date: Fri, 1 May 2026 14:29:49 +0900 Subject: [PATCH 1/3] Add --sliding_window flag to CoreML static LLM export MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Models trained with sliding-window attention (Mistral 7B, Gemma 3, Gemma 4, Llama 4 Scout, …) only need each layer to attend to the last `W` tokens, but `export_static_llm_coreml.py` was always sizing the per-layer KV cache to `max_context_len - input_len`. That made longer contexts proportionally more expensive in both KV cache memory and per-token attention compute, even though the model was trained to ignore everything outside the window. Add a `--sliding_window` flag that caps the cache at the trained window. The downstream pieces — `StaticAttentionMask` invariants under cache eviction and the `StaticAttentionIOManager`'s per-layer `cache_lens` plumbing — already support this; the export script just needed to expose it. Per-layer mixed sliding/full attention (Gemma 3/4) is left for a follow-up; this PR uses one window for every layer. The cache_len computation is factored into `_resolve_cache_len` so it is unit-testable, and the README's ANE Optimizations section documents the new option. ### Memory savings example For a 32-layer / n_kv_heads=8 / head_dim=128 model exported with `max_context_len=8192` in fp16, dropping the cache from 8160 to 4096 cuts the per-method KV cache from ~1.07 GB to ~0.54 GB. --- .../coreml/llama/export_static_llm_coreml.py | 53 ++++++++++-- examples/apple/coreml/llama/readme.md | 2 + examples/apple/coreml/llama/test.py | 83 +++++++++++++++++++ 3 files changed, 133 insertions(+), 5 deletions(-) diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index 276ff6d193a..3a70ad02d53 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -21,6 +21,7 @@ import argparse import json +from typing import Optional import coremltools as ct import torch @@ -170,6 +171,28 @@ def load_model( return model, args +def _resolve_cache_len( + max_context_len: int, input_len: int, sliding_window: Optional[int] = None +) -> int: + """Pick the per-layer KV cache length given context / input / window settings. + + Without sliding-window attention the cache must hold every token that can + attend to the current step, i.e. ``max_context_len - input_len``. When the + model is trained with sliding-window attention we instead cap the cache at + ``sliding_window`` so longer contexts do not enlarge per-layer attention + compute or KV cache memory. + """ + cache_len = max_context_len - input_len + if sliding_window is not None: + if sliding_window <= 0: + raise ValueError( + f"sliding_window must be positive, got {sliding_window}" + ) + if sliding_window < cache_len: + cache_len = sliding_window + return cache_len + + def _create_example_inputs( model_args, input_len, max_context_len, float_dtype, cache_len=None ): @@ -410,6 +433,18 @@ def main(): default=32, help="Input sequence length per forward pass", ) + parser.add_argument( + "--sliding_window", + type=int, + default=None, + help=( + "Sliding window attention size. When set, every layer uses a KV cache " + "of this many tokens instead of (max_context_len - input_len), which " + "lets the model serve longer contexts without growing per-layer attention " + "compute or KV cache memory. Required for Mistral / Gemma3 / Gemma4 / " + "Llama4-style models that train with sliding-window attention." + ), + ) parser.add_argument( "--dtype", type=str, @@ -481,11 +516,15 @@ def main(): print(f"\tLinear quantize: {args.linear_quantize}") print(f"\tDtype: {args.dtype}") - cache_len = args.max_context_len - args.input_len + cache_len = _resolve_cache_len( + args.max_context_len, args.input_len, args.sliding_window + ) print("\nGeneration configuration:") print(f"\tMax context length: {args.max_context_len}") print(f"\tInput length: {args.input_len}") print(f"\tCache length: {cache_len}") + if args.sliding_window is not None: + print(f"\tSliding window: {args.sliding_window}") print("\nLinear splitting:") print(f"\tTarget split size: {args.target_split_size}") @@ -513,9 +552,9 @@ def main(): # the same cache buffer at runtime without any copying. decode_input_len = 1 prefill_input_len = args.input_len # default 32 - shared_cache_len = ( - args.max_context_len - decode_input_len - ) # Use decode's cache size for both + shared_cache_len = _resolve_cache_len( + args.max_context_len, decode_input_len, args.sliding_window + ) print(f"\nShared cache length for prefill/decode: {shared_cache_len}") @@ -643,7 +682,11 @@ def main(): # Single method mode: fixed seqlen with generate_full_logits=True for lookahead print(f"\nCreating example inputs (seqlen={args.input_len})...") example_inputs, example_cache_len = _create_example_inputs( - model_args, args.input_len, args.max_context_len, float_dtype + model_args, + args.input_len, + args.max_context_len, + float_dtype, + cache_len=cache_len, ) # Test eager execution diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index ae3852a7828..5f8aa1bcd78 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -69,6 +69,7 @@ Key differences between the two modes: | `--max_context_len` | 1024 | Maximum context length | | `--input_len` | 32 | Input sequence length per forward pass. In multifunction mode, this is the prefill sequence length. | | `--dtype` | `fp16` | Model dtype (`fp16` or `fp32`). The ANE requires fp16. | +| `--sliding_window` | (off) | Sliding-window attention size. When set, every layer uses a KV cache of this many tokens instead of `max_context_len - input_len`. Required for Mistral / Gemma3 / Gemma4 / Llama4-style models trained with sliding-window attention; lets longer contexts run without growing per-layer attention compute or KV cache memory. | ### Quantization Options | Option | Default | Description | @@ -94,6 +95,7 @@ The static model has several ANE optimizations, including: * Splitting linear layers for improved performance (controlled by target_split_size and max_splits args) * Splitting the pte into multiple Core ML pieces for improved performance (can be disabled with no_graph_breaks) * Re-writing SDPA to avoid 5-D tensors to improve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833) +* Sliding-window attention (`--sliding_window N`) caps each layer's KV cache at `N` tokens regardless of `max_context_len`. For models trained with sliding-window attention (Mistral 7B uses 4096; Gemma 3/Gemma 4 alternate sliding and full layers), this both matches training behavior and roughly halves KV-cache memory plus per-token attention FLOPs at long context. Per-layer mixed sliding/full attention is not yet exposed; today every layer shares the same window when the flag is set. We are working on adding a C++ runner as well. diff --git a/examples/apple/coreml/llama/test.py b/examples/apple/coreml/llama/test.py index 895cf2e1cce..73ca7712b1a 100644 --- a/examples/apple/coreml/llama/test.py +++ b/examples/apple/coreml/llama/test.py @@ -9,9 +9,13 @@ sys.path.insert(0, ".") import copy +import pytest import torch +from export_static_llm_coreml import _create_example_inputs, _resolve_cache_len from utils import replace_linear_with_split_linear +from executorch.examples.models.llama.model_args import ModelArgs + def get_split_model( model, @@ -44,5 +48,84 @@ def test_split_model(): assert torch.allclose(model(inputs), model3(inputs), atol=1e-5) +def test_resolve_cache_len_no_sliding_window(): + # Without --sliding_window the cache fills the rest of the context. + assert _resolve_cache_len(1024, 32) == 992 + assert _resolve_cache_len(1024, 1) == 1023 + + +def test_resolve_cache_len_with_sliding_window(): + # When the window is smaller than the remaining context the cache shrinks. + assert _resolve_cache_len(8192, 32, sliding_window=4096) == 4096 + assert _resolve_cache_len(8192, 1, sliding_window=4096) == 4096 + + +def test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op(): + # A user-provided window larger than the remaining context degenerates to + # the no-window case, so users can safely set --sliding_window to a value + # the model was trained with even when the export uses a shorter context. + assert _resolve_cache_len(1024, 32, sliding_window=4096) == 992 + + +def test_resolve_cache_len_rejects_non_positive_window(): + with pytest.raises(ValueError): + _resolve_cache_len(1024, 32, sliding_window=0) + with pytest.raises(ValueError): + _resolve_cache_len(1024, 32, sliding_window=-1) + + +def test_create_example_inputs_with_sliding_window_shrinks_kv_cache(): + # Build a tiny ModelArgs that does not need a checkpoint or torchao. + model_args = ModelArgs( + dim=32, + n_layers=2, + n_heads=4, + n_kv_heads=2, + head_dim=8, + vocab_size=128, + max_context_len=1024, + max_seq_len=1024, + ) + max_context_len = 1024 + input_len = 32 + sliding_window = 64 + + cache_len = _resolve_cache_len(max_context_len, input_len, sliding_window) + assert cache_len == sliding_window + + example_inputs, returned_cache_len = _create_example_inputs( + model_args, + input_len, + max_context_len, + float_dtype=torch.float32, + cache_len=cache_len, + ) + assert returned_cache_len == sliding_window + + # The KV cache tensors live inside the kwargs dict at index 1 under + # in_cache_state. Walking that structure should find caches whose + # sequence dimension equals the sliding window, not max_context_len. + kwargs = example_inputs[1] + in_cache_state = kwargs["in_cache_state"] + cache_seq_dims = set() + for per_kind in in_cache_state: # (k_caches, v_caches) + for cache_tensor in per_kind.values(): + cache_seq_dims.add(cache_tensor.size(-2)) + assert cache_seq_dims == {sliding_window}, ( + f"expected every KV cache to be sized to the sliding window {sliding_window}, " + f"got {cache_seq_dims}" + ) + + # The attention mask covers (input_len + cache_len) along the last dim. + masks = kwargs["masks"] + assert sliding_window in masks + assert masks[sliding_window].shape[-1] == input_len + sliding_window + + if __name__ == "__main__": test_split_model() + test_resolve_cache_len_no_sliding_window() + test_resolve_cache_len_with_sliding_window() + test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op() + test_resolve_cache_len_rejects_non_positive_window() + test_create_example_inputs_with_sliding_window_shrinks_kv_cache() From 9bdf04eb2e097741ee0ab516bd233b7921aa9ca0 Mon Sep 17 00:00:00 2001 From: john-rocky Date: Fri, 1 May 2026 14:38:42 +0900 Subject: [PATCH 2/3] Add per-layer hybrid sliding/full attention to CoreML static LLM export Builds on the prior --sliding_window flag. Gemma 3, Gemma 4, and the Llama 4 Scout family interleave sliding and full attention layers rather than using one global setting: Gemma 4 E2B is '4 sliding + 1 full' repeated 7 times across 35 layers; Gemma 3 is '5 sliding + 1 full' repeated. HuggingFace expresses this as a single integer `sliding_window_pattern`, which is what the new `--sliding_window_pattern` flag mirrors. Implementation: - `_resolve_per_layer_cache_lens(...)` produces a per-layer cache_lens list using the HF rule (layer i is full iff (i+1) % P == 0); the IO manager and the model already accept per-layer cache_lens, so the attention mask dict and the per-layer KV cache shapes follow. - `_get_metadata` now reads each cache's cache_len from the example tensor's sequence dimension instead of receiving a single scalar, so the C++ runner metadata describes each layer correctly under hybrid attention. - Both single-method and multifunction export paths use the per-layer resolver. The previous PR's uniform-sliding behavior is preserved when `--sliding_window_pattern` is not set. Authored with Claude. --- .../coreml/llama/export_static_llm_coreml.py | 107 ++++++++++++- examples/apple/coreml/llama/readme.md | 3 +- examples/apple/coreml/llama/test.py | 149 +++++++++++++++++- 3 files changed, 251 insertions(+), 8 deletions(-) diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index 3a70ad02d53..ea0c6818bc8 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -21,7 +21,7 @@ import argparse import json -from typing import Optional +from typing import List, Optional import coremltools as ct import torch @@ -193,6 +193,51 @@ def _resolve_cache_len( return cache_len +def _resolve_per_layer_cache_lens( + n_layers: int, + max_context_len: int, + input_len: int, + sliding_window: Optional[int] = None, + sliding_window_pattern: Optional[int] = None, +) -> List[int]: + """Compute per-layer KV cache lengths for hybrid sliding/full attention. + + Returns a list of length ``n_layers``. When ``sliding_window_pattern`` is + ``P``, every ``P``-th layer (0-indexed: layers ``P-1, 2P-1, ...``) uses + the full ``max_context_len - input_len`` cache; the remaining layers use + ``sliding_window``. This matches HuggingFace's ``sliding_window_pattern`` + convention used by Gemma 3 (P=6: 5 sliding + 1 full) and Gemma 4 E2B + (P=5: 4 sliding + 1 full). + + When ``sliding_window_pattern`` is ``None``, every layer uses the same + cache length resolved by :func:`_resolve_cache_len`. + """ + full_cache_len = max_context_len - input_len + sliding_cache_len = _resolve_cache_len( + max_context_len, input_len, sliding_window + ) + + if sliding_window_pattern is None: + return [sliding_cache_len] * n_layers + + if sliding_window is None: + raise ValueError( + "sliding_window_pattern requires sliding_window to be set" + ) + if sliding_window_pattern <= 1: + raise ValueError( + "sliding_window_pattern must be at least 2 (P=1 would make every " + f"layer full attention); got {sliding_window_pattern}" + ) + + return [ + full_cache_len + if (i + 1) % sliding_window_pattern == 0 + else sliding_cache_len + for i in range(n_layers) + ] + + def _create_example_inputs( model_args, input_len, max_context_len, float_dtype, cache_len=None ): @@ -296,6 +341,10 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) # Output indices are in the same order (after logits) # Logits is output 0, then k_caches, then v_caches + # Read each cache's actual cache_len from the example tensor shape so + # per-layer hybrid sliding/full attention (Gemma 3/4) reports the right + # length per layer instead of a single uniform value. + k_cache_tensors = example_inputs[1]["in_cache_state"][0] kv_cache_specs = [] for i, cache_id in enumerate(sorted_k_cache_ids): k_in_idx = k_cache_in_indices[cache_id] @@ -304,7 +353,10 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) # v_caches come after k_caches (idx n_layers+1 to 2*n_layers) k_out_idx = 1 + i v_out_idx = 1 + len(sorted_k_cache_ids) + i - kv_cache_specs.append([k_in_idx, k_out_idx, v_in_idx, v_out_idx, cache_len]) + per_cache_len = k_cache_tensors[cache_id].size(-2) + kv_cache_specs.append( + [k_in_idx, k_out_idx, v_in_idx, v_out_idx, per_cache_len] + ) print(f"KV cache specs (k_in, k_out, v_in, v_out, cache_len): {kv_cache_specs}") @@ -445,6 +497,19 @@ def main(): "Llama4-style models that train with sliding-window attention." ), ) + parser.add_argument( + "--sliding_window_pattern", + type=int, + default=None, + help=( + "Period of the sliding/full attention pattern (HuggingFace's " + "sliding_window_pattern). When set together with --sliding_window, " + "every P-th layer (1-indexed) uses full attention while the rest use " + "the sliding window. Use P=5 for Gemma 4 E2B (4 sliding + 1 full) " + "and P=6 for Gemma 3 (5 sliding + 1 full). Without this flag every " + "layer uses the sliding window." + ), + ) parser.add_argument( "--dtype", type=str, @@ -516,6 +581,9 @@ def main(): print(f"\tLinear quantize: {args.linear_quantize}") print(f"\tDtype: {args.dtype}") + if args.sliding_window_pattern is not None and args.sliding_window is None: + parser.error("--sliding_window_pattern requires --sliding_window to be set") + cache_len = _resolve_cache_len( args.max_context_len, args.input_len, args.sliding_window ) @@ -525,6 +593,8 @@ def main(): print(f"\tCache length: {cache_len}") if args.sliding_window is not None: print(f"\tSliding window: {args.sliding_window}") + if args.sliding_window_pattern is not None: + print(f"\tSliding window pattern: every {args.sliding_window_pattern}-th layer is full") print("\nLinear splitting:") print(f"\tTarget split size: {args.target_split_size}") @@ -552,11 +622,22 @@ def main(): # the same cache buffer at runtime without any copying. decode_input_len = 1 prefill_input_len = args.input_len # default 32 - shared_cache_len = _resolve_cache_len( - args.max_context_len, decode_input_len, args.sliding_window + shared_cache_len = _resolve_per_layer_cache_lens( + n_layers=model_args.n_layers, + max_context_len=args.max_context_len, + input_len=decode_input_len, + sliding_window=args.sliding_window, + sliding_window_pattern=args.sliding_window_pattern, ) - print(f"\nShared cache length for prefill/decode: {shared_cache_len}") + if args.sliding_window_pattern is not None: + n_full = sum(1 for cl in shared_cache_len if cl == args.max_context_len - decode_input_len) + print( + f"\nShared cache lengths for prefill/decode: {n_full} full + " + f"{model_args.n_layers - n_full} sliding ({args.sliding_window} tokens)" + ) + else: + print(f"\nShared cache length for prefill/decode: {shared_cache_len[0]}") print(f"\nCreating example inputs for decode (seqlen={decode_input_len})...") decode_inputs, decode_cache_len = _create_example_inputs( @@ -680,13 +761,27 @@ def main(): ) else: # Single method mode: fixed seqlen with generate_full_logits=True for lookahead + per_layer_cache_lens = _resolve_per_layer_cache_lens( + n_layers=model_args.n_layers, + max_context_len=args.max_context_len, + input_len=args.input_len, + sliding_window=args.sliding_window, + sliding_window_pattern=args.sliding_window_pattern, + ) + if args.sliding_window_pattern is not None: + full_cache_len = args.max_context_len - args.input_len + n_full = sum(1 for cl in per_layer_cache_lens if cl == full_cache_len) + print( + f"\nCache length per layer: {n_full} full ({full_cache_len} tokens) + " + f"{model_args.n_layers - n_full} sliding ({args.sliding_window} tokens)" + ) print(f"\nCreating example inputs (seqlen={args.input_len})...") example_inputs, example_cache_len = _create_example_inputs( model_args, args.input_len, args.max_context_len, float_dtype, - cache_len=cache_len, + cache_len=per_layer_cache_lens, ) # Test eager execution diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index 5f8aa1bcd78..f1011b105e7 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -70,6 +70,7 @@ Key differences between the two modes: | `--input_len` | 32 | Input sequence length per forward pass. In multifunction mode, this is the prefill sequence length. | | `--dtype` | `fp16` | Model dtype (`fp16` or `fp32`). The ANE requires fp16. | | `--sliding_window` | (off) | Sliding-window attention size. When set, every layer uses a KV cache of this many tokens instead of `max_context_len - input_len`. Required for Mistral / Gemma3 / Gemma4 / Llama4-style models trained with sliding-window attention; lets longer contexts run without growing per-layer attention compute or KV cache memory. | +| `--sliding_window_pattern` | (off) | Period of the sliding/full attention pattern (HuggingFace's `sliding_window_pattern`). When set together with `--sliding_window`, every P-th layer (1-indexed) uses full attention while the rest use the sliding window. Use `P=5` for Gemma 4 E2B (4 sliding + 1 full) and `P=6` for Gemma 3 (5 sliding + 1 full). | ### Quantization Options | Option | Default | Description | @@ -95,7 +96,7 @@ The static model has several ANE optimizations, including: * Splitting linear layers for improved performance (controlled by target_split_size and max_splits args) * Splitting the pte into multiple Core ML pieces for improved performance (can be disabled with no_graph_breaks) * Re-writing SDPA to avoid 5-D tensors to improve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833) -* Sliding-window attention (`--sliding_window N`) caps each layer's KV cache at `N` tokens regardless of `max_context_len`. For models trained with sliding-window attention (Mistral 7B uses 4096; Gemma 3/Gemma 4 alternate sliding and full layers), this both matches training behavior and roughly halves KV-cache memory plus per-token attention FLOPs at long context. Per-layer mixed sliding/full attention is not yet exposed; today every layer shares the same window when the flag is set. +* Sliding-window attention (`--sliding_window N`) caps each layer's KV cache at `N` tokens regardless of `max_context_len`. For models trained with sliding-window attention (Mistral 7B uses 4096; Gemma 3 / Gemma 4 alternate sliding and full layers), this both matches training behavior and roughly halves KV-cache memory plus per-token attention FLOPs at long context. Pair with `--sliding_window_pattern P` to mix sliding and full layers in the HuggingFace pattern (every P-th layer is full attention): P=5 for Gemma 4 E2B (4 sliding + 1 full) and P=6 for Gemma 3 (5 sliding + 1 full). We are working on adding a C++ runner as well. diff --git a/examples/apple/coreml/llama/test.py b/examples/apple/coreml/llama/test.py index 73ca7712b1a..49ac4ea9cc6 100644 --- a/examples/apple/coreml/llama/test.py +++ b/examples/apple/coreml/llama/test.py @@ -11,7 +11,11 @@ import pytest import torch -from export_static_llm_coreml import _create_example_inputs, _resolve_cache_len +from export_static_llm_coreml import ( + _create_example_inputs, + _resolve_cache_len, + _resolve_per_layer_cache_lens, +) from utils import replace_linear_with_split_linear from executorch.examples.models.llama.model_args import ModelArgs @@ -122,6 +126,142 @@ def test_create_example_inputs_with_sliding_window_shrinks_kv_cache(): assert masks[sliding_window].shape[-1] == input_len + sliding_window +def test_per_layer_cache_lens_uniform_when_no_pattern(): + # Without a pattern every layer gets the same cache length. + out = _resolve_per_layer_cache_lens( + n_layers=4, max_context_len=1024, input_len=32, sliding_window=64 + ) + assert out == [64, 64, 64, 64] + + +def test_per_layer_cache_lens_uniform_full_when_no_window(): + # No window at all is just `max_context_len - input_len` everywhere. + out = _resolve_per_layer_cache_lens( + n_layers=4, max_context_len=1024, input_len=32 + ) + assert out == [992, 992, 992, 992] + + +def test_per_layer_cache_lens_gemma4_e2b_pattern(): + # Gemma 4 E2B: 35 layers, P=5 → 4 sliding + 1 full repeated 7 times. + out = _resolve_per_layer_cache_lens( + n_layers=35, + max_context_len=8192, + input_len=32, + sliding_window=4096, + sliding_window_pattern=5, + ) + full = 8192 - 32 + sliding = 4096 + assert len(out) == 35 + # Layers at 1-indexed positions 5, 10, 15, …, 35 are full. + assert [out[i] for i in range(35)] == [ + full if (i + 1) % 5 == 0 else sliding for i in range(35) + ] + assert sum(1 for cl in out if cl == full) == 7 + assert sum(1 for cl in out if cl == sliding) == 28 + + +def test_per_layer_cache_lens_gemma3_pattern(): + # Gemma 3 uses P=6 (5 sliding + 1 full). + out = _resolve_per_layer_cache_lens( + n_layers=12, + max_context_len=2048, + input_len=32, + sliding_window=512, + sliding_window_pattern=6, + ) + full = 2048 - 32 + sliding = 512 + assert out == [ + sliding, + sliding, + sliding, + sliding, + sliding, + full, + sliding, + sliding, + sliding, + sliding, + sliding, + full, + ] + + +def test_per_layer_cache_lens_pattern_requires_sliding_window(): + with pytest.raises(ValueError): + _resolve_per_layer_cache_lens( + n_layers=8, + max_context_len=1024, + input_len=32, + sliding_window=None, + sliding_window_pattern=5, + ) + + +def test_per_layer_cache_lens_rejects_pattern_le_one(): + # P=1 would make every layer full and is almost certainly a typo, so + # surface it rather than silently doing the no-pattern thing. + with pytest.raises(ValueError): + _resolve_per_layer_cache_lens( + n_layers=8, + max_context_len=1024, + input_len=32, + sliding_window=64, + sliding_window_pattern=1, + ) + + +def test_create_example_inputs_with_per_layer_pattern_yields_two_cache_sizes(): + model_args = ModelArgs( + dim=32, + n_layers=10, + n_heads=4, + n_kv_heads=2, + head_dim=8, + vocab_size=128, + max_context_len=1024, + max_seq_len=1024, + ) + max_context_len = 1024 + input_len = 32 + sliding_window = 64 + pattern = 5 + + cache_lens = _resolve_per_layer_cache_lens( + n_layers=model_args.n_layers, + max_context_len=max_context_len, + input_len=input_len, + sliding_window=sliding_window, + sliding_window_pattern=pattern, + ) + + example_inputs, _ = _create_example_inputs( + model_args, + input_len, + max_context_len, + float_dtype=torch.float32, + cache_len=cache_lens, + ) + + in_cache_state = example_inputs[1]["in_cache_state"] + seen = set() + for per_kind in in_cache_state: + for tensor in per_kind.values(): + seen.add(tensor.size(-2)) + full = max_context_len - input_len + assert seen == {sliding_window, full}, ( + f"expected both {sliding_window} (sliding) and {full} (full) cache sizes, got {seen}" + ) + + # Both cache_len values get their own mask; (input_len + cache_len) per mask. + masks = example_inputs[1]["masks"] + assert set(masks.keys()) == {sliding_window, full} + assert masks[sliding_window].shape[-1] == input_len + sliding_window + assert masks[full].shape[-1] == input_len + full + + if __name__ == "__main__": test_split_model() test_resolve_cache_len_no_sliding_window() @@ -129,3 +269,10 @@ def test_create_example_inputs_with_sliding_window_shrinks_kv_cache(): test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op() test_resolve_cache_len_rejects_non_positive_window() test_create_example_inputs_with_sliding_window_shrinks_kv_cache() + test_per_layer_cache_lens_uniform_when_no_pattern() + test_per_layer_cache_lens_uniform_full_when_no_window() + test_per_layer_cache_lens_gemma4_e2b_pattern() + test_per_layer_cache_lens_gemma3_pattern() + test_per_layer_cache_lens_pattern_requires_sliding_window() + test_per_layer_cache_lens_rejects_pattern_le_one() + test_create_example_inputs_with_per_layer_pattern_yields_two_cache_sizes() From fe97e4367dbca6e1949e084a3fcf56bfd8950966 Mon Sep 17 00:00:00 2001 From: john-rocky Date: Fri, 1 May 2026 14:41:03 +0900 Subject: [PATCH 3/3] Add e2e forward-pass test for per-layer hybrid attention MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Verifies a tiny static-attention transformer accepts the heterogeneous cache shapes produced by _resolve_per_layer_cache_lens and runs a forward pass without errors — the strongest signal that the model and IO Manager really do route the right mask per layer under hybrid sliding/full attention. --- examples/apple/coreml/llama/test.py | 47 +++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/examples/apple/coreml/llama/test.py b/examples/apple/coreml/llama/test.py index 49ac4ea9cc6..fa2a6762c97 100644 --- a/examples/apple/coreml/llama/test.py +++ b/examples/apple/coreml/llama/test.py @@ -213,6 +213,52 @@ def test_per_layer_cache_lens_rejects_pattern_le_one(): ) +def test_per_layer_pattern_forward_pass_runs_end_to_end(): + """A tiny static-attention transformer must accept the heterogeneous + cache shapes produced by `_resolve_per_layer_cache_lens` and run a + forward pass without complaining about the mismatched cache sizes + between sliding and full layers.""" + from executorch.examples.models.llama.llama_transformer import ( + construct_transformer, + ) + from executorch.examples.models.llama.static_attention import ( + transform_attention_mha_to_static_attention, + ) + + args = ModelArgs( + dim=64, + n_layers=10, + n_heads=4, + n_kv_heads=2, + head_dim=16, + vocab_size=128, + max_context_len=512, + max_seq_len=512, + attention_type="static_mha", + attention_kwargs={"decompose_sdpa_in_mha": True}, + ) + model = construct_transformer(args).eval() + model = transform_attention_mha_to_static_attention( + model, split_mha=True, inplace=False + ).eval() + + cache_lens = _resolve_per_layer_cache_lens( + n_layers=10, + max_context_len=512, + input_len=32, + sliding_window=64, + sliding_window_pattern=5, + ) + # 4 sliding + 1 full, twice. + assert cache_lens == [64, 64, 64, 64, 480, 64, 64, 64, 64, 480] + + example_inputs, _ = _create_example_inputs( + args, 32, 512, float_dtype=torch.float32, cache_len=cache_lens + ) + with torch.no_grad(): + model(*example_inputs) + + def test_create_example_inputs_with_per_layer_pattern_yields_two_cache_sizes(): model_args = ModelArgs( dim=32, @@ -276,3 +322,4 @@ def test_create_example_inputs_with_per_layer_pattern_yields_two_cache_sizes(): test_per_layer_cache_lens_pattern_requires_sliding_window() test_per_layer_cache_lens_rejects_pattern_le_one() test_create_example_inputs_with_per_layer_pattern_yields_two_cache_sizes() + test_per_layer_pattern_forward_pass_runs_end_to_end()