Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
383 changes: 325 additions & 58 deletions examples/qualcomm/oss_scripts/llama/decoder_utils.py

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def quantize(
custom_annotations=(),
scales_state_dict=None,
chat_template=None,
lookahead_config=None,
):
self.quant_dtype = quant_dtype
quantizer = make_custom_quantizer(
Expand Down Expand Up @@ -290,6 +291,7 @@ def quantize(
prompt=prompt,
use_i64_token=args.embedding_quantize is not None,
event_name="prepare_pt2e_prompt",
lookahead_config=lookahead_config,
)
if scales_state_dict:
set_scales(
Expand Down Expand Up @@ -336,6 +338,7 @@ def quantize(
prompt=prompt,
use_i64_token=args.embedding_quantize is not None,
event_name="convert_pt2e_prompt",
lookahead_config=lookahead_config,
)

def save_logits_quant_attrs(self):
Expand Down Expand Up @@ -497,13 +500,6 @@ def compile(
)
)
elif args.model_mode == "lookahead":
# TODO: Lookahead decoding is not yet supported for gemma3-1b.
# This will be implemented once the model architecture and KV update logic are adapted.
if args.decoder_model == "gemma3-1b":
raise NotImplementedError(
"gemma3-1b does not currently support lookahead decoding."
)

llama_instance_list.append(
LLM_VARIANT_ARCHS.get(args.decoder_model, LlamaModel)(
kv_config,
Expand Down Expand Up @@ -697,13 +693,19 @@ def permute(w, heads):
custom_annotations = decoder_model_config.custom_annotation
kv_quant_attrs = {}
for i, llama_instance in enumerate(llama_instance_list):
lookahead_config = (
(args.window, args.ngram, args.gcap)
if i == 0 and args.model_mode == "lookahead"
else None
)
llama_instance.quantize(
quant_dtype=quant_dtype,
args=args,
tokenizer=tokenizer,
custom_annotations=custom_annotations,
scales_state_dict=scales_state_dict,
chat_template=chat_template,
lookahead_config=lookahead_config,
)
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:
Expand Down
47 changes: 27 additions & 20 deletions examples/qualcomm/oss_scripts/llama/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,26 @@ def mask(self) -> torch.Tensor:
pass

@abstractmethod
def smart_mask_update(self, pos, n_updates):
def smart_mask_update(self, pos, n_updates, lade_pos_offset):
"""
Update the attention mask by smart mask update method after model forward.

Args:
pos (int): Current position in the sequence.
n_updates (int): Number of new tokens to update.
lade_pos_offset (List[int]): Position offset of lookahead attention mask.
"""
pass

@abstractmethod
def shift_pointer_update(self, pos, n_updates):
def shift_pointer_update(self, pos, n_updates, lade_pos_offset):
"""
Update the attention mask by shift pointer update method after model forward.

Args:
pos (int): Current position in the sequence.
n_updates (int): Number of tokens to shift.
lade_pos_offset (List[int]): Position offset of lookahead attention mask.
"""
pass

Expand All @@ -124,7 +126,7 @@ def __init__(self, max_batch_size: int, ar_len: int, max_seq_len: int):
def mask(self):
return self._mask

def smart_mask_update(self, pos, n_updates):
def smart_mask_update(self, pos, n_updates, _):
"""
Smart Mask mechanism for attention mask updating

Expand Down Expand Up @@ -159,7 +161,7 @@ def smart_mask_update(self, pos, n_updates):
end_pos = pos + n_updates
self.mask[:, :, start_pos:end_pos] = 0

def shift_pointer_update(self, pos, n_updates):
def shift_pointer_update(self, pos, n_updates, _):
"""
Shift Pointer mechanism for attention mask updating

Expand All @@ -173,7 +175,7 @@ def shift_pointer_update(self, pos, n_updates):
3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ○
4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ●

After 1st update (e.g., pos=0, n_updates=5, sliding_window=3):
After 1st update (e.g., pos=0, n_updates=5):
Newly added tokens are unmasked (set to 0).

0 ○ ○ ○ ○ ○ ● ● ● ● ● ● ○ ○ ○ ○
Expand Down Expand Up @@ -213,7 +215,7 @@ def __init__(
def mask(self):
return self._mask

def smart_mask_update(self, pos, n_updates):
def smart_mask_update(self, pos, n_updates, lade_pos_offset):
"""
Smart Mask mechanism for attention mask updating

Expand All @@ -237,7 +239,8 @@ def smart_mask_update(self, pos, n_updates):
3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○
4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ●

After 2nd update (e.g., pos=5, n_updates=5):

After 2nd update (e.g., pos=5, n_updates=5, sliding_window=3):
Sliding window shifts again, masking older positions and activate new postion.

0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○
Expand All @@ -252,16 +255,18 @@ def smart_mask_update(self, pos, n_updates):
self.mask[:, :, start_pos:end_pos] = 0

for i in range(self.ar_len):
# Calculate how many cached tokens are still avalible for this row
avalible_cache_len = self.sliding_window - (i + 1)
# Calculate how many cached tokens are still available for this row
available_cache_len = self.sliding_window - (
(i + 1) if lade_pos_offset is None else (lade_pos_offset[i] + 1)
)

# If the current position exceeds available cache, mask the overflow
if end_pos > avalible_cache_len:
if end_pos > available_cache_len:
# Mask tokens that are no longer within the sliding window
# TODO: [Optional]: it can be optimized by computing the exact start index
self.mask[:, i, : end_pos - avalible_cache_len] = -255.0
self.mask[:, i, : end_pos - available_cache_len] = -255.0

def shift_pointer_update(self, pos, n_updates):
def shift_pointer_update(self, pos, n_updates, lade_pos_offset):
"""
Shift Pointer mechanism for attention mask updating

Expand All @@ -283,7 +288,7 @@ def shift_pointer_update(self, pos, n_updates):
3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○
4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ●

After 2nd update (e.g., pos=5, n_updates=5):
After 2nd update (e.g., pos=5, n_updates=5, sliding_window=3):

0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○
1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○
Expand All @@ -297,28 +302,30 @@ def shift_pointer_update(self, pos, n_updates):
self.mask[:, :, start_pos:end_pos] = 0

for i in range(self.ar_len):
avalible_cache_len = self.sliding_window - (i + 1)
if abs(start_pos + self.ar_len) > avalible_cache_len:
available_cache_len = self.sliding_window - (
(i + 1) if lade_pos_offset is None else (lade_pos_offset[i] + 1)
)
if abs(start_pos + self.ar_len) > available_cache_len:
self.mask[
:,
i,
start_pos : start_pos
+ abs(start_pos + self.ar_len)
- avalible_cache_len,
- available_cache_len,
] = -255.0


class AttentionMask:
def __init__(self, masks: Union[BaseAttentionMask, List[BaseAttentionMask]]):
self.masks = masks if isinstance(masks, list) else [masks]

def smart_mask_update(self, pos, n_updates):
def smart_mask_update(self, pos, n_updates, lade_pos_offset=None):
for mask in self.masks:
mask.smart_mask_update(pos, n_updates)
mask.smart_mask_update(pos, n_updates, lade_pos_offset)

def shift_pointer_update(self, pos, n_updates):
def shift_pointer_update(self, pos, n_updates, lade_pos_offset=None):
for mask in self.masks:
mask.shift_pointer_update(pos, n_updates)
mask.shift_pointer_update(pos, n_updates, lade_pos_offset)

def __iter__(self):
return iter([mask.mask for mask in self.masks])
29 changes: 17 additions & 12 deletions examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ void KVManager<T>::init_attention_mask(
const std::vector<int32_t>& attention_map,
int32_t ar_len,
int32_t n_past,
int32_t sliding_window) {
int32_t sliding_window,
const std::vector<int32_t>& position_offset) {
ET_CHECK_MSG(
attention_map.size() <= ar_len,
"The size of attention_map (%zu) doesn't match with ar_len (%d)",
Expand Down Expand Up @@ -154,11 +155,12 @@ void KVManager<T>::init_attention_mask(
}
// Attend to itself
new_ptr[i] = pos_val;

// mask by limitation of sliding_window
int32_t avalible_context_len = sliding_window - (i + 1) - n_past;
if (n_past > avalible_context_len) {
std::fill_n(past_ptr, n_past - avalible_context_len, neg_val);
int32_t available_context_len = position_offset.empty()
? sliding_window - (i + 1) - n_past
: sliding_window - (position_offset[i] + 1) - n_past;
if (n_past > available_context_len) {
std::fill_n(past_ptr, n_past - available_context_len, neg_val);
}

past_ptr += metadata_.context_len;
Expand Down Expand Up @@ -219,7 +221,8 @@ void KVManager<T>::update_attention_mask(
int32_t ar_len,
int32_t n_past,
int32_t n_update,
int32_t sliding_window) {
int32_t sliding_window,
const std::vector<int32_t>& position_offset) {
uint16_t pos_val = 65535;
uint16_t neg_val = 0;
uint16_t* cur_ptr = attention_mask;
Expand All @@ -230,17 +233,19 @@ void KVManager<T>::update_attention_mask(

for (int i = 0; i < ar_len; i++) {
std::fill_n(cur_ptr, n_update, pos_val);
int32_t avalible_cache_len = sliding_window - (i + 1);
int32_t available_cache_len = position_offset.empty()
? sliding_window - (i + 1)
: sliding_window - (position_offset[i] + 1);
if (kv_updater_ == KVManagerMode::SMART_MASK) {
if (n_past + n_update > avalible_cache_len) {
if (n_past + n_update > available_cache_len) {
std::fill_n(
cur_ptr - n_past, n_past + n_update - avalible_cache_len, neg_val);
cur_ptr - n_past, n_past + n_update - available_cache_len, neg_val);
}
} else if (kv_updater_ == KVManagerMode::SHIFT_POINTER) {
if (std::abs(n_past + ar_len) > avalible_cache_len) {
int32_t n_invalid = n_past - avalible_cache_len;
if (std::abs(n_past + ar_len) > available_cache_len) {
int32_t n_invalid = n_past - available_cache_len;
std::fill_n(
cur_ptr, std::abs(n_past + ar_len) - avalible_cache_len, neg_val);
cur_ptr, std::abs(n_past + ar_len) - available_cache_len, neg_val);
}
}
cur_ptr += metadata_.context_len;
Expand Down
11 changes: 9 additions & 2 deletions examples/qualcomm/oss_scripts/llama/runner/kv_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,17 @@ class KVManager {
* of attention map should be [ar_len].
* @param ar_len Length of input tokens.
* @param n_past Number of past elements in the cache.
* @param sliding_window Length of sliding window for sliding window attention
* mask
* @param position_offset (optional) attention mask position offset of
*/
void init_attention_mask(
uint16_t* attention_mask,
const std::vector<int32_t>& attention_map,
int32_t ar_len,
int32_t n_past,
int32_t sliding_window);
int32_t sliding_window,
const std::vector<int32_t>& position_offset = {});

/**
* @brief Update attention mask based on kv manager mode, and n_update.
Expand All @@ -126,13 +130,16 @@ class KVManager {
* @param n_update Number of elements to be updated.
* @param sliding_window Length of sliding window for sliding window attention
* mask
* @param position_offset (optional) attention mask position offset of
* lookahead decoder
*/
void update_attention_mask(
uint16_t* attention_mask,
int32_t ar_len,
int32_t n_past,
int32_t n_update,
int32_t sliding_window);
int32_t sliding_window,
const std::vector<int32_t>& position_offset = {});

/**
* @brief Reset the data pointer of the I/O cache tensor based on number of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ void LhdTokenGenerator<T>::init_attention_mask(int32_t n_past) {

this->kv_manager_->init_attention_mask(
this->attention_mask_.data, attention_map, metadata_.ar_len, n_past);
// Initialize window attention mask with current position
if (metadata_.cache_mode == CacheMode::HybridCache) {
this->kv_manager_->init_attention_mask(
this->window_attention_mask_.data,
attention_map,
metadata_.ar_len,
n_past,
metadata_.sliding_window,
position_offset_);
}
}

template <typename T>
Expand Down Expand Up @@ -378,6 +388,15 @@ Result<int64_t> LhdTokenGenerator<T>::generate(
// Update attention mask with current position
this->kv_manager_->update_attention_mask(
this->attention_mask_.data, metadata_.ar_len, prev_pos, n_update);
if (metadata_.cache_mode == CacheMode::HybridCache) {
this->kv_manager_->update_attention_mask(
this->window_attention_mask_.data,
metadata_.ar_len,
prev_pos,
n_update,
metadata_.sliding_window,
position_offset_);
}

// data-dependent terminating condition: we have n_eos_ number of EOS
if (this->eos_ids_->count(cur_token) > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class LhdTokenGenerator : public TokenGenerator<T> {
int32_t window;
int32_t gcap;
int sliding_window;
CacheMode cache_mode;
};
LhdTokenGenerator(
tokenizers::Tokenizer* tokenizer,
Expand All @@ -51,7 +52,8 @@ class LhdTokenGenerator : public TokenGenerator<T> {
metadata.ar_len,
metadata.vocab_size,
metadata.use_int64_token,
metadata.sliding_window},
metadata.sliding_window,
metadata.cache_mode},
stats),
metadata_(metadata),
lhd_branch_(metadata.ngram - 1, std::vector<int32_t>(metadata.window)),
Expand All @@ -63,6 +65,22 @@ class LhdTokenGenerator : public TokenGenerator<T> {
metadata.ngram,
metadata.window,
metadata.gcap);

// initialize position offset
position_offset_ = std::vector<int32_t>(metadata.ar_len);
int idx = 0;
// lookahead branches
for (int i = 0; i < metadata.ngram - 1; ++i) {
for (int j = 0; j < metadata.window; ++j) {
position_offset_[idx++] = i + j;
}
}
// verification branches
for (int i = 0; i < metadata.gcap; ++i) {
for (int j = 1; j < metadata.ngram; ++j) {
position_offset_[idx++] = j;
}
}
}

~LhdTokenGenerator() = default;
Expand Down Expand Up @@ -136,6 +154,9 @@ class LhdTokenGenerator : public TokenGenerator<T> {
// verification branch
std::vector<NgramData> v_branch_;

// position offset in attention mask
std::vector<int32_t> position_offset_;

// n-gram pools
NgramContainer ngrams_pool_;
};
Expand Down
Loading
Loading