Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ Tensor& flash_attention_kernel_out(
// we might consider another appraoch
if (seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
ctx,
output,
query,
key,
Expand All @@ -289,6 +290,7 @@ Tensor& flash_attention_kernel_out(
nullopt);
} else if (seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
ctx,
output,
query,
key,
Expand All @@ -305,6 +307,7 @@ Tensor& flash_attention_kernel_out(
nullopt);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
ctx,
output,
query,
key,
Expand Down Expand Up @@ -418,6 +421,7 @@ Tensor& custom_sdpa_out_impl(
// we might consider another appraoch
if (seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
ctx,
output,
q,
k,
Expand All @@ -437,6 +441,7 @@ Tensor& custom_sdpa_out_impl(
num_keys_for_causal_attention);
} else if (seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
ctx,
output,
q,
k,
Expand All @@ -456,6 +461,7 @@ Tensor& custom_sdpa_out_impl(
num_keys_for_causal_attention);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
ctx,
output,
q,
k,
Expand Down
39 changes: 26 additions & 13 deletions extension/llm/custom_ops/op_sdpa_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ enum class SeqDim { ONE = 1, TWO };

namespace sdpa::impl {

static std::vector<char> scratch_for_quant_dequant_vec;
struct MaybeQuantizedMatrixData {
const void* data{nullptr};
const int8_t* zero_points{nullptr};
Expand Down Expand Up @@ -543,6 +544,7 @@ TODO: Just handle conversion of bool mask to float
*/
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
void cpu_flash_attention(
RuntimeContext& ctx,
Tensor& output,
const Tensor& query,
const Tensor& key,
Expand Down Expand Up @@ -766,26 +768,37 @@ void cpu_flash_attention(
int64_t size_of_intermediate_precision = sizeof(accum_t);
int64_t size_bytes = size_per_thread * num_thread * query.element_size() *
size_of_intermediate_precision;
std::vector<char> buf_vec(size_bytes);
void* buf = reinterpret_cast<void*>(buf_vec.data());
// Need to double check the following
size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size();
std::vector<char> buf_reduced_vec(size_bytes);
void* buf_reduced = reinterpret_cast<void*>(buf_reduced_vec.data());
// at::Tensor buf_reduced = at::empty(
// {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
// query.options());
Result<void*> buff_res = ctx.allocate_temp(size_bytes);
std::unique_ptr<char[]> allocated_buf;
void* buf;
if (!buff_res.ok()) {
allocated_buf = std::make_unique<char[]>(size_bytes);
buf = reinterpret_cast<void*>(allocated_buf.get());
} else {
buf = buff_res.get();
}
void* buf_reduced = nullptr;
int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize;
// Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
// by padding with right number of per thread elements
constexpr int64_t kAlignment = 32;
size_per_thread_qdq_vec =
(size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1));
int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof(accum_t);
int64_t size_per_thread_qdq_bytes =
size_per_thread_qdq_vec * size_of_intermediate_precision;
int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
std::vector<char> scratch_for_quant_dequant_vec(size_qdq_bytes);
accum_t* scratch_for_quant_dequant =
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_vec.data());
std::unique_ptr<char[]> allocated_buf_for_qdq;
Result<void*> scratch_for_quant_dequant_res =
ctx.allocate_temp(size_qdq_bytes);
accum_t* scratch_for_quant_dequant;
if (!scratch_for_quant_dequant_res.ok()) {
allocated_buf_for_qdq = std::make_unique<char[]>(size_qdq_bytes);
scratch_for_quant_dequant =
reinterpret_cast<accum_t*>(allocated_buf_for_qdq.get());
} else {
scratch_for_quant_dequant =
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_res.get());
}

// Data ptrs
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
Expand Down
Loading