From c063e189cc7f8120b551a1109b85037dc07cbf09 Mon Sep 17 00:00:00 2001 From: haowhsu-quic Date: Thu, 11 Sep 2025 10:53:29 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - sliding attention lookahead support - add lookahead decode to speed up prompt calibration - sliding attention lookahead support --- .../oss_scripts/llama/decoder_utils.py | 383 +++++++++++++++--- examples/qualcomm/oss_scripts/llama/llama.py | 16 +- .../oss_scripts/llama/masking_utils.py | 47 ++- .../oss_scripts/llama/runner/kv_manager.cpp | 32 +- .../oss_scripts/llama/runner/kv_manager.h | 11 +- .../llama/runner/lhd_token_generator.cpp | 19 + .../llama/runner/lhd_token_generator.h | 23 +- .../oss_scripts/llama/runner/runner.cpp | 9 +- 8 files changed, 431 insertions(+), 109 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index 76cf85c6e9c..2d18fab1de1 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -7,7 +7,8 @@ import getpass import logging import os -from typing import Callable, Optional, Union +from collections import defaultdict, OrderedDict +from typing import Callable, List, Optional, Tuple, Union import numpy as np @@ -100,6 +101,155 @@ def _model_call(self, inps): return all_logits +class LookaheadDecoder: + """ + Lookahead decoding to speed up calibration + """ + + class NgramPool: + def __init__(self, num_verifications: int): + self.pool = defaultdict(OrderedDict) + # keep the amount of ngrams as number of verification branches for simplicity + self.num_verifications = num_verifications + + def add(self, ngram: Tuple[int]): + key = ngram[0] + # since there is no OrderedSet in python, use OrderedDict with dummy value 1 + self.pool[key][ngram[1:]] = 1 + if len(self.pool[key]) > self.num_verifications: + # remove cache in FIFO fashion + self.pool[key].popitem(last=False) + + def __getitem__(self, key): + return self.pool[key] + + def __iter__(self): + return iter(self.pool) + + def __init__( + self, + window_size: int, + ngram_size: int, + num_verifications: int, + ar_size: int, + mask_value: int, + ): + if ar_size < (ngram_size - 1) * (window_size + num_verifications): + raise ValueError( + "AR length is not enough to meet requirement. " + "Should be at least (ngram_size - 1) * (window_size + num_verifications)." + ) + + self.window_size = window_size + self.ngram_size = ngram_size + self.ngram_pool = self.NgramPool(num_verifications) + self.num_verifications = num_verifications + self.verification_offset = window_size * (ngram_size - 1) + self.ar_size = ar_size + self.mask_value = mask_value + + @property + def attention_mask(self) -> torch.Tensor: + mask = torch.full((self.ar_size,) * 2, self.mask_value) + lookahead_branch_mask = torch.triu( + torch.full((self.window_size,) * 2, self.mask_value), + diagonal=1, + ) + for i in range(self.ngram_size - 1): + mask[ + i * self.window_size : (i + 1) * self.window_size, + : self.window_size, + ] = lookahead_branch_mask + for j in range(1, i + 1): + mask[ + i * self.window_size : (i + 1) * self.window_size, + j * self.window_size : (j + 1) * self.window_size, + ].fill_diagonal_(0) + + verification_branch_mask = torch.triu( + torch.full((self.ngram_size - 1,) * 2, self.mask_value), + diagonal=1, + ) + for i in range(self.num_verifications): + indices = [i * (self.ngram_size - 1), (i + 1) * (self.ngram_size - 1)] + slices = (slice(*[ind + self.verification_offset for ind in indices]),) * 2 + mask[slices] = verification_branch_mask + mask[ + : self.verification_offset + (self.ngram_size - 1) * self.num_verifications, + 0, + ] = 0 + + return mask + + @property + def position_offset(self) -> torch.Tensor: + offsets = torch.zeros(self.ar_size, dtype=torch.int32) + idx = 0 + # lookahead branches + for i in range(self.ngram_size - 1): + for j in range(self.window_size): + offsets[idx] = i + j + idx += 1 + + # verification branches + for _ in range(self.num_verifications): + for j in range(1, self.ngram_size): + offsets[idx] = j + idx += 1 + + return offsets + + def update_verification_branch(self, guess_token: int, inputs: List[int]) -> None: + for branch, ngram in enumerate(self.ngram_pool[guess_token]): + verification_offset = self.verification_offset + branch * ( + self.ngram_size - 1 + ) + for i, token in enumerate(ngram): + inputs[verification_offset + i] = token + + def update_lookahead_branch(self, inputs: List[int], outputs: List[int]) -> None: + # 1 level shifting + for i in range(self.ngram_size - 2): + for j in range(self.window_size): + inputs[self.window_size * i + j] = inputs[ + self.window_size * (i + 1) + j + ] + + last_ngram_offset = self.window_size * (self.ngram_size - 2) + for i in range(self.window_size): + inputs[last_ngram_offset + i] = outputs[last_ngram_offset + i] + + def update_ngram_pool(self, inputs: List[int], outputs: List[int]) -> None: + for i in range(self.window_size): + ngram = [inputs[i]] + for j in range(1, self.ngram_size - 1): + ngram.append(inputs[i + j * self.window_size]) + + ngram.append(outputs[i + self.window_size * (self.ngram_size - 2)]) + self.ngram_pool.add(tuple(ngram)) + + def verify( + self, inputs: List[int], outputs: List[int] + ) -> Tuple[List[int], Optional[int]]: + best_match, branch = [], None + for i in range(self.num_verifications): + current_match = [outputs[0]] + verification_branch_offset = ( + self.verification_offset + (self.ngram_size - 1) * i + ) + for j in range(self.ngram_size - 1): + if inputs[verification_branch_offset + j] == current_match[-1]: + current_match.append(outputs[verification_branch_offset + j]) + else: + break + + if len(current_match[1:]) > len(best_match): + best_match = current_match[1:] + branch = i + + return best_match, branch + + class QnnRunnerEvalWrapper(EagerEvalWrapper): """ A wrapper class to run PPL scores with QNN on device. @@ -248,18 +398,30 @@ def smart_mask_updater( v_caches, new_k_caches, new_v_caches, + # lookahead decoding related + lade_token_offset=None, + lade_pos_offset=None, ): # ar_len is unused in smart mask max_cache_len = k_caches[0].size(-1) + if pos + n_updates <= max_cache_len: - for i, k_cache in enumerate(k_caches): - k_cache[:, :, pos : pos + n_updates] = new_k_caches[i][:, :, :n_updates] + if lade_token_offset is not None: + # lookahead decode update + for i, offset in enumerate(lade_token_offset): + current_pos = pos + i + for j, (k_cache, v_cache) in enumerate(zip(k_caches, v_caches)): + k_cache[:, :, current_pos] = new_k_caches[j][:, :, offset] + v_cache[:, current_pos, :] = new_v_caches[j][:, offset, :] + else: + for i, k_cache in enumerate(k_caches): + k_cache[:, :, pos : pos + n_updates] = new_k_caches[i][:, :, :n_updates] + for i, v_cache in enumerate(v_caches): + v_cache[:, pos : pos + n_updates, :] = new_v_caches[i][:, :n_updates, :] - for i, v_cache in enumerate(v_caches): - v_cache[:, pos : pos + n_updates, :] = new_v_caches[i][:, :n_updates, :] - atten_mask.smart_mask_update(pos, n_updates) - pos += n_updates + atten_mask.smart_mask_update(pos, n_updates, lade_pos_offset) + pos += n_updates return pos, k_caches, v_caches @@ -271,29 +433,51 @@ def shift_pointer_updater( v_caches, new_k_caches, new_v_caches, + # lookahead decoding related + lade_token_offset=None, + lade_pos_offset=None, ): max_cache_len = k_caches[0].size(-1) if pos + n_updates <= max_cache_len: - k_caches = [ - torch.cat( - [k_cache[:, :, n_updates:], new_k_caches[i][:, :, :n_updates]], dim=-1 - ) - for i, k_cache in enumerate(k_caches) - ] - v_caches = [ - torch.cat( - [v_cache[:, n_updates:, :], new_v_caches[i][:, :n_updates, :]], dim=1 - ) - for i, v_cache in enumerate(v_caches) - ] - atten_mask.shift_pointer_update(pos, n_updates) - pos += n_updates + if lade_token_offset is not None: + # lookahead decode update + for offset in lade_token_offset: + for i, (k_cache, v_cache) in enumerate(zip(k_caches, v_caches)): + k_caches[i] = torch.cat( + [ + k_cache[:, :, 1:], + new_k_caches[i][:, :, offset].unsqueeze(-1), + ], + dim=-1, + ) + v_caches[i] = torch.cat( + [v_cache[:, 1:, :], new_v_caches[i][:, offset, :].unsqueeze(1)], + dim=1, + ) + else: + k_caches = [ + torch.cat( + [k_cache[:, :, n_updates:], new_k_caches[i][:, :, :n_updates]], + dim=-1, + ) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat( + [v_cache[:, n_updates:, :], new_v_caches[i][:, :n_updates, :]], + dim=1, + ) + for i, v_cache in enumerate(v_caches) + ] + atten_mask.shift_pointer_update(pos, n_updates, lade_pos_offset) + + pos += n_updates return pos, k_caches, v_caches @register_inference(use_kv_cache=True) -def kv_inference( +def kv_inference( # noqa: C901 get_example_inputs, prompt: Union[str, list], module: torch.fx.GraphModule, @@ -304,6 +488,7 @@ def kv_inference( use_i64_token=False, collect_logits=False, seq_mse_candidates=0, + lookahead_config=None, ): _, atten_mask, _, k_caches, v_caches = get_example_inputs(use_kv_cache=True) @@ -393,46 +578,125 @@ def kv_inference( # When run on wikitext for ppl evaluation, this while-loop is not expected to run. max_cache_len = max_seq_len - ar_len num_tokens = len(total_token_list) - while total_token_list[-1] != tokenizer.eos_id and num_tokens < max_seq_len: - chunk_start_idx = min(pos, max_cache_len) - # Take a chunk of generated tokens, up to ar_len length. - chunk_end_idx = num_tokens - actual_chunk_tokens = total_token_list[chunk_start_idx:chunk_end_idx] - num_tokens_in_chunk = len(actual_chunk_tokens) - - # Prepare tmp_token_list (padded with zeros). - tmp_token_list = torch.zeros((1, ar_len), dtype=dtype) - tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor( - actual_chunk_tokens, dtype=dtype - ) - - # Prepare tmp_pos (padded with zeros). - tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32) - tmp_pos[0, :num_tokens_in_chunk] = all_pos[0, chunk_start_idx:chunk_end_idx] + if lookahead_config is None: + while total_token_list[-1] != tokenizer.eos_id and num_tokens < max_seq_len: + chunk_start_idx = min(pos, max_cache_len) + # Take a chunk of generated tokens, up to ar_len length. + chunk_end_idx = num_tokens + actual_chunk_tokens = total_token_list[chunk_start_idx:chunk_end_idx] + num_tokens_in_chunk = len(actual_chunk_tokens) + + # Prepare tmp_token_list (padded with zeros). + tmp_token_list = torch.zeros((1, ar_len), dtype=dtype) + tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor( + actual_chunk_tokens, dtype=dtype + ) - logits, new_k_caches, new_v_caches = module( - tmp_token_list, - *atten_mask, - tmp_pos, - *k_caches, - *v_caches, - ) - if collect_logits: - result_logits.append(logits[:, :num_tokens_in_chunk]) + # Prepare tmp_pos (padded with zeros). + tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32) + tmp_pos[0, :num_tokens_in_chunk] = all_pos[ + 0, chunk_start_idx:chunk_end_idx + ] + + logits, new_k_caches, new_v_caches = module( + tmp_token_list, + *atten_mask, + tmp_pos, + *k_caches, + *v_caches, + ) - pos, k_caches, v_caches = kv_updater( - 1, - atten_mask, - pos, - k_caches, - v_caches, - new_k_caches, - new_v_caches, + pos, k_caches, v_caches = kv_updater( + 1, + atten_mask, + pos, + k_caches, + v_caches, + new_k_caches, + new_v_caches, + ) + total_token_list.append( + torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item() + ) + num_tokens = len(total_token_list) + else: + # TODO: support batch decode if necessary + # variable declaration + window, ngram, gcap = lookahead_config + lade = LookaheadDecoder( + window_size=window, + ngram_size=ngram, + num_verifications=gcap, + ar_size=ar_len, + mask_value=next(iter(atten_mask)).min().item(), ) - total_token_list.append( - torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item() + generated_tokens, accepted_tokens = 0, 0 + input_tokens = [total_token_list[-1]] * ar_len + pos_offsets = lade.position_offset.unsqueeze(0) + pos_offsets_list = pos_offsets.flatten().tolist() + # replace ar attention mask to lookahead version + for mask in atten_mask: + mask[:, :, -ar_len:] = lade.attention_mask.unsqueeze(0) + # start decoding + while ( + total_token_list[-1] != tokenizer.eos_id + and len(total_token_list) < max_cache_len + ): + # populate verification branch + lade.update_verification_branch( + guess_token=input_tokens[0], + inputs=input_tokens, + ) + # inference + logits, new_k_caches, new_v_caches = module( + torch.tensor(input_tokens, dtype=dtype).unsqueeze(0), + *atten_mask, + pos_offsets + pos, + *k_caches, + *v_caches, + ) + # collect outputs + output_tokens = torch.argmax(logits, dim=-1).flatten().tolist() + # update ngram pool + lade.update_ngram_pool(inputs=input_tokens, outputs=output_tokens) + # try matching verification branches + best_match, branch_no = lade.verify( + inputs=input_tokens, outputs=output_tokens + ) + # check if any match was found + lade_token_offset, num_match = [0], len(best_match) + if num_match > 0: + accepted_tokens += num_match + lade_token_offset += [ + e + lade.verification_offset + branch_no * (ngram - 1) + for e in range(num_match) + ] + # update kv cache + pos, k_caches, v_caches = kv_updater( + len(lade_token_offset), + atten_mask, + pos, + k_caches, + v_caches, + new_k_caches, + new_v_caches, + lade_token_offset, + pos_offsets_list, + ) + generated_tokens += len(lade_token_offset) + # update lookahead branch + lade.update_lookahead_branch(inputs=input_tokens, outputs=output_tokens) + # update token list + for token in [output_tokens[0], *best_match]: + total_token_list.append(token) + if token == tokenizer.eos_id: + break + # fill next input token + input_tokens[0] = total_token_list[-1] + + logging.info( + f"lookahead accepted / total generated: {accepted_tokens} / {generated_tokens}" ) - num_tokens = len(total_token_list) logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}") if collect_logits: @@ -514,6 +778,7 @@ def graph_module_inference( use_i64_token=False, event_name: Optional[str] = None, seq_mse_candidates: int = 0, + lookahead_config: Optional[Tuple[int]] = None, ): """ This function supports model execution from static nn.Module decoder model @@ -529,6 +794,8 @@ def graph_module_inference( if use_kv_cache: kwargs["ar_len"] = ar_len kwargs["kv_updater"] = kv_updater + kwargs["lookahead_config"] = lookahead_config + INFERENCE_REGISTRY[use_kv_cache]( get_example_inputs, prompt, diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 273829d214e..ae5ae63d509 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -223,6 +223,7 @@ def quantize( custom_annotations=(), scales_state_dict=None, chat_template=None, + lookahead_config=None, ): self.quant_dtype = quant_dtype quantizer = make_custom_quantizer( @@ -290,6 +291,7 @@ def quantize( prompt=prompt, use_i64_token=args.embedding_quantize is not None, event_name="prepare_pt2e_prompt", + lookahead_config=lookahead_config, ) if scales_state_dict: set_scales( @@ -336,6 +338,7 @@ def quantize( prompt=prompt, use_i64_token=args.embedding_quantize is not None, event_name="convert_pt2e_prompt", + lookahead_config=lookahead_config, ) def save_logits_quant_attrs(self): @@ -497,13 +500,6 @@ def compile( ) ) elif args.model_mode == "lookahead": - # TODO: Lookahead decoding is not yet supported for gemma3-1b. - # This will be implemented once the model architecture and KV update logic are adapted. - if args.decoder_model == "gemma3-1b": - raise NotImplementedError( - "gemma3-1b does not currently support lookahead decoding." - ) - llama_instance_list.append( LLM_VARIANT_ARCHS.get(args.decoder_model, LlamaModel)( kv_config, @@ -697,6 +693,11 @@ def permute(w, heads): custom_annotations = decoder_model_config.custom_annotation kv_quant_attrs = {} for i, llama_instance in enumerate(llama_instance_list): + lookahead_config = ( + (args.window, args.ngram, args.gcap) + if i == 0 and args.model_mode == "lookahead" + else None + ) llama_instance.quantize( quant_dtype=quant_dtype, args=args, @@ -704,6 +705,7 @@ def permute(w, heads): custom_annotations=custom_annotations, scales_state_dict=scales_state_dict, chat_template=chat_template, + lookahead_config=lookahead_config, ) # If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later if i == 0 and args.model_mode in ["hybrid", "lookahead"]: diff --git a/examples/qualcomm/oss_scripts/llama/masking_utils.py b/examples/qualcomm/oss_scripts/llama/masking_utils.py index 8d9d9ead154..0031f468802 100644 --- a/examples/qualcomm/oss_scripts/llama/masking_utils.py +++ b/examples/qualcomm/oss_scripts/llama/masking_utils.py @@ -93,24 +93,26 @@ def mask(self) -> torch.Tensor: pass @abstractmethod - def smart_mask_update(self, pos, n_updates): + def smart_mask_update(self, pos, n_updates, lade_pos_offset): """ Update the attention mask by smart mask update method after model forward. Args: pos (int): Current position in the sequence. n_updates (int): Number of new tokens to update. + lade_pos_offset (List[int]): Position offset of lookahead attention mask. """ pass @abstractmethod - def shift_pointer_update(self, pos, n_updates): + def shift_pointer_update(self, pos, n_updates, lade_pos_offset): """ Update the attention mask by shift pointer update method after model forward. Args: pos (int): Current position in the sequence. n_updates (int): Number of tokens to shift. + lade_pos_offset (List[int]): Position offset of lookahead attention mask. """ pass @@ -124,7 +126,7 @@ def __init__(self, max_batch_size: int, ar_len: int, max_seq_len: int): def mask(self): return self._mask - def smart_mask_update(self, pos, n_updates): + def smart_mask_update(self, pos, n_updates, _): """ Smart Mask mechanism for attention mask updating @@ -159,7 +161,7 @@ def smart_mask_update(self, pos, n_updates): end_pos = pos + n_updates self.mask[:, :, start_pos:end_pos] = 0 - def shift_pointer_update(self, pos, n_updates): + def shift_pointer_update(self, pos, n_updates, _): """ Shift Pointer mechanism for attention mask updating @@ -173,7 +175,7 @@ def shift_pointer_update(self, pos, n_updates): 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ○ 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ● - After 1st update (e.g., pos=0, n_updates=5, sliding_window=3): + After 1st update (e.g., pos=0, n_updates=5): Newly added tokens are unmasked (set to 0). 0 ○ ○ ○ ○ ○ ● ● ● ● ● ● ○ ○ ○ ○ @@ -213,7 +215,7 @@ def __init__( def mask(self): return self._mask - def smart_mask_update(self, pos, n_updates): + def smart_mask_update(self, pos, n_updates, lade_pos_offset): """ Smart Mask mechanism for attention mask updating @@ -237,7 +239,8 @@ def smart_mask_update(self, pos, n_updates): 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● - After 2nd update (e.g., pos=5, n_updates=5): + + After 2nd update (e.g., pos=5, n_updates=5, sliding_window=3): Sliding window shifts again, masking older positions and activate new postion. 0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○ @@ -252,16 +255,18 @@ def smart_mask_update(self, pos, n_updates): self.mask[:, :, start_pos:end_pos] = 0 for i in range(self.ar_len): - # Calculate how many cached tokens are still avalible for this row - avalible_cache_len = self.sliding_window - (i + 1) + # Calculate how many cached tokens are still available for this row + available_cache_len = self.sliding_window - ( + (i + 1) if lade_pos_offset is None else (lade_pos_offset[i] + 1) + ) # If the current position exceeds available cache, mask the overflow - if end_pos > avalible_cache_len: + if end_pos > available_cache_len: # Mask tokens that are no longer within the sliding window # TODO: [Optional]: it can be optimized by computing the exact start index - self.mask[:, i, : end_pos - avalible_cache_len] = -255.0 + self.mask[:, i, : end_pos - available_cache_len] = -255.0 - def shift_pointer_update(self, pos, n_updates): + def shift_pointer_update(self, pos, n_updates, lade_pos_offset): """ Shift Pointer mechanism for attention mask updating @@ -283,7 +288,7 @@ def shift_pointer_update(self, pos, n_updates): 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● - After 2nd update (e.g., pos=5, n_updates=5): + After 2nd update (e.g., pos=5, n_updates=5, sliding_window=3): 0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○ 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ @@ -297,14 +302,16 @@ def shift_pointer_update(self, pos, n_updates): self.mask[:, :, start_pos:end_pos] = 0 for i in range(self.ar_len): - avalible_cache_len = self.sliding_window - (i + 1) - if abs(start_pos + self.ar_len) > avalible_cache_len: + available_cache_len = self.sliding_window - ( + (i + 1) if lade_pos_offset is None else (lade_pos_offset[i] + 1) + ) + if abs(start_pos + self.ar_len) > available_cache_len: self.mask[ :, i, start_pos : start_pos + abs(start_pos + self.ar_len) - - avalible_cache_len, + - available_cache_len, ] = -255.0 @@ -312,13 +319,13 @@ class AttentionMask: def __init__(self, masks: Union[BaseAttentionMask, List[BaseAttentionMask]]): self.masks = masks if isinstance(masks, list) else [masks] - def smart_mask_update(self, pos, n_updates): + def smart_mask_update(self, pos, n_updates, lade_pos_offset=None): for mask in self.masks: - mask.smart_mask_update(pos, n_updates) + mask.smart_mask_update(pos, n_updates, lade_pos_offset) - def shift_pointer_update(self, pos, n_updates): + def shift_pointer_update(self, pos, n_updates, lade_pos_offset=None): for mask in self.masks: - mask.shift_pointer_update(pos, n_updates) + mask.shift_pointer_update(pos, n_updates, lade_pos_offset) def __iter__(self): return iter([mask.mask for mask in self.masks]) diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp index a049b54abb6..7a96882416e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp @@ -122,7 +122,8 @@ void KVManager::init_attention_mask( const std::vector& attention_map, int32_t ar_len, int32_t n_past, - int32_t sliding_window) { + int32_t sliding_window, + const std::vector& position_offset) { ET_CHECK_MSG( attention_map.size() <= ar_len, "The size of attention_map (%zu) doesn't match with ar_len (%d)", @@ -154,11 +155,12 @@ void KVManager::init_attention_mask( } // Attend to itself new_ptr[i] = pos_val; - // mask by limitation of sliding_window - int32_t avalible_context_len = sliding_window - (i + 1) - n_past; - if (n_past > avalible_context_len) { - std::fill_n(past_ptr, n_past - avalible_context_len, neg_val); + int32_t available_context_len = position_offset.empty() + ? sliding_window - (i + 1) - n_past + : sliding_window - (position_offset[i] + 1) - n_past; + if (n_past > available_context_len) { + std::fill_n(past_ptr, n_past - available_context_len, neg_val); } past_ptr += metadata_.context_len; @@ -219,7 +221,8 @@ void KVManager::update_attention_mask( int32_t ar_len, int32_t n_past, int32_t n_update, - int32_t sliding_window) { + int32_t sliding_window, + const std::vector& position_offset) { uint16_t pos_val = 65535; uint16_t neg_val = 0; uint16_t* cur_ptr = attention_mask; @@ -230,21 +233,22 @@ void KVManager::update_attention_mask( for (int i = 0; i < ar_len; i++) { std::fill_n(cur_ptr, n_update, pos_val); - int32_t avalible_cache_len = sliding_window - (i + 1); + int32_t available_cache_len = position_offset.empty() + ? sliding_window - (i + 1) + : sliding_window - (position_offset[i] + 1); if (kv_updater_ == KVManagerMode::SMART_MASK) { - if (n_past + n_update > avalible_cache_len) { + if (n_past + n_update > available_cache_len) { std::fill_n( - cur_ptr - n_past, n_past + n_update - avalible_cache_len, neg_val); + cur_ptr - n_past, n_past + n_update - available_cache_len, neg_val); } } else if (kv_updater_ == KVManagerMode::SHIFT_POINTER) { - if (std::abs(n_past + ar_len) > avalible_cache_len) { - int32_t n_invalid = n_past - avalible_cache_len; + if (std::abs(n_past + ar_len) > available_cache_len) { + int32_t n_invalid = n_past - available_cache_len; std::fill_n( - cur_ptr, std::abs(n_past + ar_len) - avalible_cache_len, neg_val); + cur_ptr, std::abs(n_past + ar_len) - available_cache_len, neg_val); } - - cur_ptr += metadata_.context_len; } + cur_ptr += metadata_.context_len; } } diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h index af9cf49a34f..ca24166aa9c 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h @@ -95,13 +95,17 @@ class KVManager { * of attention map should be [ar_len]. * @param ar_len Length of input tokens. * @param n_past Number of past elements in the cache. + * @param sliding_window Length of sliding window for sliding window attention + * mask + * @param position_offset (optional) attention mask position offset of */ void init_attention_mask( uint16_t* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past, - int32_t sliding_window); + int32_t sliding_window, + const std::vector& position_offset = {}); /** * @brief Update attention mask based on kv manager mode, and n_update. @@ -126,13 +130,16 @@ class KVManager { * @param n_update Number of elements to be updated. * @param sliding_window Length of sliding window for sliding window attention * mask + * @param position_offset (optional) attention mask position offset of + * lookahead decoder */ void update_attention_mask( uint16_t* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update, - int32_t sliding_window); + int32_t sliding_window, + const std::vector& position_offset = {}); /** * @brief Reset the data pointer of the I/O cache tensor based on number of diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp index 1692caa2756..96a25e9c935 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp @@ -61,6 +61,16 @@ void LhdTokenGenerator::init_attention_mask(int32_t n_past) { this->kv_manager_->init_attention_mask( this->attention_mask_.data, attention_map, metadata_.ar_len, n_past); + // Initialize window attention mask with current position + if (metadata_.cache_mode == CacheMode::HybridCache) { + this->kv_manager_->init_attention_mask( + this->window_attention_mask_.data, + attention_map, + metadata_.ar_len, + n_past, + metadata_.sliding_window, + position_offset_); + } } template @@ -378,6 +388,15 @@ Result LhdTokenGenerator::generate( // Update attention mask with current position this->kv_manager_->update_attention_mask( this->attention_mask_.data, metadata_.ar_len, prev_pos, n_update); + if (metadata_.cache_mode == CacheMode::HybridCache) { + this->kv_manager_->update_attention_mask( + this->window_attention_mask_.data, + metadata_.ar_len, + prev_pos, + n_update, + metadata_.sliding_window, + position_offset_); + } // data-dependent terminating condition: we have n_eos_ number of EOS if (this->eos_ids_->count(cur_token) > 0) { diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h index fe5e4b49230..cf4c55d9f2c 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h @@ -29,6 +29,7 @@ class LhdTokenGenerator : public TokenGenerator { int32_t window; int32_t gcap; int sliding_window; + CacheMode cache_mode; }; LhdTokenGenerator( tokenizers::Tokenizer* tokenizer, @@ -51,7 +52,8 @@ class LhdTokenGenerator : public TokenGenerator { metadata.ar_len, metadata.vocab_size, metadata.use_int64_token, - metadata.sliding_window}, + metadata.sliding_window, + metadata.cache_mode}, stats), metadata_(metadata), lhd_branch_(metadata.ngram - 1, std::vector(metadata.window)), @@ -63,6 +65,22 @@ class LhdTokenGenerator : public TokenGenerator { metadata.ngram, metadata.window, metadata.gcap); + + // initialize position offset + position_offset_ = std::vector(metadata.ar_len); + int idx = 0; + // lookahead branches + for (int i = 0; i < metadata.ngram - 1; ++i) { + for (int j = 0; j < metadata.window; ++j) { + position_offset_[idx++] = i + j; + } + } + // verification branches + for (int i = 0; i < metadata.gcap; ++i) { + for (int j = 1; j < metadata.ngram; ++j) { + position_offset_[idx++] = j; + } + } } ~LhdTokenGenerator() = default; @@ -136,6 +154,9 @@ class LhdTokenGenerator : public TokenGenerator { // verification branch std::vector v_branch_; + // position offset in attention mask + std::vector position_offset_; + // n-gram pools NgramContainer ngrams_pool_; }; diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index fc4ff006a90..fe45d4b6a67 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -285,12 +285,6 @@ Error Runner::load() { sliding_window, cache_mode_}); if (eval_mode_ == EvalMode::kLookaheadDecoding) { - // TODO: sliding window attention will be supported in future. - if (sliding_window < context_len_) { - ET_CHECK_MSG( - false, - "Lookahead decoding (eval_mode == 2) is not yet supported for sliding window attention."); - } token_generator_ = std::make_unique>( tokenizer_.get(), decoder_runner_.get(), @@ -307,7 +301,8 @@ Error Runner::load() { ngram_, window_, gcap_, - sliding_window}, + sliding_window, + cache_mode_}, &stats_); } else { token_generator_ = std::make_unique>(