Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 145 additions & 7 deletions examples/apple/coreml/llama/export_static_llm_coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import argparse
import json
from typing import List, Optional

import coremltools as ct
import torch
Expand Down Expand Up @@ -170,6 +171,73 @@ def load_model(
return model, args


def _resolve_cache_len(
max_context_len: int, input_len: int, sliding_window: Optional[int] = None
) -> int:
"""Pick the per-layer KV cache length given context / input / window settings.

Without sliding-window attention the cache must hold every token that can
attend to the current step, i.e. ``max_context_len - input_len``. When the
model is trained with sliding-window attention we instead cap the cache at
``sliding_window`` so longer contexts do not enlarge per-layer attention
compute or KV cache memory.
"""
cache_len = max_context_len - input_len
if sliding_window is not None:
if sliding_window <= 0:
raise ValueError(
f"sliding_window must be positive, got {sliding_window}"
)
if sliding_window < cache_len:
cache_len = sliding_window
return cache_len


def _resolve_per_layer_cache_lens(
n_layers: int,
max_context_len: int,
input_len: int,
sliding_window: Optional[int] = None,
sliding_window_pattern: Optional[int] = None,
) -> List[int]:
"""Compute per-layer KV cache lengths for hybrid sliding/full attention.

Returns a list of length ``n_layers``. When ``sliding_window_pattern`` is
``P``, every ``P``-th layer (0-indexed: layers ``P-1, 2P-1, ...``) uses
the full ``max_context_len - input_len`` cache; the remaining layers use
``sliding_window``. This matches HuggingFace's ``sliding_window_pattern``
convention used by Gemma 3 (P=6: 5 sliding + 1 full) and Gemma 4 E2B
(P=5: 4 sliding + 1 full).

When ``sliding_window_pattern`` is ``None``, every layer uses the same
cache length resolved by :func:`_resolve_cache_len`.
"""
full_cache_len = max_context_len - input_len
sliding_cache_len = _resolve_cache_len(
max_context_len, input_len, sliding_window
)

if sliding_window_pattern is None:
return [sliding_cache_len] * n_layers

if sliding_window is None:
raise ValueError(
"sliding_window_pattern requires sliding_window to be set"
)
if sliding_window_pattern <= 1:
raise ValueError(
"sliding_window_pattern must be at least 2 (P=1 would make every "
f"layer full attention); got {sliding_window_pattern}"
)

return [
full_cache_len
if (i + 1) % sliding_window_pattern == 0
else sliding_cache_len
for i in range(n_layers)
]


def _create_example_inputs(
model_args, input_len, max_context_len, float_dtype, cache_len=None
):
Expand Down Expand Up @@ -273,6 +341,10 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype)

# Output indices are in the same order (after logits)
# Logits is output 0, then k_caches, then v_caches
# Read each cache's actual cache_len from the example tensor shape so
# per-layer hybrid sliding/full attention (Gemma 3/4) reports the right
# length per layer instead of a single uniform value.
k_cache_tensors = example_inputs[1]["in_cache_state"][0]
kv_cache_specs = []
for i, cache_id in enumerate(sorted_k_cache_ids):
k_in_idx = k_cache_in_indices[cache_id]
Expand All @@ -281,7 +353,10 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype)
# v_caches come after k_caches (idx n_layers+1 to 2*n_layers)
k_out_idx = 1 + i
v_out_idx = 1 + len(sorted_k_cache_ids) + i
kv_cache_specs.append([k_in_idx, k_out_idx, v_in_idx, v_out_idx, cache_len])
per_cache_len = k_cache_tensors[cache_id].size(-2)
kv_cache_specs.append(
[k_in_idx, k_out_idx, v_in_idx, v_out_idx, per_cache_len]
)

print(f"KV cache specs (k_in, k_out, v_in, v_out, cache_len): {kv_cache_specs}")

Expand Down Expand Up @@ -410,6 +485,31 @@ def main():
default=32,
help="Input sequence length per forward pass",
)
parser.add_argument(
"--sliding_window",
type=int,
default=None,
help=(
"Sliding window attention size. When set, every layer uses a KV cache "
"of this many tokens instead of (max_context_len - input_len), which "
"lets the model serve longer contexts without growing per-layer attention "
"compute or KV cache memory. Required for Mistral / Gemma3 / Gemma4 / "
"Llama4-style models that train with sliding-window attention."
),
)
parser.add_argument(
"--sliding_window_pattern",
type=int,
default=None,
help=(
"Period of the sliding/full attention pattern (HuggingFace's "
"sliding_window_pattern). When set together with --sliding_window, "
"every P-th layer (1-indexed) uses full attention while the rest use "
"the sliding window. Use P=5 for Gemma 4 E2B (4 sliding + 1 full) "
"and P=6 for Gemma 3 (5 sliding + 1 full). Without this flag every "
"layer uses the sliding window."
),
)
parser.add_argument(
"--dtype",
type=str,
Expand Down Expand Up @@ -481,11 +581,20 @@ def main():
print(f"\tLinear quantize: {args.linear_quantize}")
print(f"\tDtype: {args.dtype}")

cache_len = args.max_context_len - args.input_len
if args.sliding_window_pattern is not None and args.sliding_window is None:
parser.error("--sliding_window_pattern requires --sliding_window to be set")

cache_len = _resolve_cache_len(
args.max_context_len, args.input_len, args.sliding_window
)
print("\nGeneration configuration:")
print(f"\tMax context length: {args.max_context_len}")
print(f"\tInput length: {args.input_len}")
print(f"\tCache length: {cache_len}")
if args.sliding_window is not None:
print(f"\tSliding window: {args.sliding_window}")
if args.sliding_window_pattern is not None:
print(f"\tSliding window pattern: every {args.sliding_window_pattern}-th layer is full")

print("\nLinear splitting:")
print(f"\tTarget split size: {args.target_split_size}")
Expand Down Expand Up @@ -513,11 +622,22 @@ def main():
# the same cache buffer at runtime without any copying.
decode_input_len = 1
prefill_input_len = args.input_len # default 32
shared_cache_len = (
args.max_context_len - decode_input_len
) # Use decode's cache size for both
shared_cache_len = _resolve_per_layer_cache_lens(
n_layers=model_args.n_layers,
max_context_len=args.max_context_len,
input_len=decode_input_len,
sliding_window=args.sliding_window,
sliding_window_pattern=args.sliding_window_pattern,
)

print(f"\nShared cache length for prefill/decode: {shared_cache_len}")
if args.sliding_window_pattern is not None:
n_full = sum(1 for cl in shared_cache_len if cl == args.max_context_len - decode_input_len)
print(
f"\nShared cache lengths for prefill/decode: {n_full} full + "
f"{model_args.n_layers - n_full} sliding ({args.sliding_window} tokens)"
)
else:
print(f"\nShared cache length for prefill/decode: {shared_cache_len[0]}")

print(f"\nCreating example inputs for decode (seqlen={decode_input_len})...")
decode_inputs, decode_cache_len = _create_example_inputs(
Expand Down Expand Up @@ -641,9 +761,27 @@ def main():
)
else:
# Single method mode: fixed seqlen with generate_full_logits=True for lookahead
per_layer_cache_lens = _resolve_per_layer_cache_lens(
n_layers=model_args.n_layers,
max_context_len=args.max_context_len,
input_len=args.input_len,
sliding_window=args.sliding_window,
sliding_window_pattern=args.sliding_window_pattern,
)
if args.sliding_window_pattern is not None:
full_cache_len = args.max_context_len - args.input_len
n_full = sum(1 for cl in per_layer_cache_lens if cl == full_cache_len)
print(
f"\nCache length per layer: {n_full} full ({full_cache_len} tokens) + "
f"{model_args.n_layers - n_full} sliding ({args.sliding_window} tokens)"
)
print(f"\nCreating example inputs (seqlen={args.input_len})...")
example_inputs, example_cache_len = _create_example_inputs(
model_args, args.input_len, args.max_context_len, float_dtype
model_args,
args.input_len,
args.max_context_len,
float_dtype,
cache_len=per_layer_cache_lens,
)

# Test eager execution
Expand Down
3 changes: 3 additions & 0 deletions examples/apple/coreml/llama/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ Key differences between the two modes:
| `--max_context_len` | 1024 | Maximum context length |
| `--input_len` | 32 | Input sequence length per forward pass. In multifunction mode, this is the prefill sequence length. |
| `--dtype` | `fp16` | Model dtype (`fp16` or `fp32`). The ANE requires fp16. |
| `--sliding_window` | (off) | Sliding-window attention size. When set, every layer uses a KV cache of this many tokens instead of `max_context_len - input_len`. Required for Mistral / Gemma3 / Gemma4 / Llama4-style models trained with sliding-window attention; lets longer contexts run without growing per-layer attention compute or KV cache memory. |
| `--sliding_window_pattern` | (off) | Period of the sliding/full attention pattern (HuggingFace's `sliding_window_pattern`). When set together with `--sliding_window`, every P-th layer (1-indexed) uses full attention while the rest use the sliding window. Use `P=5` for Gemma 4 E2B (4 sliding + 1 full) and `P=6` for Gemma 3 (5 sliding + 1 full). |

### Quantization Options
| Option | Default | Description |
Expand All @@ -94,6 +96,7 @@ The static model has several ANE optimizations, including:
* Splitting linear layers for improved performance (controlled by target_split_size and max_splits args)
* Splitting the pte into multiple Core ML pieces for improved performance (can be disabled with no_graph_breaks)
* Re-writing SDPA to avoid 5-D tensors to improve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833)
* Sliding-window attention (`--sliding_window N`) caps each layer's KV cache at `N` tokens regardless of `max_context_len`. For models trained with sliding-window attention (Mistral 7B uses 4096; Gemma 3 / Gemma 4 alternate sliding and full layers), this both matches training behavior and roughly halves KV-cache memory plus per-token attention FLOPs at long context. Pair with `--sliding_window_pattern P` to mix sliding and full layers in the HuggingFace pattern (every P-th layer is full attention): P=5 for Gemma 4 E2B (4 sliding + 1 full) and P=6 for Gemma 3 (5 sliding + 1 full).

We are working on adding a C++ runner as well.

Expand Down
Loading
Loading