diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 57b5796cbb3..8f3486353f2 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -1,4 +1,6 @@ +import logging from abc import ABC, abstractmethod +from collections import defaultdict, deque from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -14,6 +16,7 @@ from executorch.examples.models.llama.rope import Rope +logger = logging.getLogger(__name__) _CacheMap = Dict[str, torch.Tensor] # Key and value caches are kept separate so the key caches can be kept transposed. _InputCacheState = Tuple[_CacheMap, _CacheMap] @@ -174,6 +177,24 @@ def unmask(self, new_unmasked_len): class StaticAttentionIOManager: + class NGramCache: + def __init__(self, max_size): + self.cache = deque() + self.max_size = max_size + + def add(self, x): + if x in self.cache: + return + if len(self.cache) == self.max_size: + self.cache.popleft() + self.cache.append(x) + + def __iter__(self): + return iter(self.cache) + + def __str__(self): + return str(self.cache) + def __init__( self, config: ModelArgs, @@ -266,12 +287,143 @@ def decode( new_tokens = [init_token] for _ in range(n): y = self._run_once(model, new_tokens[-1:])[0] - new_tokens.append(y[:, :1, :].argmax().item()) + new_tokens.append(y[:, :1, ...].argmax().item()) if new_tokens[-1] in stop_tokens: break return new_tokens + def lookahead_decode( # noqa: C901 + self, + model: Callable[..., Any], + init_token: int, + n: int, + ngram_size: int, + window_size: int, + n_verifications: int, + stop_tokens: Optional[List[int]] = None, + ngram_caches: Optional[Dict[int, "StaticAttentionIOManager.NGramCache"]] = None, + ): + if self.cache_full: + raise RuntimeError("KV cache is full.") + + if (ngram_size - 1) * (window_size + n_verifications) > self.input_len: + raise RuntimeError( + "Lookahead decoding setting not compatible with input length." + f" input_len = {self.input_len}," + f" ngram_size = {ngram_size}," + f" window_size = {window_size}," + f" n_verifications = {n_verifications}" + ) + + stop_tokens = stop_tokens or [] + if ngram_caches is None: + ngram_caches = defaultdict( + lambda: StaticAttentionIOManager.NGramCache(n_verifications) + ) + + self.mask.tensor[:, :, self.cache_len :] = self._get_lookahead_decoding_mask( + ngram_size, window_size, n_verifications + ) + logger.debug("Lookahead decoding mask: ") + for i in range(self.input_len): + logger.debug( + " ".join( + ("X" if x == 0.0 else " ") + for x in self.mask.tensor[0][i][self.cache_len :] + ) + ) + + pos_offsets = self._get_lookahead_position_offsets( + ngram_size, window_size, n_verifications + ) + + verification_offset = max(window_size * (ngram_size - 1), 1) + new_tokens = [init_token] + x = [init_token] * self.input_len + inference_cnt = 0 + while len(new_tokens) < n + 1: + # Update verification branch with cached n-grams. + cache = ngram_caches[x[0]] + for i, ngram in enumerate(cache): + for j, token in enumerate(ngram): + x[verification_offset + i * (ngram_size - 1) + j] = token + + y, attn_updates = self._run_once( + model, + x, + non_padded_len=1, + freqs_cos_override=self.freqs_cos[pos_offsets + self.pos], + freqs_sin_override=self.freqs_sin[pos_offsets + self.pos], + ) + inference_cnt += 1 + # Only supports greedy decoding for now. + y = y[0].argmax(dim=-1).tolist() + new_tokens.append(y[0]) + logger.debug(f"{self.pos}: x = {x[0]}, y = {y[0]}") + if new_tokens[-1] in stop_tokens: + break + + # Collect new n-grams. + for i in range(window_size): + key = x[i] + suffix = [] + for j in range(1, ngram_size - 1): + suffix.append(x[i + j * window_size]) + suffix.append(y[i + window_size * (ngram_size - 2)]) + ngram_caches[key].add(suffix) + + # Verification. + longest_match = [] + matched_branch = None + for i in range(n_verifications): + match = [y[0]] + j = 0 + # for j in range(ngram_size - 1): + while ( + j < ngram_size - 1 + and x[verification_offset + (ngram_size - 1) * i + j] == match[-1] + ): + match.append(y[verification_offset + (ngram_size - 1) * i + j]) + j += 1 + if len(match) - 1 > len(longest_match): + longest_match = match[1:] + matched_branch = i + + if matched_branch is not None: + logger.debug( + f"Matched {len(longest_match)} additional tokens from n-grams: {longest_match}" + ) + for stop in stop_tokens: + if stop in longest_match: + longest_match = longest_match[: longest_match.index(stop) + 1] + + new_tokens.extend(longest_match) + + # Update KV caches and attention mask for the additional matched tokens. + branch_offset = verification_offset + (ngram_size - 1) * matched_branch + self._update_states( + attn_updates, + update_pos=branch_offset, + update_len=len(longest_match), + ) + + # Update lookahead branch. + for i in range(ngram_size - 2): + for j in range(window_size): + x[window_size * i + j] = x[window_size * (i + 1) + j] + for j in range(window_size): + x[window_size * (ngram_size - 2) + j] = y[ + window_size * (ngram_size - 2) + j + ] + + x[0] = new_tokens[-1] + + logger.info( + f"Generated {len(new_tokens) - 1} tokens with {inference_cnt} inference(s)." + ) + return new_tokens + def _run_once( self, model: Callable[..., Any], @@ -330,6 +482,67 @@ def _update_states(self, attn_updates, update_pos, update_len): ) self.pos += update_len + def _get_lookahead_decoding_mask( + self, ngram_size: int, window_size: int, n_verifications: int + ) -> torch.Tensor: + mask = torch.full((self.input_len, self.input_len), self.mask_val) + mask[0][0] = 0.0 + + lookahead_submask = torch.triu( + torch.full((window_size, window_size), self.mask_val), + diagonal=1, + ) + for i in range(ngram_size - 1): + offset = window_size * i + mask[offset : offset + window_size, :window_size] = lookahead_submask + for j in range(1, i + 1): + mask[ + offset : offset + window_size, + window_size * j : window_size * (j + 1), + ].fill_diagonal_(0.0) + + verification_offset = max(window_size * (ngram_size - 1), 1) + verification_submask = torch.triu( + torch.full((ngram_size - 1, ngram_size - 1), self.mask_val), + diagonal=1, + ) + for i in range(n_verifications): + mask[ + verification_offset + + i * (ngram_size - 1) : verification_offset + + (i + 1) * (ngram_size - 1), + verification_offset + + i * (ngram_size - 1) : verification_offset + + (i + 1) * (ngram_size - 1), + ] = verification_submask + mask[verification_offset:, :1] = 0.0 + + return mask + + def _get_lookahead_position_offsets( + self, ngram_size: int, window_size: int, n_verifications: int + ) -> torch.Tensor: + # Input position offsets, used for indexing RoPE frequencies. + pos_offsets = torch.zeros(self.input_len, dtype=torch.int32) + idx = 0 + # Lookahead branches: [i + 0, i + 1, ..., i + window_size - 1] for time i. + if window_size > 0: + for i in range(ngram_size - 1): + for j in range(window_size): + pos_offsets[idx] = i + j + idx += 1 + else: + pos_offsets[0] = 0 + idx += 1 + + # Verification branches: [1, 2, ..., ngram_size - 1]. + for _ in range(n_verifications): + for j in range(1, ngram_size): + pos_offsets[idx] = j + idx += 1 + + return pos_offsets + class _Rope(nn.Module): def __init__(self, use_hf_rope): diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index a6eac24db1f..44a483fe981 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -1,4 +1,5 @@ import unittest +from collections import defaultdict import torch from executorch.examples.models.llama.attention import AttentionMHA @@ -164,15 +165,7 @@ def test_with_style(style): test_with_style("shift_pointer") test_with_style("smart_mask") - def test_within_transformer(self): - config = ModelArgs( - dim=64, - n_heads=4, - n_kv_heads=2, - max_seq_len=24, - n_layers=4, - vocab_size=128, - ) + def _get_test_transformers(self, config): mha_transformer = construct_transformer(config).eval() config.attention_type = "static" @@ -183,6 +176,18 @@ def test_within_transformer(self): ): static_layer.attention.load_weights_from_attention_mha(mha_layer.attention) + return mha_transformer, static_transformer + + def test_within_transformer(self): + config = ModelArgs( + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=24, + n_layers=4, + vocab_size=128, + ) + mha_transformer, static_transformer = self._get_test_transformers(config) x = torch.randint(config.vocab_size, (1, config.max_seq_len)) expected = mha_transformer(x) @@ -204,3 +209,52 @@ def test_with_style(style): test_with_style("shift_pointer") test_with_style("smart_mask") + + def test_lookahead_decode(self): + config = ModelArgs( + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=128, + n_layers=4, + vocab_size=128, + generate_full_logits=True, + ) + _, static_transformer = self._get_test_transformers(config) + + input_len = 32 + cache_len = config.max_seq_len - input_len + prefill_input = torch.randint(config.vocab_size, (input_len,)) + ref_mgr = StaticAttentionIOManager(config, input_len, cache_len) + lookahead_mgr = StaticAttentionIOManager(config, input_len, cache_len) + + next_tok = ( + ref_mgr.prefill(static_transformer, prefill_input.tolist())[0][-1] + .argmax() + .item() + ) + ref_output = ref_mgr.decode(static_transformer, next_tok, 50) + + ngram_size = 3 + window_size = 8 + n_verifications = 8 + ngram_caches = defaultdict( + lambda: StaticAttentionIOManager.NGramCache(n_verifications) + ) + for _ in range(2): # run twice, first run will populates the cache + lookahead_mgr.reset() + next_tok = ( + lookahead_mgr.prefill(static_transformer, prefill_input.tolist())[0][-1] + .argmax() + .item() + ) + lookahead_output = lookahead_mgr.lookahead_decode( + static_transformer, + next_tok, + 50, + ngram_size=ngram_size, + window_size=window_size, + n_verifications=n_verifications, + ngram_caches=ngram_caches, + ) + self.assertEqual(lookahead_output[: len(ref_output)], ref_output)