From 1595b734b6c90300f5a34f59b67fd4207ab60561 Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Mon, 4 Aug 2025 13:52:45 -0700 Subject: [PATCH] Static attention: support local-global attention (#13043) Summary: Runtime: support different cache lengths for different layer. Python: add sliding window cache update which was already in the runtime. Reviewed By: billmguo Differential Revision: D79267644 --- .../runner/static_attention_io_manager.h | 202 ++++++++++-------- examples/models/llama/static_attention.py | 177 ++++++++++----- .../llama/tests/test_static_attention.py | 38 +++- 3 files changed, 277 insertions(+), 140 deletions(-) diff --git a/examples/models/llama/runner/static_attention_io_manager.h b/examples/models/llama/runner/static_attention_io_manager.h index 2c700324486..41c826773fa 100644 --- a/examples/models/llama/runner/static_attention_io_manager.h +++ b/examples/models/llama/runner/static_attention_io_manager.h @@ -6,10 +6,11 @@ * LICENSE file in the root directory of this source tree. */ +#pragma once + #include #include #include -#include #include #include @@ -38,14 +39,13 @@ class StaticKVCache { * caches. */ StaticKVCache( - size_t n_caches, - size_t cache_len, + const std::vector& cache_lengths, size_t head_dim, - size_t max_input_len = 1, - size_t n_heads_per_cache = 1, + size_t max_input_len, + size_t n_heads_per_cache, StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK) - : n_caches_(n_caches), - cache_len_(n_caches_, cache_len), + : n_caches_(cache_lengths.size()), + cache_lengths_(cache_lengths), cache_pos_(n_caches_, 0), max_input_len_(max_input_len), n_heads_per_cache_(n_heads_per_cache), @@ -54,7 +54,7 @@ class StaticKVCache { input_ptrs_(n_caches_), output_ptrs_(n_caches_) { size_t total_cache_len = - std::accumulate(cache_len_.begin(), cache_len_.end(), 0); + std::accumulate(cache_lengths_.begin(), cache_lengths_.end(), 0); cache_data_size_ = total_cache_len * n_heads_per_cache_ * head_dim_; update_data_size_ = n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_; @@ -83,12 +83,12 @@ class StaticKVCache { */ void prepare( torch::executor::Method& method, - const std::vector& inputIndices, + const std::vector& input_indices, const std::vector& output_indices) { - ET_CHECK(inputIndices.size() == output_indices.size()); + ET_CHECK(input_indices.size() == output_indices.size()); auto methodMeta = method.method_meta(); for (size_t i = 0; i < n_caches_; i++) { - auto inIdx = inputIndices[i]; + auto inIdx = input_indices[i]; auto outIdx = output_indices[i]; auto inMeta = methodMeta.input_tensor_meta(inIdx); auto outMeta = methodMeta.output_tensor_meta(outIdx); @@ -113,6 +113,7 @@ class StaticKVCache { ET_CHECK_MSG( outSizes[1] == n_heads_per_cache_, "Number of heads per cache mismatch."); + ET_CHECK_MSG(inSizes[2] == cache_lengths_[i], "Cache length mismatch."); } else { // 1 head per cache, meaning MHA is split up into multiple SHAs for QNN. // Tensor shape is (1, seq_len, head_dim). @@ -121,12 +122,18 @@ class StaticKVCache { ET_CHECK_MSG( outSizes.size() == 3, "Cache input tensor expected to have rank 3."); + ET_CHECK_MSG(inSizes[1] == cache_lengths_[i], "Cache length mismatch."); + if (i < n_caches_ - 1) { + ET_CHECK_MSG( + inSizes[1] * head_dim_ == (input_ptrs_[i + 1] - input_ptrs_[i]), + "Cache length mismatch."); + } } auto ndim = inSizes.size(); ET_CHECK_MSG(inSizes[ndim - 1] == head_dim_, "KV head dim mismatch."); ET_CHECK_MSG(outSizes[ndim - 1] == head_dim_, "KV head dim mismatch."); ET_CHECK_MSG( - inSizes[ndim - 2] == cache_len_[i], "Cache length dim mismatch."); + inSizes[ndim - 2] == cache_lengths_[i], "Cache length dim mismatch."); auto impl = ::executorch::runtime::etensor::TensorImpl( inMeta->scalar_type(), @@ -167,7 +174,7 @@ class StaticKVCache { update_n, update_pos, input_ptrs_[i], - cache_len_[i], + cache_lengths_[i], cache_pos_[i]); } } @@ -187,7 +194,7 @@ class StaticKVCache { size_t cache_data_offset = 0; for (size_t i = 0; i < n_caches_; i++) { input_ptrs_[i] = cache_data_ + cache_data_offset; - cache_data_offset += cache_len_[i] * n_heads_per_cache_ * head_dim_; + cache_data_offset += cache_lengths_[i] * n_heads_per_cache_ * head_dim_; output_ptrs_[i] = update_data_ + i * n_heads_per_cache_ * max_input_len_ * head_dim_; } @@ -217,9 +224,10 @@ class StaticKVCache { update_head + (update_pos + update_n) * head_dim_, cache_head + cache_pos * head_dim_); } - cache_pos += update_n; + cache_pos = (cache_pos + update_n) % cache_len; if (wrap_n > 0) { + ET_CHECK(cache_pos == 0); return update_one_cache( update, update_len, @@ -227,14 +235,14 @@ class StaticKVCache { update_pos + contiguous_n, cache, cache_len, - 0); + cache_pos); } return cache_pos; } size_t n_caches_; - std::vector cache_len_; + std::vector cache_lengths_; std::vector cache_pos_; size_t max_input_len_; size_t n_heads_per_cache_; @@ -415,11 +423,11 @@ class StaticAttentionIOManager { public: struct StaticAttentionIOConfig { size_t n_caches{}; - size_t cache_len{}; + std::vector cache_lengths{}; size_t head_dim{}; size_t max_input_len{}; size_t n_heads_per_cache{}; - size_t attn_mask_input_index{}; + std::unordered_map cache_len_to_mask_idx; size_t rope_freqs_cos_input_index{}; size_t rope_freqs_sin_input_index{}; std::vector k_cache_input_indices; @@ -433,50 +441,55 @@ class StaticAttentionIOManager { StaticAttentionIOManager(StaticAttentionIOConfig config) : config_(std::move(config)), - kCaches_( - config_.n_caches, - config_.cache_len, + k_caches_( + config_.cache_lengths, config_.head_dim, config_.max_input_len, config_.n_heads_per_cache, config_.style), - vCaches_( - config_.n_caches, - config_.cache_len, + v_caches_( + config_.cache_lengths, config_.head_dim, config_.max_input_len, config_.n_heads_per_cache, config_.style) { ET_LOG( Info, - "Created StaticAttentionIOManager with" - " max input length = %zu cache length = %zu", - config_.max_input_len, - config_.cache_len); + "Created StaticAttentionIOManager with max input length = %zu", + config_.max_input_len); + for (auto cache_len : config_.cache_lengths) { + ET_LOG(Info, "Cache length = %zu", cache_len); + } } + using PerCacheLenMasks = std::vector>>>; + /** - * Create a new StaticAttentionMask that will be managed by this object. + * Create a new StaticAttentionMask for each cache length used. */ - StaticAttentionMask& - add_mask(size_t input_len, MaskT zero_val, MaskT mask_val) { - auto it = attentionMasks_.emplace( - std::piecewise_construct, - std::forward_as_tuple(input_len), - std::forward_as_tuple( - config_.cache_len, - input_len, - config_.head_dim, - zero_val, - mask_val, - config_.style)); + PerCacheLenMasks& add_mask(size_t input_len, MaskT zero_val, MaskT mask_val) { + PerCacheLenMasks masks; + for (auto& pair : config_.cache_len_to_mask_idx) { + masks.emplace_back( + pair.first, + std::make_unique>( + pair.first, + input_len, + config_.head_dim, + zero_val, + mask_val, + config_.style)); + } + auto it = attentionMasks_.emplace(input_len, std::move(masks)); return it.first->second; } /** * Retrieve a mask suitable for given input length. */ - StaticAttentionMask& get_mask(size_t input_len) { + PerCacheLenMasks& get_mask(size_t input_len) { return attentionMasks_.at(input_len); } @@ -487,9 +500,9 @@ class StaticAttentionIOManager { torch::executor::Method& method, std::optional> pos_offsets = std::nullopt) { - kCaches_.prepare( + k_caches_.prepare( method, config_.k_cache_input_indices, config_.k_cache_output_indices); - vCaches_.prepare( + v_caches_.prepare( method, config_.v_cache_input_indices, config_.v_cache_output_indices); size_t rope_dim = config_.head_dim / 2; @@ -538,12 +551,14 @@ class StaticAttentionIOManager { size_t update_len, size_t cache_update_pos = 0) { input_pos_ += update_len; - kCaches_.update( + k_caches_.update( method, k_cache_output_indices, update_len, cache_update_pos); - vCaches_.update( + v_caches_.update( method, v_cache_output_indices, update_len, cache_update_pos); for (auto& it : attentionMasks_) { - it.second.unmask(update_len); + for (auto& mask : it.second) { + mask.second->unmask(update_len); + } } } @@ -552,10 +567,12 @@ class StaticAttentionIOManager { */ void reset() { input_pos_ = 0; - kCaches_.reset(); - vCaches_.reset(); + k_caches_.reset(); + v_caches_.reset(); for (auto& it : attentionMasks_) { - it.second.reset(); + for (auto& mask : it.second) { + mask.second->reset(); + } } } @@ -570,7 +587,12 @@ class StaticAttentionIOManager { executorch::runtime::Span input_buffer, executorch::runtime::Method& method) { size_t input_len = input_buffer.size(); - get_mask(input_buffer.size()).set_causal_mask(); + auto& masks = get_mask(input_buffer.size()); + for (auto& pair : masks) { + auto& mask = *pair.second; + mask.set_causal_mask(); + set_input(method, config_.cache_len_to_mask_idx[pair.first], mask.get()); + } size_t batch_len = 0; for (size_t i = 0; i < tokens.size(); i += input_len) { @@ -593,17 +615,20 @@ class StaticAttentionIOManager { * the sampled token. */ template - std::vector decode( + void decode( TokenT prev_tok, executorch::runtime::Span input_buffer, executorch::runtime::Method& method, std::function& sample, - std::function& should_stop) { + std::function& token_callback) { set_input(method, 0, input_buffer.data()); - auto& mask = get_mask(input_buffer.size()); - set_input(method, config_.attn_mask_input_index, mask.get()); + auto& masks = get_mask(input_buffer.size()); + for (auto& pair : masks) { + auto& mask = *pair.second; + mask.set_causal_mask(); + set_input(method, config_.cache_len_to_mask_idx[pair.first], mask.get()); + } - std::vector generated_tokens; while (true) { input_buffer[0] = prev_tok; prepare(method); @@ -614,12 +639,10 @@ class StaticAttentionIOManager { config_.v_cache_output_indices, 1); prev_tok = sample(method); - generated_tokens.emplace_back(prev_tok); - if (should_stop(prev_tok)) { + if (!token_callback(prev_tok)) { break; } } - return generated_tokens; } /** @@ -628,12 +651,12 @@ class StaticAttentionIOManager { * output and return the sampled token for all output positions. */ template - std::vector lookahead_decode( + void lookahead_decode( TokenT prev_tok, executorch::runtime::Span input_buffer, executorch::runtime::Method& method, std::function(executorch::runtime::Method&)>& sample, - std::function& should_stop, + std::function& token_callback, size_t ngram_size, size_t window_size, size_t n_verifications, @@ -642,10 +665,18 @@ class StaticAttentionIOManager { size_t input_len = input_buffer.size(); // Set up attention mask for current input length. - auto& mask = get_mask(input_buffer.size()); - set_lookahead_decoding_mask( - mask, input_len, ngram_size, window_size, n_verifications); - set_input(method, config_.attn_mask_input_index, mask.get()); + auto& masks = get_mask(input_buffer.size()); + for (auto& pair : masks) { + auto& mask = *pair.second; + set_lookahead_decoding_mask( + mask, + input_len, + pair.first, + ngram_size, + window_size, + n_verifications); + set_input(method, config_.cache_len_to_mask_idx[pair.first], mask.get()); + } // Position offsets relative to current position, for indexing RoPE // frequence tensors. @@ -663,7 +694,7 @@ class StaticAttentionIOManager { n_verifications); // Decoding loop. - std::vector generated_tokens; + size_t n_generated = 0; size_t verification_offset = std::max(window_size * (ngram_size - 1), static_cast(1)); size_t n_inference = 0; @@ -743,40 +774,42 @@ class StaticAttentionIOManager { } } - bool generated_stop_tok = false; + bool should_stop = false; + // Count the number of accepted tokns in the matched branched, can be + // less than the match length due to callback request stopping. + size_t n_accepted = 0; for (auto tok : longest_match) { - generated_tokens.emplace_back(tok); - if (should_stop(tok)) { - generated_stop_tok = true; + n_generated++; + n_accepted++; + if (!token_callback(tok)) { + should_stop = true; break; } } // Update KV caches and mask for additional matches. - if (longest_match.size() > 1) { + if (n_accepted > 1) { size_t branch_offset = verification_offset + (ngram_size - 1) * matched_branch; update( method, config_.k_cache_output_indices, config_.v_cache_output_indices, - longest_match.size() - 1, + n_accepted - 1, branch_offset); } - if (generated_stop_tok) { + if (should_stop) { break; } - prev_tok = generated_tokens.back(); + prev_tok = longest_match.back(); } ET_LOG( Info, "Generated %zu tokens with %zu inferences(s).", - generated_tokens.size(), + n_generated, n_inference); - - return generated_tokens; } private: @@ -793,12 +826,14 @@ class StaticAttentionIOManager { const_cast( inputMeta->dim_order().data())); executorch::runtime::etensor::Tensor t(&impl); + ET_CHECK(data != nullptr); ET_CHECK(method.set_input(t, idx) == executorch::runtime::Error::Ok); } void set_lookahead_decoding_mask( StaticAttentionMask& mask, size_t input_len, + size_t cache_len, size_t ngram_size, size_t window_size, size_t n_verifications) { @@ -815,8 +850,8 @@ class StaticAttentionIOManager { size_t stride_; }; - size_t stride = config_.cache_len + input_len; - auto input_submask = SubMask(mask.get() + config_.cache_len, stride); + size_t stride = cache_len + input_len; + auto input_submask = SubMask(mask.get() + cache_len, stride); input_submask.at(0, 0) = mask.zero_val(); // Fill entire input mask first. @@ -895,10 +930,9 @@ class StaticAttentionIOManager { StaticAttentionIOConfig config_; size_t input_pos_ = 0; - StaticKVCache kCaches_; - StaticKVCache vCaches_; - std::unordered_map> - attentionMasks_; + StaticKVCache k_caches_; + StaticKVCache v_caches_; + std::unordered_map attentionMasks_; std::vector rope_freqs_cos_override_; std::vector rope_freqs_sin_override_; }; diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 21ad6c837ed..e3859b98210 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -1,7 +1,7 @@ import logging from abc import ABC, abstractmethod from collections import defaultdict, deque -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -23,6 +23,11 @@ _OutputCacheState = Tuple[_CacheMap, _CacheMap] +def none_throws(x: Optional[Any]) -> Any: + assert x is not None + return x + + class StaticKVCache(nn.Module, ABC): def __init__(self, layer_id: int, head_id: int): super().__init__() @@ -57,6 +62,19 @@ def apply_update( After inference, update the cache state for next iteration. The runtime needs to implement the same operation. """ + seq_dim = -1 if transpose else -2 + cache_len = cache.size(seq_dim) + if cache_len == 0: + return + if cache_len < update.size(seq_dim): + update = torch.narrow( + update, + seq_dim, + update.size(seq_dim) - cache_len, + cache_len, + ) + assert update.size(seq_dim) == cache_len + if style == "shift_pointer": if transpose: update_len = update_len or update.size(-1) @@ -72,17 +90,32 @@ def apply_update( ] if style == "smart_mask": + available = cache.size(-2) - pos + update_len = update_len or update.size(-1 if transpose else -2) + if update_len > available: + wrap = update_len - available + update_len = available + else: + wrap = 0 + updated = torch.clone(cache) if transpose: - update_len = update_len or update.size(-1) - updated[..., :, pos : pos + update_len] = update[ - ..., :, update_pos : update_pos + update_len + updated[..., pos : pos + update_len] = update[ + ..., update_pos : update_pos + update_len ] + if wrap > 0: + update_pos += update_len + updated[..., :wrap] = update[..., update_pos : update_pos + wrap] + else: - update_len = update_len or update.size(-2) updated[..., pos : pos + update_len, :] = update[ ..., update_pos : update_pos + update_len, : ] + if wrap > 0: + update_pos += update_len + updated[..., :wrap, :] = update[ + ..., update_pos : update_pos + wrap, : + ] return updated @@ -108,12 +141,13 @@ def update( new_data = new_data.transpose(-1, -2) if in_cache_state is None: return new_data, None + cache = in_cache_state[0].get(self.cache_key()) + if cache is None: + return new_data, None if out_cache_state is None: out_cache_state = ({}, {}) - all_data = torch.cat( - [in_cache_state[0][self.cache_key()], new_data], dim=seq_dim - ) + all_data = torch.cat([cache, new_data], dim=seq_dim) out_k_cache, out_v_cache = out_cache_state out_k_cache[self.cache_key()] = new_data return all_data, (out_k_cache, out_v_cache) @@ -128,10 +162,13 @@ def update( ) -> Tuple[torch.Tensor, Optional[_OutputCacheState]]: if in_cache_state is None: return new_data, None + cache = in_cache_state[1].get(self.cache_key()) + if cache is None: + return new_data, None if out_cache_state is None: out_cache_state = ({}, {}) - all_data = torch.cat([in_cache_state[1][self.cache_key()], new_data], dim=-2) + all_data = torch.cat([cache, new_data], dim=-2) out_k_cache, out_v_cache = out_cache_state out_v_cache[self.cache_key()] = new_data return all_data, (out_k_cache, out_v_cache) @@ -154,6 +191,9 @@ def reset(self): self.unmasked_len = 0 self.tensor[:, :, : self.cache_len] = self.mask_val + def set_input_mask(self, input_mask): + self.tensor[:, :, self.cache_len :] = input_mask + def unmask(self, new_unmasked_len): if new_unmasked_len <= 0: return @@ -162,9 +202,9 @@ def unmask(self, new_unmasked_len): self.tensor[ :, :, - self.cache_len - - self.unmasked_len - - new_unmasked_len : self.cache_len + max( + 0, self.cache_len - self.unmasked_len - new_unmasked_len + ) : self.cache_len - self.unmasked_len, ] = 0 @@ -201,14 +241,21 @@ def __init__( self, config: ModelArgs, input_len: int, - cache_len: int, + cache_lens: Union[int, List[int]], dtype=torch.float32, style: str = "shift_pointer", mask_val: float = float("-inf"), ): - self.mask = StaticAttentionMask( - input_len, cache_len, style=style, mask_val=mask_val, dtype=dtype - ) + if isinstance(cache_lens, int): + cache_lens = [cache_lens] * config.n_layers + assert len(cache_lens) == config.n_layers + + self._masks = { + cl: StaticAttentionMask( + input_len, cl, style=style, mask_val=mask_val, dtype=dtype + ) + for cl in set(cache_lens) + } rope = Rope(config) freqs = rope.get_freqs(None, config.max_seq_len) @@ -219,44 +266,59 @@ def __init__( if split_mha: self.k_caches = { StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( - 1, cache_len, config.head_dim, dtype=dtype + 1, cache_lens[layer_id], none_throws(config.head_dim), dtype=dtype ) for layer_id in range(config.n_layers) - for head_id in range(config.n_kv_heads) + for head_id in range(none_throws(config.n_kv_heads)) + if cache_lens[layer_id] > 0 } self.v_caches = { StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( - 1, cache_len, config.head_dim, dtype=dtype + 1, cache_lens[layer_id], none_throws(config.head_dim), dtype=dtype ) for layer_id in range(config.n_layers) - for head_id in range(config.n_kv_heads) + for head_id in range(none_throws(config.n_kv_heads)) + if cache_lens[layer_id] > 0 } else: self.k_caches = { StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros( - 1, config.n_kv_heads, cache_len, config.head_dim, dtype=dtype + 1, + none_throws(config.n_kv_heads), + cache_lens[layer_id], + none_throws(config.head_dim), + dtype=dtype, ) for layer_id in range(config.n_layers) } self.v_caches = { StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros( - 1, config.n_kv_heads, cache_len, config.head_dim, dtype=dtype + 1, + none_throws(config.n_kv_heads), + cache_lens[layer_id], + none_throws(config.head_dim), + dtype=dtype, ) for layer_id in range(config.n_layers) } self.config = config self.input_len = input_len - self.cache_len = cache_len + self.cache_lens = cache_lens self.style = style self.mask_val = mask_val self.pos = 0 self.cache_full = False + @property + def masks(self): + return {cache_len: mask.tensor for cache_len, mask in self._masks.items()} + def reset(self): self.pos = 0 self.cache_full = False - self.mask.reset() + for mask in self._masks.values(): + mask.reset() def prefill( self, @@ -266,10 +328,13 @@ def prefill( if self.cache_full: raise RuntimeError("KV cache is full.") - self.mask.tensor[:, :, self.cache_len :] = torch.triu( - torch.full((1, self.input_len, self.input_len), self.mask_val), - diagonal=1, - ) + for mask in self._masks.values(): + mask.set_input_mask( + torch.triu( + torch.full((1, self.input_len, self.input_len), self.mask_val), + diagonal=1, + ) + ) logits = None all_logits = None @@ -296,10 +361,13 @@ def decode( if self.cache_full: raise RuntimeError("KV cache is full.") - self.mask.tensor[:, :, self.cache_len :] = torch.triu( - torch.full((1, self.input_len, self.input_len), self.mask_val), - diagonal=1, - ) + for mask in self._masks.values(): + mask.set_input_mask( + torch.triu( + torch.full((1, self.input_len, self.input_len), self.mask_val), + diagonal=1, + ) + ) stop_tokens = stop_tokens or [] new_tokens = [init_token] @@ -340,15 +408,10 @@ def lookahead_decode( # noqa: C901 lambda: StaticAttentionIOManager.NGramCache(n_verifications) ) - self.mask.tensor[:, :, self.cache_len :] = self._get_lookahead_decoding_mask( - ngram_size, window_size, n_verifications - ) - logger.debug("Lookahead decoding mask: ") - for i in range(self.input_len): - logger.debug( - " ".join( - ("X" if x == 0.0 else " ") - for x in self.mask.tensor[0][i][self.cache_len :] + for mask in self._masks.values(): + mask.set_input_mask( + self._get_lookahead_decoding_mask( + ngram_size, window_size, n_verifications ) ) @@ -455,7 +518,7 @@ def _run_once( n_tokens = len(tokens) if n_tokens < self.input_len: tokens += [0] * (self.input_len - n_tokens) - tokens = torch.tensor([tokens], dtype=torch.int32) + tokens = torch.tensor([tokens], dtype=torch.int32) # pyre-ignore[9] if freqs_cos_override is None: freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len] if freqs_sin_override is None: @@ -463,24 +526,20 @@ def _run_once( y, attn_updates = model( tokens, { - "mask": self.mask.tensor, + "masks": self.masks, "freqs_cos_override": freqs_cos_override, "freqs_sin_override": freqs_sin_override, "in_cache_state": (self.k_caches, self.v_caches), }, ) non_padded_len = non_padded_len or n_tokens - if self.pos + non_padded_len <= self.cache_len: - self._update_states(attn_updates, 0, non_padded_len) - else: - self.cache_full = True + self._update_states(attn_updates, 0, non_padded_len) return y, attn_updates def _update_states(self, attn_updates, update_pos, update_len): - assert self.pos + update_len <= self.cache_len - - self.mask.unmask(update_len) + for mask in self._masks.values(): + mask.unmask(update_len) k_cache_updates, v_cache_updates = attn_updates["out_cache_state"] for cache_id, update in k_cache_updates.items(): self.k_caches[cache_id] = StaticKVCache.apply_update( @@ -724,6 +783,7 @@ def from_conv2ds(ts): new_vs, freqs_cos, freqs_sin, + seq_len, **kwargs, ) else: @@ -756,9 +816,9 @@ def _forward_sha( new_vs, freqs_cos, freqs_sin, + seq_len, **kwargs: ForwardOptions, ): - mask = kwargs.get("mask") if (freqs_cos_override := kwargs.get("freqs_cos_override")) is not None: freqs_cos = freqs_cos_override # pyre-ignore if (freqs_sin_override := kwargs.get("freqs_sin_override")) is not None: @@ -789,6 +849,9 @@ def _forward_sha( ) all_vs.append(vs) + cache_len = all_ks[0].size(-2) - seq_len + mask = kwargs["masks"][cache_len] + heads = [] for i in range(self.n_heads): kv_idx = i // self.n_heads_per_kv_group @@ -811,7 +874,6 @@ def _forward_mha( seq_len, **kwargs: ForwardOptions, ): - mask = kwargs.get("mask") in_cache_state = kwargs.get("in_cache_state") out_cache_state = kwargs.get("out_cache_state") @@ -836,6 +898,12 @@ def _forward_mha( if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) + + mask = None + masks = kwargs.get("masks") + if masks: + cache_len = k.size(-2) - seq_len + mask = masks[cache_len] y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) return y.transpose(1, 2).contiguous().view(bsz, seq_len, -1), out_cache_state @@ -846,14 +914,17 @@ def load_weights_from_attention_mha( if self.split_mha: for i in range(self.n_heads): self.wqs[i].weight.data.copy_( + # pyre-ignore[29] other.wq.weight[i * self.head_dim : (i + 1) * self.head_dim, :] ) for i in range(self.n_kv_heads): self.wks[i].weight.data.copy_( + # pyre-ignore[29] other.wk.weight[i * self.head_dim : (i + 1) * self.head_dim, :] ) self.wvs[i].weight.data.copy_( + # pyre-ignore[29] other.wv.weight[i * self.head_dim : (i + 1) * self.head_dim, :] ) else: @@ -861,7 +932,7 @@ def load_weights_from_attention_mha( self.wks[0].load_state_dict(other.wk.state_dict()) self.wvs[0].load_state_dict(other.wv.state_dict()) - self.wo.weight.data.copy_(other.wo.weight) + self.wo.weight.data.copy_(other.wo.weight) # pyre-ignore[6] if other.use_qk_norm: self.use_qk_norm = True diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index 0f7f412bd91..2461732db5a 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -1,7 +1,7 @@ import copy import itertools import unittest -from collections import defaultdict +from collections import Counter, defaultdict import torch from executorch.examples.models.llama.attention import AttentionMHA @@ -12,6 +12,7 @@ StaticAttention, StaticAttentionIOManager, StaticAttentionMask, + StaticKCache, StaticKVCache, ) @@ -20,6 +21,37 @@ class StaticAttentionTest(unittest.TestCase): def setUp(self): torch.manual_seed(42) + def test_sliding_window_cache_and_mask(self): + def test(style): + cache_len = 16 + + # Cache initialized to -128, mask to 64, integers from 0 are added to cache, + # check the set of positive values in cache + mask. + cache = StaticKCache(0, 0) + cache_data = torch.full((1, cache_len, 1), -128, dtype=torch.int64) + mask = StaticAttentionMask( + 1, cache_len, style=style, mask_val=64, dtype=torch.int64 + ) + for i in range(0, 3 * cache_len, 3): + update = torch.tensor([i, i + 1, i + 2], dtype=torch.int64).view( + 1, 3, 1 + ) + cache_data = cache.apply_update( + cache_data, + update, + i % cache_len, + style, + ) + mask.unmask(3) + unmasked_cache_data = cache_data.flatten() + mask.tensor.flatten()[:-1] + self.assertEqual( + Counter([x for x in unmasked_cache_data.tolist() if x >= 0]), + Counter(list(range(i + 2, -1, -1))[:cache_len]), + ) + + test("shift_pointer") + test("smart_mask") + def test_without_cache(self): def test( use_qk_norm, qk_norm_before_rope, split_mha, adopt_hf_rope, use_conv2d @@ -75,7 +107,7 @@ def test( x, freqs_cos, freqs_sin, - mask=mask, + masks={0: mask}, ) self.assertTrue( torch.isclose(y, expected, rtol=1e-3).all(), @@ -139,7 +171,7 @@ def test_with_style(style): x[:, i * chunk_len : (i + 1) * chunk_len, :], hf_freqs_cos[i * chunk_len : (i + 1) * chunk_len], hf_freqs_sin[i * chunk_len : (i + 1) * chunk_len], - mask=mask.tensor, + masks={cache_len: mask.tensor}, in_cache_state=(k_caches, v_caches), out_cache_state=({}, {}), )