diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index ce3b01b6d68..69ee4e192e1 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -47,29 +47,39 @@ def calculate_cache_key(layer_id: int, head_id: int) -> str: return f"l{layer_id},h{head_id}" @staticmethod - def apply_update(cache, update, pos, style, transpose=False): + def apply_update( + cache, update, pos, style, transpose=False, update_pos=0, update_len=None + ): """ After inference, update the cache state for next iteration. The runtime needs to implement the same operation. """ if style == "shift_pointer": if transpose: - update_len = update.size(-1) + update_len = update_len or update.size(-1) updated = torch.roll(cache, -update_len, -1) - updated[:, :, -update_len:] = update + updated[:, :, -update_len:] = update[ + :, :, update_pos : update_pos + update_len + ] else: - update_len = update.size(-2) + update_len = update_len or update.size(-2) updated = torch.roll(cache, -update_len, -2) - updated[:, -update_len:, :] = update + updated[:, -update_len:, :] = update[ + :, update_pos : update_pos + update_len, : + ] if style == "smart_mask": updated = torch.clone(cache) if transpose: - update_len = update.size(-1) - updated[:, :, pos : pos + update_len] = update + update_len = update_len or update.size(-1) + updated[:, :, pos : pos + update_len] = update[ + :, :, update_pos : update_pos + update_len + ] else: - update_len = update.size(-2) - updated[:, pos : pos + update_len, :] = update + update_len = update_len or update.size(-2) + updated[:, pos : pos + update_len, :] = update[ + :, update_pos : update_pos + update_len, : + ] return updated @@ -163,6 +173,164 @@ def unmask(self, new_unmasked_len): self.unmasked_len += new_unmasked_len +class StaticAttentionIOManager: + def __init__( + self, + config: ModelArgs, + input_len: int, + cache_len: int, + style: str = "shift_pointer", + mask_val: float = float("-inf"), + ): + self.mask = StaticAttentionMask( + input_len, cache_len, style=style, mask_val=mask_val + ) + + rope = Rope(config) + freqs = rope.get_freqs(None, config.max_seq_len) + self.freqs_cos = freqs[0] + self.freqs_sin = freqs[1] + + self.k_caches = { + StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( + 1, cache_len, config.head_dim + ) + for layer_id in range(config.n_layers) + for head_id in range(config.n_kv_heads) + } + self.v_caches = { + StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( + 1, cache_len, config.head_dim + ) + for layer_id in range(config.n_layers) + for head_id in range(config.n_kv_heads) + } + + self.config = config + self.input_len = input_len + self.cache_len = cache_len + self.style = style + self.mask_val = mask_val + self.pos = 0 + self.cache_full = False + + def reset(self): + self.pos = 0 + self.cache_full = False + self.mask.reset() + + def prefill( + self, + model: Callable[..., Any], + tokens: List[int], + ) -> torch.Tensor: + if self.cache_full: + raise RuntimeError("KV cache is full.") + + self.mask.tensor[:, :, self.cache_len :] = torch.triu( + torch.full((1, self.input_len, self.input_len), self.mask_val), + diagonal=1, + ) + + logits = None + all_logits = None + for i in range(0, len(tokens), self.input_len): + logits = self._run_once(model, tokens[i : i + self.input_len])[0] + if self.config.generate_full_logits: + if all_logits is None: + all_logits = logits + else: + all_logits = torch.cat([all_logits, logits], dim=1) + + if self.config.generate_full_logits: + return all_logits[:, : len(tokens), :] + + return logits + + def decode( + self, + model: Callable[..., Any], + init_token: int, + n: int, + stop_tokens: Optional[List[int]] = None, + ): + if self.cache_full: + raise RuntimeError("KV cache is full.") + + self.mask.tensor[:, :, self.cache_len :] = torch.triu( + torch.full((1, self.input_len, self.input_len), self.mask_val), + diagonal=1, + ) + + stop_tokens = stop_tokens or [] + new_tokens = [init_token] + for _ in range(n): + y = self._run_once(model, new_tokens[-1:])[0] + new_tokens.append(y[:, :1, :].argmax().item()) + if new_tokens[-1] in stop_tokens: + break + + return new_tokens + + def _run_once( + self, + model: Callable[..., Any], + tokens: List[int], + non_padded_len: Optional[int] = None, + freqs_cos_override: Optional[torch.Tensor] = None, + freqs_sin_override: Optional[torch.Tensor] = None, + ): + n_tokens = len(tokens) + if n_tokens < self.input_len: + tokens += [0] * (self.input_len - n_tokens) + tokens = torch.tensor([tokens], dtype=torch.int32) + if freqs_cos_override is None: + freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len] + if freqs_sin_override is None: + freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len] + y, attn_updates = model( + tokens, + { + "mask": self.mask.tensor, + "freqs_cos_override": freqs_cos_override, + "freqs_sin_override": freqs_sin_override, + "in_cache_state": (self.k_caches, self.v_caches), + }, + ) + non_padded_len = non_padded_len or n_tokens + if self.pos + non_padded_len <= self.cache_len: + self._update_states(attn_updates, 0, non_padded_len) + else: + self.cache_full = True + + return y, attn_updates + + def _update_states(self, attn_updates, update_pos, update_len): + assert self.pos + update_len <= self.cache_len + + self.mask.unmask(update_len) + k_cache_updates, v_cache_updates = attn_updates["out_cache_state"] + for cache_id, update in k_cache_updates.items(): + self.k_caches[cache_id] = StaticKVCache.apply_update( + self.k_caches[cache_id], + update, + self.pos, + style=self.style, + update_pos=update_pos, + update_len=update_len, + ) + for cache_id, update in v_cache_updates.items(): + self.v_caches[cache_id] = StaticKVCache.apply_update( + self.v_caches[cache_id], + update, + self.pos, + style=self.style, + update_pos=update_pos, + update_len=update_len, + ) + self.pos += update_len + + class _Rope(nn.Module): def __init__(self, use_hf_rope): super().__init__() diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index 77b8be5d401..e40643299ef 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -1,12 +1,13 @@ import unittest import torch -from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions +from executorch.examples.models.llama.attention import AttentionMHA from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope from executorch.examples.models.llama.static_attention import ( StaticAttention, + StaticAttentionIOManager, StaticAttentionMask, StaticKVCache, ) @@ -171,8 +172,6 @@ def test_within_transformer(self): static_layer.attention.load_weights_from_attention_mha(mha_layer.attention) x = torch.randint(config.vocab_size, (1, config.max_seq_len)) - rope = Rope(config) - freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) expected = mha_transformer(x) n_chunks = 3 @@ -180,53 +179,14 @@ def test_within_transformer(self): cache_len = config.max_seq_len - chunk_len def test_with_style(style): - mask = StaticAttentionMask(chunk_len, cache_len, style=style) - mask.tensor[:, :, cache_len:] = torch.triu( - torch.full((1, chunk_len, chunk_len), float("-inf")), - diagonal=1, - ) - k_caches = { - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( - 1, cache_len, config.head_dim - ) - for layer_id in range(config.n_layers) - for i in range(config.n_kv_heads) - } - v_caches = { - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( - 1, cache_len, config.head_dim - ) - for layer_id in range(config.n_layers) - for i in range(config.n_kv_heads) - } + mgr = StaticAttentionIOManager(config, chunk_len, cache_len, style=style) ys = [] for i in range(n_chunks): - y_i, attn_update = static_transformer( - x[:, i * chunk_len : (i + 1) * chunk_len], - attn_options=ForwardOptions( - mask=mask.tensor, - freqs_cos_override=freqs_cos[ - i * chunk_len : (i + 1) * chunk_len - ], - freqs_sin_override=freqs_sin[ - i * chunk_len : (i + 1) * chunk_len - ], - in_cache_state=(k_caches, v_caches), - out_cache_state=({}, {}), - ), + y_i = mgr.prefill( + static_transformer, + x[0][i * chunk_len : (i + 1) * chunk_len].tolist(), ) ys.append(y_i) - mask.unmask(chunk_len) - k_cache_updates, v_cache_updates = attn_update["out_cache_state"] - if i < n_chunks - 1: - for cache_id, update in k_cache_updates.items(): - k_caches[cache_id] = StaticKVCache.apply_update( - k_caches[cache_id], update, pos=chunk_len * i, style=style - ) - for cache_id, update in v_cache_updates.items(): - v_caches[cache_id] = StaticKVCache.apply_update( - v_caches[cache_id], update, pos=chunk_len * i, style=style - ) self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())