Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
derdeljan-msft committed Mar 4, 2025
1 parent 30c6825 commit 27c655d
Showing 4 changed files with 323 additions and 124 deletions.
53 changes: 48 additions & 5 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
@@ -34,6 +34,8 @@ class GQAAttentionBase {
use_smooth_softmax_ = info.GetAttrOrDefault<int64_t>("smooth_softmax", 0) == 1;

local_window_size_ = has_local ? static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1)) : -1;

tree_attention_ = info.GetAttrOrDefault<int64_t>("tree_attention", 0) == 1;
}

int num_heads_; // number of attention heads of Q
@@ -43,13 +45,15 @@ class GQAAttentionBase {
bool do_rotary_; // whether or not to use rotary embeddings
bool rotary_interleaved_;
int local_window_size_;
bool tree_attention_;

bool use_smooth_softmax_;

template <typename T>
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
const T* K, // K data with shape BxN_kvxSxH
const T* V, // V data with shape BxN_kvxSxH
const T* attention_mask, // Causal attention mask to apply before softmax
const Tensor* past_key, // past K input tensor (if not using past state)
const Tensor* past_value, // past V input tensor (if not using past state)
Tensor* output, // output tensor
@@ -92,8 +96,8 @@ class GQAAttentionBase {
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;

if (gqa_mlas_supported) {
ComputeAttentionProbs(static_cast<T*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), batch_size,
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
ComputeAttentionProbs(static_cast<T*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), attention_mask, batch_size,
sequence_length, parameters.total_sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);

// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
@@ -104,9 +108,10 @@ class GQAAttentionBase {
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
} else {
ComputeAttentionProbs(static_cast<float*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), batch_size,
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
ComputeAttentionProbs(static_cast<float*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), attention_mask,
batch_size, sequence_length, parameters.total_sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp,
allocator);

// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
@@ -130,8 +135,10 @@ class GQAAttentionBase {
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
const T* attention_mask,
const size_t batch_size, // batch size of self-attention
const size_t sequence_length, // sequence length of self-attention (S)
const size_t total_sequence_length, // max total sequence length in batch
const size_t past_buffer_sequence_length, // sequence length of past state
const size_t present_buffer_sequence_length, // sequence length of present state
const size_t head_size, // head size of self-attention
@@ -189,6 +196,11 @@ class GQAAttentionBase {
const ptrdiff_t output_offset = SafeInt<ptrdiff_t>(i) * sequence_length * present_buffer_sequence_length;
U* output = attention_probs + output_offset;

const ptrdiff_t attention_mask_offset = SafeInt<ptrdiff_t>(batch_index) * sequence_length * total_sequence_length;
const T* attention_mask_batch = attention_mask != nullptr ? attention_mask + attention_mask_offset : nullptr;

// std::cout << "Batch_index: " << batch_index << ", sequence_length: " << sequence_length << ", present_buffer_sequence_length: " << present_buffer_sequence_length << std::endl;

const T* k;
if (packed_qkv) {
k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
@@ -242,6 +254,33 @@ class GQAAttentionBase {
U* output_softmax = output;
for (size_t seq = 0; seq < sequence_length; seq++) {
size_t seq_causal_length = past_seqlen + seq + 1;

// TODO: Vectorize this addition
if (nullptr != attention_mask_batch) {
if (batch_index == 0 && head_index == 0) {
std::cout << "attention mask offset: " << attention_mask_offset <<std::endl;
std::cout << "ptr diff " << (attention_mask_batch - attention_mask) << std::endl;
std::cout << "total sequence length batch " << total_seqlen << std::endl;
std::cout << "total sequence length " << total_sequence_length << std::endl;
std::cout << "past sequence length " << past_seqlen << std::endl;
std::cout << "sequence length: " << sequence_length << std::endl;
// std::cout << "==========================================" << std::endl;
// for (size_t i = 0; i < seq_causal_length; i++) {
// std::cout << attention_mask_batch[i] << " ";
// }
// std::cout << std::endl;
std::cout << "==========================================" << std::endl;
}

for (size_t i = 0; i < seq_causal_length; i++) {
if constexpr (std::is_same<U, float>::value) {
output_softmax[i] += static_cast<float>(attention_mask_batch[i]);
} else {
output_softmax[i] = MLFloat16(output_softmax[i].ToFloat() + attention_mask[i].ToFloat());
}
}
}

if (local_window_size_ > 0 && seq_causal_length > static_cast<size_t>(local_window_size_) + 1) {
for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
if constexpr (std::is_same<U, float>::value) {
@@ -283,6 +322,10 @@ class GQAAttentionBase {
}

output_softmax += present_buffer_sequence_length;

if (nullptr != attention_mask_batch) {
attention_mask_batch += total_seqlen;
}
}
}
});
34 changes: 31 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
@@ -53,6 +53,9 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
const Tensor* cos_cache = context->Input<Tensor>(7);
const Tensor* sin_cache = context->Input<Tensor>(8);

const Tensor* tree_pos_ids = context->Input<Tensor>(9);
const Tensor* tree_causal_attention_mask = context->Input<Tensor>(10);

GroupQueryAttentionParameters parameters = {};
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
key,
@@ -130,7 +133,12 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
// Generate position ids
const int pos_ids_size = parameters.is_first_prompt ? 1 : batch_size * sequence_length;
std::vector<int64_t> pos_ids(pos_ids_size);
if (parameters.is_first_prompt) {
const int64_t* pos_ids_data = pos_ids.data();

if (tree_attention_) {
ORT_RETURN_IF_NOT(pos_ids_size == tree_pos_ids->Shape()[0] * tree_pos_ids->Shape()[1]);
pos_ids_data = tree_pos_ids->Data<int64_t>();
} else if (parameters.is_first_prompt) {
pos_ids[0] = static_cast<int64_t>(0);
} else {
// Note: As of now, continuous decoding supports only batch size 1 and token generation supports only sequence length 1.
@@ -146,6 +154,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
}
}
}

// Initialize separate buffers for rotary embeddings
const T* q_input;
const T* k_input;
@@ -165,7 +174,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
}
// Run rotary embedding for Q and K
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, q_input,
pos_ids.data(), cos_cache->Data<T>(),
pos_ids_data, cos_cache->Data<T>(),
sin_cache->Data<T>(), q_rotary, rotary_interleaved_));

rotary_params.num_heads = kv_num_heads_;
@@ -174,7 +183,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
rotary_params.batch_stride = kv_num_heads_ * rotary_params.head_stride;
}
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, k_input,
pos_ids.data(), cos_cache->Data<T>(),
pos_ids_data, cos_cache->Data<T>(),
sin_cache->Data<T>(), k_rotary, rotary_interleaved_));
// Pack V into rotary QKV buffer
if (packed_qkv) {
@@ -192,8 +201,27 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
}

ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

if constexpr (std::is_same<T, float>::value) {
if (tree_attention_) {
const T* ptr = tree_causal_attention_mask->Data<T>();

const size_t total_els = tree_causal_attention_mask->Shape()[0] * tree_causal_attention_mask->Shape()[1] * tree_causal_attention_mask->Shape()[2];

std::cout << "Total elements: " << total_els << std::endl;

bool all_zeros = true;
for (size_t i = 0; i < total_els; i++) {
all_zeros &= (ptr[i] == 0);
}

std::cout << "All elements are zeros: " << all_zeros << std::endl;
}
}

// Compute the attention score and apply the score to V
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
tree_attention_ ? tree_causal_attention_mask->Data<T>() : nullptr,
past_key, past_value, output, present_k, present_v,
seqlens_k, parameters, allocator, context);
}
14 changes: 14 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
@@ -1082,6 +1082,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Use a smooth factor in softmax.",
AttributeProto::INT,
static_cast<int64_t>(-1))
.Attr("tree_attention",
"Provide custom position IDs and causal tree attention mask. Default value is 0 (False).",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape"
@@ -1128,6 +1132,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"2D tensor with shape (max_sequence_length, head_size / 2).",
"T",
OpSchema::Optional)
.Input(9,
"tree_pos_ids",
"2D tensor with shape (batch_size, sequence_length).",
"tensor(int64)",
OpSchema::Optional)
.Input(10,
"tree_causal_attention_mask",
"3D tensor with shape (batch_size, sequence_length, total_sequence_length)",
"T",
OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",
Loading
Oops, something went wrong.

0 comments on commit 27c655d

Please sign in to comment.