Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions kernels/flash-attn/flash_attn_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,18 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor,
out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2)
out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1)
out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2)
out_mma_share_kv, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage1)", o, stages=1)
out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage1)", o, stages=1)
out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage2)", o, stages=2)
out_mma_share_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage1)", o, stages=1)
out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage2)", o, stages=2)
out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage1)", o, stages=1)
out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage2)", o, stages=2)
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
if args.run_torch_sdpa:
out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
print("-" * 120)

torch.cuda.synchronize()
if args.check:
check_all_close(out_flash, out_mma_split_kv1, "out_mma_split_kv1", args.show_all)
check_all_close(out_flash, out_mma_split_q1, "out_mma_split_q1", args.show_all)
check_all_close(out_flash, out_mma_split_kv1, "out_mma_split_kv1", args.show_all)
check_all_close(out_flash, out_mma_split_q1, "out_mma_split_q1", args.show_all)
check_all_close(out_flash, out_mma_share_kv1, "out_mma_share_kv1", args.show_all)
check_all_close(out_flash, out_mma_share_qkv1, "out_mma_share_qkv1", args.show_all)
216 changes: 139 additions & 77 deletions kernels/flash-attn/mma/flash_attn_mma_share_kv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,12 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1
extern __shared__ half smem[];
constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
// constexpr int KV_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
// constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
constexpr int KV_tile_size = (
((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ?
((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad))
);
// K multi-stages: currently, only apply multi stages for K across seq_len.
half* Q_tile_smem = smem; // 8M/16M
half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M
Expand All @@ -164,14 +168,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,

// ---------------------- Registers for S=Q@K^T/O=P@V ----------------------------
// registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d].
// TODO: Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs.
// Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs.
// By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
// Then we can load Q from smem only once and reuse it for <loop over K seqlen>
// processes. This will reduce large io-access for Q smem while N is large.
// constexpr bool kCanPrefetchQs2r = false; // d <= 128
// FIXME(DefTruth): why can not get good performance for headdim >= 64 ?
// Will enable it untill I have figure out the performance issues.
constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8) && (kHeadDim < 64);
constexpr bool kCanPrefetchKVg2s = (kStage == 2); // whether prefetch KV g2s.
constexpr int kPrefetchKg2sSmemId = 0; // smem id for K g2s, 0.
constexpr int kPrefetchVg2sSmemId = kCanPrefetchKVg2s ? 1 : 0; // smem id for V g2s, 1.
constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1;
uint32_t R_Q[kNumPrefetchQs2r][kWarpTileSeqLenQ][4]; // [4/8/1][1][4]
uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2]
Expand All @@ -180,7 +187,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// = Q_tile[Br,d] * K[Bc,d], each thread hold 2x32 bits regs.
uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][ 2]; // [1][8][2]
// registers for tile_K_seqlen O=PV[Br,d]=P@V, [2][2/4][2], 8 or 16 regs.
// TODO: may reuse R_D as R_O? kWarpTileSeqLenP=kWarpTileSeqLenQ.
uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
// registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs.
uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
Expand All @@ -206,39 +212,13 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
#pragma unroll 1
for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) {
// TODO: process last tile_K_seqlen ? pad to multiple of 8.
// s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0;
int smem_sel = (tile_K_seqlen) % kStage;
// s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2;
int smem_sel_next = (tile_K_seqlen + (kStage - 1)) % kStage;

// multi stages pipeling gmem -> smem
// NOTE: kStage must be > 1 for pipeling. For s1, smem_sel
// and smem_sel_next will always equal 0, thus, we can not
// prefetch KV from gmem to smem before tile_K_seqlen MMA done.

// First, prefetch curr K tile_K_seqlen [d,Bc] (no stages)
if (tile_K_seqlen == 0) {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (smem_sel * K_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
#pragma unroll
for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
}


// <Prefetch Q s2r>: Load Q tile from smem -> regs, before Q@K^T.
if constexpr (kCanPrefetchQs2r) {
// Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
// NOTE: we only need to load Q once from smem -> regs, and then reuse it.
if (tile_K_seqlen == 0) {
// TODO: Full share QKV smem after Q is ready load to regs.
CP_ASYNC_WAIT_GROUP(1);
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();

#pragma unroll
Expand All @@ -262,24 +242,76 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
}
}
} // end if tile_K_seqlen == 0
// Now, we have to wait curr K tile ready for Q@K^T MMA.
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
} // end if kCanPrefetchQs2r

// Load K tile from gmem -> smem, always use smem part 0.
if constexpr (kCanPrefetchKVg2s) {
if (tile_K_seqlen == 0) {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
#pragma unroll
for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
// Now, we have to wait curr K tile ready for Q@K^T MMA.
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
// <Prefetch V g2s>: Load V tile async from gmem -> smem 1, before Q@K^T
{
load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
int load_gmem_V_d = load_smem_V_d;
int load_gmem_V_addr = (
V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
uint32_t load_smem_V_ptr = (
smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
load_smem_V_Bc * (kHeadDim + kPad) +
load_smem_V_d) * sizeof(half)
);
#pragma unroll
for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
}
} else {
// Wait curr Q and K tile ready.
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
#pragma unroll
for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
// Now, we have to wait curr K tile ready for Q@K^T MMA.
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}

// <loop over K d>: tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
// Matmul with NN layout, Q row major, K row major.
// S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]
// <HGEMM in shared memory>
fill_3D_regs<uint32_t, kWarpTileSeqLenQ, kWarpTileSeqLenK, 2>(R_S, 0);
#pragma unroll
for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
// smem -> reg, load m16k16 smem Q, offset d according tile_K_d.
// ldmatrix.x4 for Q_tile_smem.
if constexpr (!kCanPrefetchQs2r) {
if constexpr (!kCanPrefetchQs2r) {
// load Q from smem -> regs in each loop w/o prefetch Q s2r.
#pragma unroll
for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
Expand All @@ -302,7 +334,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K);
int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N)
uint32_t lane_smem_K_ptr = (
smem_K_base_ptr + (smem_sel * K_tile_size +
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
lane_smem_K_d * (Bc + kPad) +
lane_smem_K_Bc) * sizeof(half)
);
Expand Down Expand Up @@ -338,16 +370,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
} // end loop over d, S=Q@K^T
__syncthreads();

// Then, async prefetch curr V tile_K_seqlen [Bc,d] (no stages),
// before rowmax and rowsum, load V from gmem -> smem.
{
// <w/o Prefetch V g2s>: If kCanPrefetchKVg2s is not enable,
// we will load V g2s here, before rowmax and rowsum.
if constexpr (!kCanPrefetchKVg2s) {
load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
int load_gmem_V_d = load_smem_V_d;
int load_gmem_V_addr = (
V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
uint32_t load_smem_V_ptr = (
smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) +
smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
load_smem_V_Bc * (kHeadDim + kPad) +
load_smem_V_d) * sizeof(half)
);
#pragma unroll
Expand All @@ -357,6 +390,25 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
CP_ASYNC_COMMIT_GROUP();
}

// <Prefetch K g2s>: load next K tile from gmem -> smem 0, before P@V.
if constexpr (kCanPrefetchKVg2s) {
if ((tile_K_seqlen + 1) < Tc) {
load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
#pragma unroll
for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
}
}

// MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
// | 64x64 | warp_KV 0 |
// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
Expand Down Expand Up @@ -446,9 +498,12 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,

// Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention.
// Here, we have to wait V ready before compute O = P @ V
if constexpr (kStage > 1) {
// NOTE: For kStage > 1, we have send V mem issues before K
CP_ASYNC_WAIT_GROUP(1);
if constexpr (kCanPrefetchKVg2s) {
if ((tile_K_seqlen + 1) < Tc) {
CP_ASYNC_WAIT_GROUP(1); // we have send V & K g2s, wait V and let K async.
} else {
CP_ASYNC_WAIT_GROUP(0); // we have only send V g2s.
}
} else {
CP_ASYNC_WAIT_GROUP(0);
}
Expand Down Expand Up @@ -476,6 +531,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// ...
// 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7}

// <HGEMM in registers>
fill_3D_regs<uint32_t, kWarpTileSeqLenP, kWarpTileHeadDimV, 2>(R_O, 0);
#pragma unroll
for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) {
Expand All @@ -486,7 +542,8 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
int lane_smem_V_Bc = tile_V_Bc * kMmaAtomK + lane_id % 16; // 0~15; Bc, matmul K
int lane_smem_V_d = warp_smem_V_d; // 0
uint32_t lane_smem_V_ptr = (
smem_V_base_ptr + (lane_smem_V_Bc * (kHeadDim + kPad) +
smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
lane_smem_V_Bc * (kHeadDim + kPad) +
lane_smem_V_d) * sizeof(half)
);
LDMATRIX_X2_T(R_V[j][0], R_V[j][1], lane_smem_V_ptr); // R_V
Expand Down Expand Up @@ -517,23 +574,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
} // end for V Bc.
__syncthreads();

// NOTE: Load next K tile async before rescale O
if ((tile_K_seqlen + 1) < Tc) {
load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (smem_sel * K_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
#pragma unroll
for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
}

// Rescale O -> Update row sum Exp -> then, Update row max.
#pragma unroll
for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP=1
Expand Down Expand Up @@ -594,11 +634,12 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
lane_block_row_max_old[i][1] = block_row_max_new_1;
}

// NOTE: After compute P @ V, we have to wait next K tile ready in smem.
// do not need to wait any things if kStage == 1.
if constexpr (kStage > 1) {
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
if constexpr (kCanPrefetchKVg2s) {
if ((tile_K_seqlen + 1) < Tc) {
// now, we have to wait next K tile ready in smem.
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
}

} // end loop over N
Expand Down Expand Up @@ -690,7 +731,8 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
template<const int kHeadDim, const int kStage>
void launch_flash_attn_mma_stages_split_q_shared_kv(
torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) {
// Tile BrxBc=128x64
// Now: fixed tile BrxBc=128x64
// TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size.
constexpr int kMmaAtomM = 16;
constexpr int kMmaAtomN = 8;
constexpr int kMmaAtomK = 16;
Expand All @@ -712,18 +754,22 @@ void launch_flash_attn_mma_stages_split_q_shared_kv(

// static int kMaxSramPerBlock;
// cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);

// Calculate SRAM size needed per block, Q,K/V smem size, KV shared the same smem.
constexpr int KV_tile_size = (
((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ?
((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad))
);
const int smem_max_size = ((Br * (kHeadDim + kPad)) +
(kStage * kHeadDim * (Bc + kPad))) * sizeof(half);
(kStage * KV_tile_size)) * sizeof(half);

const int QKV_batch = Q.size(0);
const int QKV_head = Q.size(1);
const int QKV_seqlen = Q.size(2); // QKV_seqlen
assert(QKV_seqlen % Bc == 0); // multiple of Bc=64

assert(QKV_seqlen % max(Br, Bc) == 0); // multiple of max(Br, Bc)

// TODO: How to apply block swizzle to improve L2 Cache hit rate?
dim3 grid(QKV_batch, QKV_head, div_ceil(QKV_seqlen, Br)); // batch_size x num_heads x Tr(=N/Br)
dim3 block(kNumThreads); // 4 warps per block
dim3 block(kNumThreads); // 4/8 warps per block

cudaFuncSetAttribute(
flash_attn_mma_stages_split_q_shared_kv_kernel<
Expand Down Expand Up @@ -783,8 +829,24 @@ void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q,
const int d = Q.size(3); // B, H, N, d

if (stages > 1) {
throw std::runtime_error(
"split_q_shared_kv not support stages>1 now!");
switch (d)
{
case 32:
launch_flash_attn_mma_stages_split_q_shared_kv<32, 2>(Q, K, V, O);
break;
case 64:
launch_flash_attn_mma_stages_split_q_shared_kv<64, 2>(Q, K, V, O);
break;
case 96:
launch_flash_attn_mma_stages_split_q_shared_kv<96, 2>(Q, K, V, O);
break;
case 128:
launch_flash_attn_mma_stages_split_q_shared_kv<128, 2>(Q, K, V, O);
break;
default:
throw std::runtime_error("headdim not support!");
break;
}
} else {
switch (d)
{
Expand Down
Loading