Skip to content

metal: fused Sdpa via the vendored MetalFlashAttention kernel (~2×)#2320

Open
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/metal-flash-sdpa
Open

metal: fused Sdpa via the vendored MetalFlashAttention kernel (~2×)#2320
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/metal-flash-sdpa

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

@kali — opening this as much as an RFC as a PR, so please read the design question below before the diff.

As I've seen you've been actively building out the GPU path lately — mirroring the CUDA kernels into Metal (gather, diag_gather), landing scaled_masked_softmax (bool mask + post-softmax mask) on both backends, and adding the transform pre-check + the CPU-fallback-when-a-gpu-op-rejects-a-shape — I wanted to surface something adjacent that's been sitting unused in the Metal crate and check the direction with you before investing further.

The opportunity

The vendored libMetalFlashAttention.metallib (Apple / Philip Turner's MetalFlashAttention, MIT — already in-tree) ships four functions: sgemm, hgemm, convolution, and attention. We only ever dispatch sgemm/hgemm. The attention entry is a fully-implemented fused flash-attention kernel (online softmax, never materializes the score matrix) — and it's simply never called. (For completeness: convolution in this metallib is an empty stub, so attention is the only unused-but-real kernel here.)

Meanwhile MetalTransform explodes every Sdpa into einsum → softmax → einsum, which materializes the full (B,H,Sq,Sk) score buffer in device memory and round-trips it through three kernels — the middle one being the scaled_masked_softmax you just landed.

This PR wires the fused kernel and routes Sdpa to it.

The design question (why this is an RFC)

Dispatching a vendored, 2023-era macosx13 metallib that we don't own is a real commitment — and I noticed you just dropped the unused GgmlFlashAttn kernel library on the CUDA side (73dada812), which cuts the other way. So I'd rather ask directly than assume:

Is wiring this vendored Metal kernel the direction you want — or would you prefer a fresh, owned port (e.g. translating the MLX / ggml-metal flash-attention kernel into a .metal source we control)?

The case for wiring it now: it's already in-tree, fully implemented, MIT, and measures ~2× — a low-risk way to close the fused-Metal-SDPA gap today. The owned-.metal-port hedge stays open as a follow-up if metallib longevity on M3/M4 worries you. This PR is the "wire what's already there" option, fully validated, for you to accept or redirect.

What it does

  • dispatch_metal_mfa_attention — drives the vendored attention function. ABI (buffers / function-constants / grid geometry) reconstructed from the MFA v1.0.1 source + on-GPU pipeline reflection. f32 + f16.
  • mfa_attention_head_major — adapts tract's native [B,H,S,D] to MFA's layout (Q/O=[R,H,D], K=[H,D,C], V=[C,H,D]) on-device via copy_nd. The one unavoidable copy is the K transpose (candidate to fold later).
  • MetalMfaSdpa op + register_metal_op!(Sdpa) translator — routes a real Sdpa node to the fused kernel. Unsupported shapes return None and fall through to the CPU-fallback path you just added (85255fdb9). A Metal-local rewire_sdpa_metal flattens only the Sdpa nodes the kernel can't take, leaving fusable ones intact (CUDA keeps the shared rewire_sdpa untouched).
  • causal via an additive [Sq,Sk] mask — the metallib's triangular function-constant alone is a no-op (it computes full attention), pinned by a regression test.

Numbers (Apple M-series, f32, B=1 H=8 S=512 D=64)

measurement result
fused kernel vs explode path (both preallocated, dispatch_eval) ~2×, growing with S (2.4× at S=2048)
8-layer all-attention stack (amortizes host sync) 2.70× on the attention portion (fused 6.6 ms/layer vs explode 17.9)
projected real-model e2e (2.70× × attention compute-share) ~1.2–1.4×
(B,H,Sq,Sk) score buffer eliminated (≈128 MB at S=2048/H=8)

(A single-op model bench reports 3.9×, but that's overhead-inflated — the 2.70× multi-layer figure is the honest one, and it's consistent with the kernel-level ~2×.)

Correctness & gates

  • Bit-close to a CPU reference across f32/f16, head dims 16–128, masked, causal, multi-head, and the head-major adapter.
  • An e2e test builds a real Sdpa model, runs MetalTransform, asserts it routes to MetalMfaSdpa, and matches the CPU FlashSdpa output.
  • Full tract-metal suite 71/0; cargo build --workspace clean; fmt + clippy clean on the new code.

Orthogonal to #2319 (that one is the CPU FlashSdpa path; this is Metal).

Credits / prior art

🤖 Generated with Claude Code

tract's metal crate already vendors libMetalFlashAttention (Apple / Philip
Turner's flash-attention kernels, MIT) but only used its sgemm/hgemm entry
points -- the fused `attention` kernel shipped inside the metallib was never
dispatched. The MetalTransform instead explodes every Sdpa into
einsum + softmax + einsum, materializing the full (B,H,Sq,Sk) score matrix in
device memory and round-tripping it through three separate kernels.

This wires the fused kernel:

  * `dispatch_metal_mfa_attention` drives the vendored `attention` function
    (online softmax / flash attention; ABI reconstructed from the v1.0.1
    source + on-GPU pipeline reflection). f32 and f16.
  * `mfa_attention_head_major` adapts tract's native [B,H,S,D] layout to MFA's
    (Q/O=[R,H,D], K=[H,D,C], V=[C,H,D]) on-device via copy_nd.
  * `MetalMfaSdpa` op + a `register_metal_op!(Sdpa)` translator route a real
    Sdpa node to the fused kernel; unsupported shapes fall back to the existing
    explode path. The new `rewire_sdpa_metal` only flattens the Sdpa nodes the
    kernel can't take, leaving fusable ones intact (cuda keeps the shared
    `rewire_sdpa`).
  * causal masking via an additive [Sq,Sk] mask -- the metallib's `triangular`
    function-constant alone is a no-op, pinned by a regression test.

Eliminates the (B,H,Sq,Sk) intermediate and collapses three kernels to one.
Measured on M-series (f32, B=1 H=8 S=512 D=64): the kernel is ~2x the explode
path, and an 8-layer all-attention stack (amortizing host sync) runs 2.70x
faster on the attention portion -- so a real model's end-to-end gain is 2.70x
scaled by attention's compute share.

Correctness: bit-close to a CPU reference across f32/f16, head dims 16..128,
masked, causal, multi-head, and head-major layout; an e2e test builds a real
Sdpa model, runs the MetalTransform, asserts it routes to MetalMfaSdpa, and
matches the CPU FlashSdpa output.

Apple MetalFlashAttention: https://github.com/philipturner/metal-flash-attention
Prior art (fused-attention dispatch): llama.cpp ggml-metal flash-attn kernel;
candle-metal-kernels.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
czoli1976 added a commit to czoli1976/tract that referenced this pull request Jun 3, 2026
…rite + NNEF + resume

tract's DynKeyValueCache grows by TypedConcat([past, new]) each step, copying the
whole t-token past into a fresh buffer -> O(T^2) total copy over a T-token decode.
Apple Core ML "stateful in-place KV" lever. Pieces:

1. InPlaceKvCache: geometric-growth in-place cache. Buffer with spare capacity along
   `axis`, write each new chunk at the cursor (Tensor::assign_slice, strided-safe for
   any axis), double only when capacity is exceeded -> O(T) amortized copy.
   valid_view() exposes the live [0..len] region as a ZERO-COPY ndarray view (the path
   that realizes the win). For the seq axis of [B,H,S,D] a per-head slice of the
   capacity buffer is a contiguous prefix, so a consumer reads it at concat cost.

2. InPlaceKvSdpa: stateful fused op owning the K/V in-place caches, running the CPU
   SDPA (FlashSdpaOp::flash_attention_gqa) over the zero-copy views. tract Tensors
   cannot be zero-copy views ACROSS an op boundary (Tensor::slice copies), so keeping
   the buffers inside the consuming op is what makes the saving real. Drop-in for
   {kv_cache(K), kv_cache(V), Sdpa}; does GQA internally.

3. InPlaceKvSdpaTransform: rewrite pass that strips the GQA broadcast chain
   (fuse_kv_cache_broadcast_rule) then fuses {cache(K), cache(V), Sdpa} -> InPlaceKvSdpa
   so existing decode models adopt the in-place cache transparently.

4. NNEF ser/de: round-trips via tract_transformers_inplace_kv_sdpa (registered).

5. Resume: save_to/load_from checkpoint the cache as [K,V] tensors; freeze/unfreeze
   snapshot the running state in-process. Both bit-exact resume; snapshot is O(len).

Validated (11 tests): in-place bit-exact vs concat-grow; fused op matches concat-cache
+ FlashSdpaOp baseline (prefill+decode, GQA, causal/non-causal); runs end-to-end via a
persistent SimpleState; the rewrite fires + the rewritten model matches baseline; NNEF
round-trip; freeze/unfreeze and save/load resume bit-exact; growth amortized. fmt +
clippy clean; transformers lib 23/0 no-regression.

Benched (release, B=1 H=8 D=128):
  - cache-update only:      21x (T=256) -> 709x (T=4096), O(T^2) -> O(T)
  - end-to-end via the op:  1.10x (256) -> 1.63x (2048), 39% faster decode @2k
  - resume checkpoint:      O(len), 0.10ms (256) -> 1.76ms (4096), one-time

Follow-up: GPU coupling (sonos#2320 MFA kernel reading capacity buffer + length).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant