Skip to content

Commit

Permalink
Adds GQA support for SDPA op.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627235612
  • Loading branch information
tensorflower-gardener committed Apr 23, 2024
1 parent bf065bf commit 177e447
Show file tree
Hide file tree
Showing 3 changed files with 352 additions and 91 deletions.
24 changes: 22 additions & 2 deletions tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc
Expand Up @@ -30,7 +30,7 @@ TEST(ODMLSDPA, MQA) {

const auto batch = 1;
const auto input_seq_len = 1;
const auto max_seq_len = 500;
const auto max_seq_len = 64;
const auto q_heads = 32;
const auto kv_heads = 1;
const auto head_dim = 4; // embedding_dim//q_heads
Expand All @@ -50,7 +50,7 @@ TEST(ODMLSDPA, MHA) {

const auto batch = 1;
const auto input_seq_len = 1;
const auto max_seq_len = 500;
const auto max_seq_len = 64;
const auto q_heads = 32;
const auto kv_heads = 32;
const auto head_dim = 4; // embedding_dim//q_heads
Expand All @@ -63,5 +63,25 @@ TEST(ODMLSDPA, MHA) {
.Test(xnnpack_delegate.get());
}

TEST(ODMLSDPA, GQA) {
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
TfLiteXNNPackDelegateDelete);

const auto batch = 1;
const auto input_seq_len = 1;
const auto max_seq_len = 64;
const auto q_heads = 32;
const auto kv_heads = 4;
const auto head_dim = 4; // embedding_dim//q_heads

ODMLSDPATester()
.QueryShape({batch, input_seq_len, q_heads, head_dim}) // q
.KeyShape({batch, max_seq_len, kv_heads, head_dim}) // k
.ValueShape({batch, max_seq_len, kv_heads, head_dim}) // v
.MaskShape({batch, 1, input_seq_len, max_seq_len}) // mask
.Test(xnnpack_delegate.get());
}

} // namespace xnnpack
} // namespace tflite

0 comments on commit 177e447

Please sign in to comment.