Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,10 @@ def __init__(
self.use_conv2d = False
self.enable_qnn_masked_softmax = kwargs.get("enable_qnn_masked_softmax", False)

# This fixes numerics on iOS26 on Core ML
# Possibly disable in future, depending on bug fixes in Core ML runtime
self.decompose_sdpa_in_mha: bool = kwargs.get("decompose_sdpa_in_mha", False)

if self.split_mha:
self.wqs = nn.ModuleList(
[
Expand Down Expand Up @@ -1027,16 +1031,56 @@ def _forward_mha(
k, out_cache_state = self.k_caches[0].update(k, in_cache_state, out_cache_state)
v, out_cache_state = self.v_caches[0].update(v, in_cache_state, out_cache_state)

if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)

mask = None
masks = kwargs.get("masks")
if masks:
cache_len = k.size(-2) - seq_len
mask = masks[cache_len]
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)

if not self.decompose_sdpa_in_mha:
if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
else:
# We remove bsz dim to keep matmul's on 4D tensors
# Core ML sometimes fails at runtime when given 5D tensors
assert bsz == 1, "Batch size > 1 not supported yet"

n_kv = self.n_kv_heads
n_rep = self.n_rep
D = self.head_dim

# Explicitly track lengths; they are NOT necessarily equal.
Tq = q.size(-2) # query length (current step/window), e.g. 64
Tk = k.size(-2) # key/value length (cache length), e.g. 2048

# Group Q to match KV layout
# q: (bsz=1, n_heads, Tq, D), with n_heads = n_kv * n_rep
# 1 * n_heads * Tq * D == n_kv * n_rep * Tq * D
# q_grouped: (n_kv, n_rep, Tq, D)
q_grouped = q.view(n_kv, n_rep, Tq, D)

# Prepare K for grouped KV matmul
# k: (1, n_kv, Tk, d) -> (n_kv, 1, Tk, D)
k_grouped = k.view(n_kv, 1, Tk, D)

# (n_kv, n_rep, Tq, Tk)
attn_grouped = q_grouped @ k_grouped.transpose(-2, -1)
attn_grouped = attn_grouped * self.inv_scale

# Ungroup, add mask, and regroup
attn_grouped = attn_grouped.view(1, self.n_heads, Tq, Tk)
attn_grouped = attn_grouped + mask
attn_grouped = F.softmax(attn_grouped, dim=-1)
attn_grouped = attn_grouped.view(n_kv, n_rep, Tq, Tk)

# Group v
v_grouped = v.view(n_kv, 1, Tk, D)
y_grouped = attn_grouped @ v_grouped

# Ungroup y
y = y_grouped.view(1, self.n_heads, Tq, D)

return y.transpose(1, 2).contiguous().view(bsz, seq_len, -1), out_cache_state

Expand Down
Loading