@@ -34,6 +34,8 @@ class GQAAttentionBase {
34
34
use_smooth_softmax_ = info.GetAttrOrDefault <int64_t >(" smooth_softmax" , 0 ) == 1 ;
35
35
36
36
local_window_size_ = has_local ? static_cast <int >(info.GetAttrOrDefault <int64_t >(" local_window_size" , -1 )) : -1 ;
37
+
38
+ tree_attention_ = info.GetAttrOrDefault <int64_t >(" tree_attention" , 0 ) == 1 ;
37
39
}
38
40
39
41
int num_heads_; // number of attention heads of Q
@@ -43,13 +45,15 @@ class GQAAttentionBase {
43
45
bool do_rotary_; // whether or not to use rotary embeddings
44
46
bool rotary_interleaved_;
45
47
int local_window_size_;
48
+ bool tree_attention_;
46
49
47
50
bool use_smooth_softmax_;
48
51
49
52
template <typename T>
50
53
Status ApplyAttention (const T* Q, // Q data with shape BxNxSxH
51
54
const T* K, // K data with shape BxN_kvxSxH
52
55
const T* V, // V data with shape BxN_kvxSxH
56
+ const T* attention_mask, // Causal attention mask to apply before softmax
53
57
const Tensor* past_key, // past K input tensor (if not using past state)
54
58
const Tensor* past_value, // past V input tensor (if not using past state)
55
59
Tensor* output, // output tensor
@@ -92,8 +96,8 @@ class GQAAttentionBase {
92
96
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
93
97
94
98
if (gqa_mlas_supported) {
95
- ComputeAttentionProbs (static_cast <T*>(attention_probs), Q, k, seqlens_k->Data <int32_t >(), batch_size,
96
- sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
99
+ ComputeAttentionProbs (static_cast <T*>(attention_probs), Q, k, seqlens_k->Data <int32_t >(), attention_mask, batch_size,
100
+ sequence_length, parameters. total_sequence_length , seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
97
101
present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
98
102
99
103
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
@@ -104,9 +108,10 @@ class GQAAttentionBase {
104
108
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
105
109
is_prompt, tp, allocator);
106
110
} else {
107
- ComputeAttentionProbs (static_cast <float *>(attention_probs), Q, k, seqlens_k->Data <int32_t >(), batch_size,
108
- sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
109
- present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
111
+ ComputeAttentionProbs (static_cast <float *>(attention_probs), Q, k, seqlens_k->Data <int32_t >(), attention_mask,
112
+ batch_size, sequence_length, parameters.total_sequence_length , seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
113
+ past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp,
114
+ allocator);
110
115
111
116
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
112
117
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
@@ -130,8 +135,10 @@ class GQAAttentionBase {
130
135
const T* Q, // Q data. Its size is BxNxSxH
131
136
const T* K, // k data. Its size is BxNxLxH
132
137
const int32_t * seqlens_k, // total - 1 sequence lengths tensor
138
+ const T* attention_mask,
133
139
const size_t batch_size, // batch size of self-attention
134
140
const size_t sequence_length, // sequence length of self-attention (S)
141
+ const size_t total_sequence_length, // max total sequence length in batch
135
142
const size_t past_buffer_sequence_length, // sequence length of past state
136
143
const size_t present_buffer_sequence_length, // sequence length of present state
137
144
const size_t head_size, // head size of self-attention
@@ -189,6 +196,11 @@ class GQAAttentionBase {
189
196
const ptrdiff_t output_offset = SafeInt<ptrdiff_t >(i) * sequence_length * present_buffer_sequence_length;
190
197
U* output = attention_probs + output_offset;
191
198
199
+ const ptrdiff_t attention_mask_offset = SafeInt<ptrdiff_t >(batch_index) * sequence_length * total_sequence_length;
200
+ const T* attention_mask_batch = attention_mask != nullptr ? attention_mask + attention_mask_offset : nullptr ;
201
+
202
+ // std::cout << "Batch_index: " << batch_index << ", sequence_length: " << sequence_length << ", present_buffer_sequence_length: " << present_buffer_sequence_length << std::endl;
203
+
192
204
const T* k;
193
205
if (packed_qkv) {
194
206
k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
@@ -242,6 +254,33 @@ class GQAAttentionBase {
242
254
U* output_softmax = output;
243
255
for (size_t seq = 0 ; seq < sequence_length; seq++) {
244
256
size_t seq_causal_length = past_seqlen + seq + 1 ;
257
+
258
+ // TODO: Vectorize this addition
259
+ if (nullptr != attention_mask_batch) {
260
+ if (batch_index == 0 && head_index == 0 ) {
261
+ std::cout << " attention mask offset: " << attention_mask_offset <<std::endl;
262
+ std::cout << " ptr diff " << (attention_mask_batch - attention_mask) << std::endl;
263
+ std::cout << " total sequence length batch " << total_seqlen << std::endl;
264
+ std::cout << " total sequence length " << total_sequence_length << std::endl;
265
+ std::cout << " past sequence length " << past_seqlen << std::endl;
266
+ std::cout << " sequence length: " << sequence_length << std::endl;
267
+ // std::cout << "==========================================" << std::endl;
268
+ // for (size_t i = 0; i < seq_causal_length; i++) {
269
+ // std::cout << attention_mask_batch[i] << " ";
270
+ // }
271
+ // std::cout << std::endl;
272
+ std::cout << " ==========================================" << std::endl;
273
+ }
274
+
275
+ for (size_t i = 0 ; i < seq_causal_length; i++) {
276
+ if constexpr (std::is_same<U, float >::value) {
277
+ output_softmax[i] += static_cast <float >(attention_mask_batch[i]);
278
+ } else {
279
+ output_softmax[i] = MLFloat16 (output_softmax[i].ToFloat () + attention_mask[i].ToFloat ());
280
+ }
281
+ }
282
+ }
283
+
245
284
if (local_window_size_ > 0 && seq_causal_length > static_cast <size_t >(local_window_size_) + 1 ) {
246
285
for (size_t total_seq_id = 0 ; total_seq_id < seq_causal_length - local_window_size_ - 1 ; total_seq_id++) {
247
286
if constexpr (std::is_same<U, float >::value) {
@@ -283,6 +322,10 @@ class GQAAttentionBase {
283
322
}
284
323
285
324
output_softmax += present_buffer_sequence_length;
325
+
326
+ if (nullptr != attention_mask_batch) {
327
+ attention_mask_batch += total_seqlen;
328
+ }
286
329
}
287
330
}
288
331
});
0 commit comments