Skip to content

Commit

Permalink
Add MQA attention in experimental/gen_ai (pytorch#2504)
Browse files Browse the repository at this point in the history
Summary:

This diff open sources `mqa_attn_splitk`, which is a multi-query
attention operator for LLM inference for the decoding phase.  The
operator supports the following:

- BF16 input query and output types
- BF16/INT4 KV cache types
- 16,384 max context length
- Fixed head dimension of 128
- Arbitrary query head size

The INT4 KV `mqa_attn_splitk` is ~1.7x faster than the BF16 KV
counterpart.

Differential Revision: D56110657
  • Loading branch information
sryap authored and facebook-github-bot committed Apr 16, 2024
1 parent 4b1174c commit 773c7e6
Show file tree
Hide file tree
Showing 3 changed files with 1,276 additions and 0 deletions.
45 changes: 45 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/src/attention/attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>

namespace fbgemm_gpu::gen_ai::attention {

std::tuple<at::Tensor, at::Tensor, at::Tensor> mqa_attn_splitk_cuda(
const at::Tensor& XQ,
const at::Tensor& cache_K,
const at::Tensor& cache_V,
const at::Tensor& seq_positions,
const double qk_scale,
const int64_t num_split_ks,
const int64_t num_groups);

} // namespace fbgemm_gpu::gen_ai::attention

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"mqa_attn_splitk("
" Tensor XQ, "
" Tensor cache_K, "
" Tensor cache_V, "
" Tensor seq_positions, "
" float qk_scale, "
" int num_split_ks, "
" int num_int4_kv_groups=1"
") -> (Tensor, Tensor, Tensor)");
}

TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl(
"mqa_attn_splitk",
torch::dispatch(
c10::DispatchKey::CUDA,
TORCH_FN(fbgemm_gpu::gen_ai::attention::mqa_attn_splitk_cuda)));
}
Loading

0 comments on commit 773c7e6

Please sign in to comment.