Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions include/flashinfer/attention/persistent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,90 @@ cudaError_t BatchPagedAttentionPersistentHolistic(const Params params_1, const P
return cudaSuccess;
}

template <typename Params>
cudaError_t BatchAttentionScoreReductionPersisitent(const Params params, const uint32_t num_blks_x,
const uint32_t num_blks_y,
const cudaStream_t stream) {
using DTypeO = typename Params::DTypeO;
using IdType = typename Params::IdType;
constexpr uint32_t NUM_THREADS = 128;
auto kernel = BatchPagedAttentionPersistentHolisticKernel<Params, NUM_THREADS>;

dim3 nblks(num_blks_x, num_blks_y);
dim3 nthrs(NUM_THREADS);
void* args[] = {(void*)&params};

size_t smem_size = 16 * 1024;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(
cudaLaunchCooperativeKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
return cudaSuccess;
}

template <typename Params, uint32_t NUM_THREADS>
__global__ __launch_bounds__(NUM_THREADS) void BatchPagedAttentionPersistentHolisticKernel(
const __grid_constant__ Params params) {
extern __shared__ uint8_t smem[];
using DTypeIn = typename Params::DTypeO;
using DTypeOut = typename Params::DTypeO;
using DTypeAccum = float;
using IdType = typename Params::IdType;

// [max_batch_size, num_kv_heads, max_seqlen]
// gqa_group_size is packed into first dim
DTypeIn* qk_in_ptr = params.qk_ptr;
const uint32_t qk_stride = params.qk_max_seqlen;

// [max_batch_size, num_kv_heads, max_seqlen]
DTypeOut* reduced_o_ptr = params.reduced_o_ptr;
const uint32_t reduced_o_stride = params.qk_stride;

// [nnz, num_qo_heads]
float* lse = params.lse;
IdType* work_indptr = params.work_indptr;

const uint_fastdiv& gqa_group_size = params.gqa_group_size;
const uint32_t num_kv_heads = params.num_kv_heads;
const uint32_t num_qo_heads = num_kv_heads * gqa_group_size;
const uint32_t lane_idx = threadIdx.x;

#pragma unroll 1
for (IdType work_idx = work_indptr[blockIdx.y]; work_idx < work_indptr[blockIdx.y + 1];
++work_idx) {
const auto [q_indptr, kv_indptr, partial_indptr, q_len, kv_len, packed_qo_start, kv_start,
kv_end, kv_head_idx] = get_block_coord(params, work_idx);
if (packed_qo_start != 0) {
// the first block takes care of all qo_len
// for minimal code changes
continue;
}
// else packed_qo_start == 0
const auto o_indptr = params.o_indptr[work_idx];
const auto qk_indptr = params.qk_indptr[work_idx]; // packed
DTypeIn* qk_ptr_base = qk_in_ptr + ((qk_indptr * num_kv_heads) + kv_head_idx) * qk_stride;
DTypeOut* o_ptr_base =
reduced_o_ptr + ((o_indptr * num_kv_heads) + kv_head_idx) * reduced_o_stride;
float* lse_base = lse + q_indptr * num_qo_heads;

// reduction kernel:
// each threadblock read qk_ptr_base[0:q_len*gqa_group_size, kv_head_idx, kv_start:kv_end]
// do a reduction into a tensor with shape (1,1,kv_start:kv_end)
// and write it back to o_ptr_base[0, kv_head_idx, kv_start:kv_end]
{
const uint32_t total_rows = q_len * gqa_group_size;
for (uint32_t kv_idx = kv_start + lane_idx; kv_idx < kv_end; kv_idx += NUM_THREADS) {
DTypeAccum sum = 0.0f;
for (uint32_t i = 0; i < total_rows; ++i) {
DTypeAccum cur_lse = static_cast<DTypeAccum>(lse_base[i]);
sum += cur_lse * static_cast<DTypeAccum>(
qk_ptr_base[(i * num_kv_heads + kv_head_idx) * qk_stride + kv_idx]);
}
o_ptr_base[kv_idx] = static_cast<DTypeOut>(sum);
}
}
}
}
}; // namespace flashinfer

#endif // FLASHINFER_PERSISTENT_CUH_
Loading