diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index 276ff6d193a..ea0c6818bc8 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -21,6 +21,7 @@ import argparse import json +from typing import List, Optional import coremltools as ct import torch @@ -170,6 +171,73 @@ def load_model( return model, args +def _resolve_cache_len( + max_context_len: int, input_len: int, sliding_window: Optional[int] = None +) -> int: + """Pick the per-layer KV cache length given context / input / window settings. + + Without sliding-window attention the cache must hold every token that can + attend to the current step, i.e. ``max_context_len - input_len``. When the + model is trained with sliding-window attention we instead cap the cache at + ``sliding_window`` so longer contexts do not enlarge per-layer attention + compute or KV cache memory. + """ + cache_len = max_context_len - input_len + if sliding_window is not None: + if sliding_window <= 0: + raise ValueError( + f"sliding_window must be positive, got {sliding_window}" + ) + if sliding_window < cache_len: + cache_len = sliding_window + return cache_len + + +def _resolve_per_layer_cache_lens( + n_layers: int, + max_context_len: int, + input_len: int, + sliding_window: Optional[int] = None, + sliding_window_pattern: Optional[int] = None, +) -> List[int]: + """Compute per-layer KV cache lengths for hybrid sliding/full attention. + + Returns a list of length ``n_layers``. When ``sliding_window_pattern`` is + ``P``, every ``P``-th layer (0-indexed: layers ``P-1, 2P-1, ...``) uses + the full ``max_context_len - input_len`` cache; the remaining layers use + ``sliding_window``. This matches HuggingFace's ``sliding_window_pattern`` + convention used by Gemma 3 (P=6: 5 sliding + 1 full) and Gemma 4 E2B + (P=5: 4 sliding + 1 full). + + When ``sliding_window_pattern`` is ``None``, every layer uses the same + cache length resolved by :func:`_resolve_cache_len`. + """ + full_cache_len = max_context_len - input_len + sliding_cache_len = _resolve_cache_len( + max_context_len, input_len, sliding_window + ) + + if sliding_window_pattern is None: + return [sliding_cache_len] * n_layers + + if sliding_window is None: + raise ValueError( + "sliding_window_pattern requires sliding_window to be set" + ) + if sliding_window_pattern <= 1: + raise ValueError( + "sliding_window_pattern must be at least 2 (P=1 would make every " + f"layer full attention); got {sliding_window_pattern}" + ) + + return [ + full_cache_len + if (i + 1) % sliding_window_pattern == 0 + else sliding_cache_len + for i in range(n_layers) + ] + + def _create_example_inputs( model_args, input_len, max_context_len, float_dtype, cache_len=None ): @@ -273,6 +341,10 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) # Output indices are in the same order (after logits) # Logits is output 0, then k_caches, then v_caches + # Read each cache's actual cache_len from the example tensor shape so + # per-layer hybrid sliding/full attention (Gemma 3/4) reports the right + # length per layer instead of a single uniform value. + k_cache_tensors = example_inputs[1]["in_cache_state"][0] kv_cache_specs = [] for i, cache_id in enumerate(sorted_k_cache_ids): k_in_idx = k_cache_in_indices[cache_id] @@ -281,7 +353,10 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) # v_caches come after k_caches (idx n_layers+1 to 2*n_layers) k_out_idx = 1 + i v_out_idx = 1 + len(sorted_k_cache_ids) + i - kv_cache_specs.append([k_in_idx, k_out_idx, v_in_idx, v_out_idx, cache_len]) + per_cache_len = k_cache_tensors[cache_id].size(-2) + kv_cache_specs.append( + [k_in_idx, k_out_idx, v_in_idx, v_out_idx, per_cache_len] + ) print(f"KV cache specs (k_in, k_out, v_in, v_out, cache_len): {kv_cache_specs}") @@ -410,6 +485,31 @@ def main(): default=32, help="Input sequence length per forward pass", ) + parser.add_argument( + "--sliding_window", + type=int, + default=None, + help=( + "Sliding window attention size. When set, every layer uses a KV cache " + "of this many tokens instead of (max_context_len - input_len), which " + "lets the model serve longer contexts without growing per-layer attention " + "compute or KV cache memory. Required for Mistral / Gemma3 / Gemma4 / " + "Llama4-style models that train with sliding-window attention." + ), + ) + parser.add_argument( + "--sliding_window_pattern", + type=int, + default=None, + help=( + "Period of the sliding/full attention pattern (HuggingFace's " + "sliding_window_pattern). When set together with --sliding_window, " + "every P-th layer (1-indexed) uses full attention while the rest use " + "the sliding window. Use P=5 for Gemma 4 E2B (4 sliding + 1 full) " + "and P=6 for Gemma 3 (5 sliding + 1 full). Without this flag every " + "layer uses the sliding window." + ), + ) parser.add_argument( "--dtype", type=str, @@ -481,11 +581,20 @@ def main(): print(f"\tLinear quantize: {args.linear_quantize}") print(f"\tDtype: {args.dtype}") - cache_len = args.max_context_len - args.input_len + if args.sliding_window_pattern is not None and args.sliding_window is None: + parser.error("--sliding_window_pattern requires --sliding_window to be set") + + cache_len = _resolve_cache_len( + args.max_context_len, args.input_len, args.sliding_window + ) print("\nGeneration configuration:") print(f"\tMax context length: {args.max_context_len}") print(f"\tInput length: {args.input_len}") print(f"\tCache length: {cache_len}") + if args.sliding_window is not None: + print(f"\tSliding window: {args.sliding_window}") + if args.sliding_window_pattern is not None: + print(f"\tSliding window pattern: every {args.sliding_window_pattern}-th layer is full") print("\nLinear splitting:") print(f"\tTarget split size: {args.target_split_size}") @@ -513,11 +622,22 @@ def main(): # the same cache buffer at runtime without any copying. decode_input_len = 1 prefill_input_len = args.input_len # default 32 - shared_cache_len = ( - args.max_context_len - decode_input_len - ) # Use decode's cache size for both + shared_cache_len = _resolve_per_layer_cache_lens( + n_layers=model_args.n_layers, + max_context_len=args.max_context_len, + input_len=decode_input_len, + sliding_window=args.sliding_window, + sliding_window_pattern=args.sliding_window_pattern, + ) - print(f"\nShared cache length for prefill/decode: {shared_cache_len}") + if args.sliding_window_pattern is not None: + n_full = sum(1 for cl in shared_cache_len if cl == args.max_context_len - decode_input_len) + print( + f"\nShared cache lengths for prefill/decode: {n_full} full + " + f"{model_args.n_layers - n_full} sliding ({args.sliding_window} tokens)" + ) + else: + print(f"\nShared cache length for prefill/decode: {shared_cache_len[0]}") print(f"\nCreating example inputs for decode (seqlen={decode_input_len})...") decode_inputs, decode_cache_len = _create_example_inputs( @@ -641,9 +761,27 @@ def main(): ) else: # Single method mode: fixed seqlen with generate_full_logits=True for lookahead + per_layer_cache_lens = _resolve_per_layer_cache_lens( + n_layers=model_args.n_layers, + max_context_len=args.max_context_len, + input_len=args.input_len, + sliding_window=args.sliding_window, + sliding_window_pattern=args.sliding_window_pattern, + ) + if args.sliding_window_pattern is not None: + full_cache_len = args.max_context_len - args.input_len + n_full = sum(1 for cl in per_layer_cache_lens if cl == full_cache_len) + print( + f"\nCache length per layer: {n_full} full ({full_cache_len} tokens) + " + f"{model_args.n_layers - n_full} sliding ({args.sliding_window} tokens)" + ) print(f"\nCreating example inputs (seqlen={args.input_len})...") example_inputs, example_cache_len = _create_example_inputs( - model_args, args.input_len, args.max_context_len, float_dtype + model_args, + args.input_len, + args.max_context_len, + float_dtype, + cache_len=per_layer_cache_lens, ) # Test eager execution diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index ae3852a7828..f1011b105e7 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -69,6 +69,8 @@ Key differences between the two modes: | `--max_context_len` | 1024 | Maximum context length | | `--input_len` | 32 | Input sequence length per forward pass. In multifunction mode, this is the prefill sequence length. | | `--dtype` | `fp16` | Model dtype (`fp16` or `fp32`). The ANE requires fp16. | +| `--sliding_window` | (off) | Sliding-window attention size. When set, every layer uses a KV cache of this many tokens instead of `max_context_len - input_len`. Required for Mistral / Gemma3 / Gemma4 / Llama4-style models trained with sliding-window attention; lets longer contexts run without growing per-layer attention compute or KV cache memory. | +| `--sliding_window_pattern` | (off) | Period of the sliding/full attention pattern (HuggingFace's `sliding_window_pattern`). When set together with `--sliding_window`, every P-th layer (1-indexed) uses full attention while the rest use the sliding window. Use `P=5` for Gemma 4 E2B (4 sliding + 1 full) and `P=6` for Gemma 3 (5 sliding + 1 full). | ### Quantization Options | Option | Default | Description | @@ -94,6 +96,7 @@ The static model has several ANE optimizations, including: * Splitting linear layers for improved performance (controlled by target_split_size and max_splits args) * Splitting the pte into multiple Core ML pieces for improved performance (can be disabled with no_graph_breaks) * Re-writing SDPA to avoid 5-D tensors to improve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833) +* Sliding-window attention (`--sliding_window N`) caps each layer's KV cache at `N` tokens regardless of `max_context_len`. For models trained with sliding-window attention (Mistral 7B uses 4096; Gemma 3 / Gemma 4 alternate sliding and full layers), this both matches training behavior and roughly halves KV-cache memory plus per-token attention FLOPs at long context. Pair with `--sliding_window_pattern P` to mix sliding and full layers in the HuggingFace pattern (every P-th layer is full attention): P=5 for Gemma 4 E2B (4 sliding + 1 full) and P=6 for Gemma 3 (5 sliding + 1 full). We are working on adding a C++ runner as well. diff --git a/examples/apple/coreml/llama/test.py b/examples/apple/coreml/llama/test.py index 895cf2e1cce..fa2a6762c97 100644 --- a/examples/apple/coreml/llama/test.py +++ b/examples/apple/coreml/llama/test.py @@ -9,9 +9,17 @@ sys.path.insert(0, ".") import copy +import pytest import torch +from export_static_llm_coreml import ( + _create_example_inputs, + _resolve_cache_len, + _resolve_per_layer_cache_lens, +) from utils import replace_linear_with_split_linear +from executorch.examples.models.llama.model_args import ModelArgs + def get_split_model( model, @@ -44,5 +52,274 @@ def test_split_model(): assert torch.allclose(model(inputs), model3(inputs), atol=1e-5) +def test_resolve_cache_len_no_sliding_window(): + # Without --sliding_window the cache fills the rest of the context. + assert _resolve_cache_len(1024, 32) == 992 + assert _resolve_cache_len(1024, 1) == 1023 + + +def test_resolve_cache_len_with_sliding_window(): + # When the window is smaller than the remaining context the cache shrinks. + assert _resolve_cache_len(8192, 32, sliding_window=4096) == 4096 + assert _resolve_cache_len(8192, 1, sliding_window=4096) == 4096 + + +def test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op(): + # A user-provided window larger than the remaining context degenerates to + # the no-window case, so users can safely set --sliding_window to a value + # the model was trained with even when the export uses a shorter context. + assert _resolve_cache_len(1024, 32, sliding_window=4096) == 992 + + +def test_resolve_cache_len_rejects_non_positive_window(): + with pytest.raises(ValueError): + _resolve_cache_len(1024, 32, sliding_window=0) + with pytest.raises(ValueError): + _resolve_cache_len(1024, 32, sliding_window=-1) + + +def test_create_example_inputs_with_sliding_window_shrinks_kv_cache(): + # Build a tiny ModelArgs that does not need a checkpoint or torchao. + model_args = ModelArgs( + dim=32, + n_layers=2, + n_heads=4, + n_kv_heads=2, + head_dim=8, + vocab_size=128, + max_context_len=1024, + max_seq_len=1024, + ) + max_context_len = 1024 + input_len = 32 + sliding_window = 64 + + cache_len = _resolve_cache_len(max_context_len, input_len, sliding_window) + assert cache_len == sliding_window + + example_inputs, returned_cache_len = _create_example_inputs( + model_args, + input_len, + max_context_len, + float_dtype=torch.float32, + cache_len=cache_len, + ) + assert returned_cache_len == sliding_window + + # The KV cache tensors live inside the kwargs dict at index 1 under + # in_cache_state. Walking that structure should find caches whose + # sequence dimension equals the sliding window, not max_context_len. + kwargs = example_inputs[1] + in_cache_state = kwargs["in_cache_state"] + cache_seq_dims = set() + for per_kind in in_cache_state: # (k_caches, v_caches) + for cache_tensor in per_kind.values(): + cache_seq_dims.add(cache_tensor.size(-2)) + assert cache_seq_dims == {sliding_window}, ( + f"expected every KV cache to be sized to the sliding window {sliding_window}, " + f"got {cache_seq_dims}" + ) + + # The attention mask covers (input_len + cache_len) along the last dim. + masks = kwargs["masks"] + assert sliding_window in masks + assert masks[sliding_window].shape[-1] == input_len + sliding_window + + +def test_per_layer_cache_lens_uniform_when_no_pattern(): + # Without a pattern every layer gets the same cache length. + out = _resolve_per_layer_cache_lens( + n_layers=4, max_context_len=1024, input_len=32, sliding_window=64 + ) + assert out == [64, 64, 64, 64] + + +def test_per_layer_cache_lens_uniform_full_when_no_window(): + # No window at all is just `max_context_len - input_len` everywhere. + out = _resolve_per_layer_cache_lens( + n_layers=4, max_context_len=1024, input_len=32 + ) + assert out == [992, 992, 992, 992] + + +def test_per_layer_cache_lens_gemma4_e2b_pattern(): + # Gemma 4 E2B: 35 layers, P=5 → 4 sliding + 1 full repeated 7 times. + out = _resolve_per_layer_cache_lens( + n_layers=35, + max_context_len=8192, + input_len=32, + sliding_window=4096, + sliding_window_pattern=5, + ) + full = 8192 - 32 + sliding = 4096 + assert len(out) == 35 + # Layers at 1-indexed positions 5, 10, 15, …, 35 are full. + assert [out[i] for i in range(35)] == [ + full if (i + 1) % 5 == 0 else sliding for i in range(35) + ] + assert sum(1 for cl in out if cl == full) == 7 + assert sum(1 for cl in out if cl == sliding) == 28 + + +def test_per_layer_cache_lens_gemma3_pattern(): + # Gemma 3 uses P=6 (5 sliding + 1 full). + out = _resolve_per_layer_cache_lens( + n_layers=12, + max_context_len=2048, + input_len=32, + sliding_window=512, + sliding_window_pattern=6, + ) + full = 2048 - 32 + sliding = 512 + assert out == [ + sliding, + sliding, + sliding, + sliding, + sliding, + full, + sliding, + sliding, + sliding, + sliding, + sliding, + full, + ] + + +def test_per_layer_cache_lens_pattern_requires_sliding_window(): + with pytest.raises(ValueError): + _resolve_per_layer_cache_lens( + n_layers=8, + max_context_len=1024, + input_len=32, + sliding_window=None, + sliding_window_pattern=5, + ) + + +def test_per_layer_cache_lens_rejects_pattern_le_one(): + # P=1 would make every layer full and is almost certainly a typo, so + # surface it rather than silently doing the no-pattern thing. + with pytest.raises(ValueError): + _resolve_per_layer_cache_lens( + n_layers=8, + max_context_len=1024, + input_len=32, + sliding_window=64, + sliding_window_pattern=1, + ) + + +def test_per_layer_pattern_forward_pass_runs_end_to_end(): + """A tiny static-attention transformer must accept the heterogeneous + cache shapes produced by `_resolve_per_layer_cache_lens` and run a + forward pass without complaining about the mismatched cache sizes + between sliding and full layers.""" + from executorch.examples.models.llama.llama_transformer import ( + construct_transformer, + ) + from executorch.examples.models.llama.static_attention import ( + transform_attention_mha_to_static_attention, + ) + + args = ModelArgs( + dim=64, + n_layers=10, + n_heads=4, + n_kv_heads=2, + head_dim=16, + vocab_size=128, + max_context_len=512, + max_seq_len=512, + attention_type="static_mha", + attention_kwargs={"decompose_sdpa_in_mha": True}, + ) + model = construct_transformer(args).eval() + model = transform_attention_mha_to_static_attention( + model, split_mha=True, inplace=False + ).eval() + + cache_lens = _resolve_per_layer_cache_lens( + n_layers=10, + max_context_len=512, + input_len=32, + sliding_window=64, + sliding_window_pattern=5, + ) + # 4 sliding + 1 full, twice. + assert cache_lens == [64, 64, 64, 64, 480, 64, 64, 64, 64, 480] + + example_inputs, _ = _create_example_inputs( + args, 32, 512, float_dtype=torch.float32, cache_len=cache_lens + ) + with torch.no_grad(): + model(*example_inputs) + + +def test_create_example_inputs_with_per_layer_pattern_yields_two_cache_sizes(): + model_args = ModelArgs( + dim=32, + n_layers=10, + n_heads=4, + n_kv_heads=2, + head_dim=8, + vocab_size=128, + max_context_len=1024, + max_seq_len=1024, + ) + max_context_len = 1024 + input_len = 32 + sliding_window = 64 + pattern = 5 + + cache_lens = _resolve_per_layer_cache_lens( + n_layers=model_args.n_layers, + max_context_len=max_context_len, + input_len=input_len, + sliding_window=sliding_window, + sliding_window_pattern=pattern, + ) + + example_inputs, _ = _create_example_inputs( + model_args, + input_len, + max_context_len, + float_dtype=torch.float32, + cache_len=cache_lens, + ) + + in_cache_state = example_inputs[1]["in_cache_state"] + seen = set() + for per_kind in in_cache_state: + for tensor in per_kind.values(): + seen.add(tensor.size(-2)) + full = max_context_len - input_len + assert seen == {sliding_window, full}, ( + f"expected both {sliding_window} (sliding) and {full} (full) cache sizes, got {seen}" + ) + + # Both cache_len values get their own mask; (input_len + cache_len) per mask. + masks = example_inputs[1]["masks"] + assert set(masks.keys()) == {sliding_window, full} + assert masks[sliding_window].shape[-1] == input_len + sliding_window + assert masks[full].shape[-1] == input_len + full + + if __name__ == "__main__": test_split_model() + test_resolve_cache_len_no_sliding_window() + test_resolve_cache_len_with_sliding_window() + test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op() + test_resolve_cache_len_rejects_non_positive_window() + test_create_example_inputs_with_sliding_window_shrinks_kv_cache() + test_per_layer_cache_lens_uniform_when_no_pattern() + test_per_layer_cache_lens_uniform_full_when_no_window() + test_per_layer_cache_lens_gemma4_e2b_pattern() + test_per_layer_cache_lens_gemma3_pattern() + test_per_layer_cache_lens_pattern_requires_sliding_window() + test_per_layer_cache_lens_rejects_pattern_le_one() + test_create_example_inputs_with_per_layer_pattern_yields_two_cache_sizes() + test_per_layer_pattern_forward_pass_runs_end_to_end()