Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 214 additions & 1 deletion examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
Expand All @@ -14,6 +16,7 @@
from executorch.examples.models.llama.rope import Rope


logger = logging.getLogger(__name__)
_CacheMap = Dict[str, torch.Tensor]
# Key and value caches are kept separate so the key caches can be kept transposed.
_InputCacheState = Tuple[_CacheMap, _CacheMap]
Expand Down Expand Up @@ -174,6 +177,24 @@ def unmask(self, new_unmasked_len):


class StaticAttentionIOManager:
class NGramCache:
def __init__(self, max_size):
self.cache = deque()
self.max_size = max_size

def add(self, x):
if x in self.cache:
return
if len(self.cache) == self.max_size:
self.cache.popleft()
self.cache.append(x)

def __iter__(self):
return iter(self.cache)

def __str__(self):
return str(self.cache)

def __init__(
self,
config: ModelArgs,
Expand Down Expand Up @@ -266,12 +287,143 @@ def decode(
new_tokens = [init_token]
for _ in range(n):
y = self._run_once(model, new_tokens[-1:])[0]
new_tokens.append(y[:, :1, :].argmax().item())
new_tokens.append(y[:, :1, ...].argmax().item())
if new_tokens[-1] in stop_tokens:
break

return new_tokens

def lookahead_decode( # noqa: C901
self,
model: Callable[..., Any],
init_token: int,
n: int,
ngram_size: int,
window_size: int,
n_verifications: int,
stop_tokens: Optional[List[int]] = None,
ngram_caches: Optional[Dict[int, "StaticAttentionIOManager.NGramCache"]] = None,
):
if self.cache_full:
raise RuntimeError("KV cache is full.")

if (ngram_size - 1) * (window_size + n_verifications) > self.input_len:
raise RuntimeError(
"Lookahead decoding setting not compatible with input length."
f" input_len = {self.input_len},"
f" ngram_size = {ngram_size},"
f" window_size = {window_size},"
f" n_verifications = {n_verifications}"
)

stop_tokens = stop_tokens or []
if ngram_caches is None:
ngram_caches = defaultdict(
lambda: StaticAttentionIOManager.NGramCache(n_verifications)
)

self.mask.tensor[:, :, self.cache_len :] = self._get_lookahead_decoding_mask(
ngram_size, window_size, n_verifications
)
logger.debug("Lookahead decoding mask: ")
for i in range(self.input_len):
logger.debug(
" ".join(
("X" if x == 0.0 else " ")
for x in self.mask.tensor[0][i][self.cache_len :]
)
)

pos_offsets = self._get_lookahead_position_offsets(
ngram_size, window_size, n_verifications
)

verification_offset = max(window_size * (ngram_size - 1), 1)
new_tokens = [init_token]
x = [init_token] * self.input_len
inference_cnt = 0
while len(new_tokens) < n + 1:
# Update verification branch with cached n-grams.
cache = ngram_caches[x[0]]
for i, ngram in enumerate(cache):
for j, token in enumerate(ngram):
x[verification_offset + i * (ngram_size - 1) + j] = token

y, attn_updates = self._run_once(
model,
x,
non_padded_len=1,
freqs_cos_override=self.freqs_cos[pos_offsets + self.pos],
freqs_sin_override=self.freqs_sin[pos_offsets + self.pos],
)
inference_cnt += 1
# Only supports greedy decoding for now.
y = y[0].argmax(dim=-1).tolist()
new_tokens.append(y[0])
logger.debug(f"{self.pos}: x = {x[0]}, y = {y[0]}")
if new_tokens[-1] in stop_tokens:
break

# Collect new n-grams.
for i in range(window_size):
key = x[i]
suffix = []
for j in range(1, ngram_size - 1):
suffix.append(x[i + j * window_size])
suffix.append(y[i + window_size * (ngram_size - 2)])
ngram_caches[key].add(suffix)

# Verification.
longest_match = []
matched_branch = None
for i in range(n_verifications):
match = [y[0]]
j = 0
# for j in range(ngram_size - 1):
while (
j < ngram_size - 1
and x[verification_offset + (ngram_size - 1) * i + j] == match[-1]
):
match.append(y[verification_offset + (ngram_size - 1) * i + j])
j += 1
if len(match) - 1 > len(longest_match):
longest_match = match[1:]
matched_branch = i

if matched_branch is not None:
logger.debug(
f"Matched {len(longest_match)} additional tokens from n-grams: {longest_match}"
)
for stop in stop_tokens:
if stop in longest_match:
longest_match = longest_match[: longest_match.index(stop) + 1]

new_tokens.extend(longest_match)

# Update KV caches and attention mask for the additional matched tokens.
branch_offset = verification_offset + (ngram_size - 1) * matched_branch
self._update_states(
attn_updates,
update_pos=branch_offset,
update_len=len(longest_match),
)

# Update lookahead branch.
for i in range(ngram_size - 2):
for j in range(window_size):
x[window_size * i + j] = x[window_size * (i + 1) + j]
for j in range(window_size):
x[window_size * (ngram_size - 2) + j] = y[
window_size * (ngram_size - 2) + j
]

x[0] = new_tokens[-1]

logger.info(
f"Generated {len(new_tokens) - 1} tokens with {inference_cnt} inference(s)."
)
return new_tokens

def _run_once(
self,
model: Callable[..., Any],
Expand Down Expand Up @@ -330,6 +482,67 @@ def _update_states(self, attn_updates, update_pos, update_len):
)
self.pos += update_len

def _get_lookahead_decoding_mask(
self, ngram_size: int, window_size: int, n_verifications: int
) -> torch.Tensor:
mask = torch.full((self.input_len, self.input_len), self.mask_val)
mask[0][0] = 0.0

lookahead_submask = torch.triu(
torch.full((window_size, window_size), self.mask_val),
diagonal=1,
)
for i in range(ngram_size - 1):
offset = window_size * i
mask[offset : offset + window_size, :window_size] = lookahead_submask
for j in range(1, i + 1):
mask[
offset : offset + window_size,
window_size * j : window_size * (j + 1),
].fill_diagonal_(0.0)

verification_offset = max(window_size * (ngram_size - 1), 1)
verification_submask = torch.triu(
torch.full((ngram_size - 1, ngram_size - 1), self.mask_val),
diagonal=1,
)
for i in range(n_verifications):
mask[
verification_offset
+ i * (ngram_size - 1) : verification_offset
+ (i + 1) * (ngram_size - 1),
verification_offset
+ i * (ngram_size - 1) : verification_offset
+ (i + 1) * (ngram_size - 1),
] = verification_submask
mask[verification_offset:, :1] = 0.0

return mask

def _get_lookahead_position_offsets(
self, ngram_size: int, window_size: int, n_verifications: int
) -> torch.Tensor:
# Input position offsets, used for indexing RoPE frequencies.
pos_offsets = torch.zeros(self.input_len, dtype=torch.int32)
idx = 0
# Lookahead branches: [i + 0, i + 1, ..., i + window_size - 1] for time i.
if window_size > 0:
for i in range(ngram_size - 1):
for j in range(window_size):
pos_offsets[idx] = i + j
idx += 1
else:
pos_offsets[0] = 0
idx += 1

# Verification branches: [1, 2, ..., ngram_size - 1].
for _ in range(n_verifications):
for j in range(1, ngram_size):
pos_offsets[idx] = j
idx += 1

return pos_offsets


class _Rope(nn.Module):
def __init__(self, use_hf_rope):
Expand Down
72 changes: 63 additions & 9 deletions examples/models/llama/tests/test_static_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from collections import defaultdict

import torch
from executorch.examples.models.llama.attention import AttentionMHA
Expand Down Expand Up @@ -164,15 +165,7 @@ def test_with_style(style):
test_with_style("shift_pointer")
test_with_style("smart_mask")

def test_within_transformer(self):
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=24,
n_layers=4,
vocab_size=128,
)
def _get_test_transformers(self, config):
mha_transformer = construct_transformer(config).eval()

config.attention_type = "static"
Expand All @@ -183,6 +176,18 @@ def test_within_transformer(self):
):
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)

return mha_transformer, static_transformer

def test_within_transformer(self):
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=24,
n_layers=4,
vocab_size=128,
)
mha_transformer, static_transformer = self._get_test_transformers(config)
x = torch.randint(config.vocab_size, (1, config.max_seq_len))
expected = mha_transformer(x)

Expand All @@ -204,3 +209,52 @@ def test_with_style(style):

test_with_style("shift_pointer")
test_with_style("smart_mask")

def test_lookahead_decode(self):
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=128,
n_layers=4,
vocab_size=128,
generate_full_logits=True,
)
_, static_transformer = self._get_test_transformers(config)

input_len = 32
cache_len = config.max_seq_len - input_len
prefill_input = torch.randint(config.vocab_size, (input_len,))
ref_mgr = StaticAttentionIOManager(config, input_len, cache_len)
lookahead_mgr = StaticAttentionIOManager(config, input_len, cache_len)

next_tok = (
ref_mgr.prefill(static_transformer, prefill_input.tolist())[0][-1]
.argmax()
.item()
)
ref_output = ref_mgr.decode(static_transformer, next_tok, 50)

ngram_size = 3
window_size = 8
n_verifications = 8
ngram_caches = defaultdict(
lambda: StaticAttentionIOManager.NGramCache(n_verifications)
)
for _ in range(2): # run twice, first run will populates the cache
lookahead_mgr.reset()
next_tok = (
lookahead_mgr.prefill(static_transformer, prefill_input.tolist())[0][-1]
.argmax()
.item()
)
lookahead_output = lookahead_mgr.lookahead_decode(
static_transformer,
next_tok,
50,
ngram_size=ngram_size,
window_size=window_size,
n_verifications=n_verifications,
ngram_caches=ngram_caches,
)
self.assertEqual(lookahead_output[: len(ref_output)], ref_output)
Loading