diff --git a/README.md b/README.md
index 97707790..ac19db06 100644
--- a/README.md
+++ b/README.md
@@ -50,11 +50,11 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
|Tensor Cores|Loop over Seqlen/Headdim |Tile Block (Br, Bc)|MMA (m16n8k16)|
|:---:|:---:|:---:|:---:|
|✔️|✔️|✔️|✔️|
-|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)
+|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)|
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
-|**Shared KV** SMEM|Fully **Shared QKV** SMEM|**Prefetch Q** s2r|SMEM/Block Swizzle|
+|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
|✔️|✔️|✔️|?|
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop):
@@ -133,7 +133,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
```C++
-// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy.
+// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access.
__global__ void
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
half* K,
diff --git a/kernels/flash-attn/README.md b/kernels/flash-attn/README.md
index 5717d994..eb2bc2ca 100644
--- a/kernels/flash-attn/README.md
+++ b/kernels/flash-attn/README.md
@@ -9,7 +9,7 @@
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shuffle & Reg Reuse)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
-|**Shared KV** SMEM|Fully **Shared QKV** SMEM|**Prefetch Q** s2r|SMEM/Block Swizzle|
+|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
|✔️|✔️|✔️|?|
This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop. However, for large-scale attention computations, there remains a performance gap. Performance optimizations are ongoing; stay tuned for updates.
@@ -107,7 +107,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
```C++
-// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy.
+// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access.
__global__ void
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
half* K,
diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu
index aad3d5f2..b4efc69f 100644
--- a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu
+++ b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu
@@ -173,9 +173,11 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
static_assert(kHeadDim >= 32, "shared_qkv only support headdim>=32");
constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8); // always true.
// Use kStage and (Br / Bc) to control multi-stage policy for K g2s.
- constexpr bool kCanPrefetchKg2s = (
+ constexpr bool kCanPrefetchKVg2s = (
((Q_tile_size / KV_tile_size) >= 2) && (kStage >= 2)); // for d<=64 is true.
- constexpr int kPrefetchStageKg2s = kCanPrefetchKg2s ? 2 : 1; // only apply stage 2 for k prefetch.
+ // constexpr int kPrefetchStageKg2s = kCanPrefetchKVg2s ? 2 : 1; // only apply stage 2 for k prefetch.
+ constexpr int kPrefetchKg2sSmemId = 0; // smem id for K g2s, 0.
+ constexpr int kPrefetchVg2sSmemId = kCanPrefetchKVg2s ? 1 : 0; // smem id for V g2s, 1.
constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1;
uint32_t R_Q[kNumPrefetchQs2r][kWarpTileSeqLenQ][4]; // [4/8/1][1][4]
uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2]
@@ -209,11 +211,11 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
// tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc]
#pragma unroll 1
for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) {
- // TODO: process last tile_K_seqlen ? pad to multiple of 8.
- // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0;
- int smem_sel = (tile_K_seqlen) % kPrefetchStageKg2s;
- // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2;
- int smem_sel_next = (tile_K_seqlen + (kPrefetchStageKg2s - 1)) % kPrefetchStageKg2s;
+ // // TODO: process last tile_K_seqlen ? pad to multiple of 8.
+ // // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0;
+ // int smem_sel = (tile_K_seqlen) % kPrefetchStageKg2s;
+ // // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2;
+ // int smem_sel_next = (tile_K_seqlen + (kPrefetchStageKg2s - 1)) % kPrefetchStageKg2s;
// Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
// NOTE: we only need to load Q once from smem -> regs, and then reuse it.
@@ -245,17 +247,17 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
}
}
- // Load K tile from gmem -> smem
- if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
+ // Load K tile from gmem -> smem, always use smem part 0.
+ if constexpr (kCanPrefetchKVg2s) {
if (tile_K_seqlen == 0) {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel * KV_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
+ load_smem_K_d * (Bc + kPad) +
+ load_smem_K_Bc) * sizeof(half));
#pragma unroll
for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
@@ -265,15 +267,33 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
+ // : Load V tile async from gmem -> smem 1, before Q@K^T
+ {
+ load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
+ int load_gmem_V_d = load_smem_V_d;
+ int load_gmem_V_addr = (
+ V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
+ uint32_t load_smem_V_ptr = (
+ smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
+ load_smem_V_Bc * (kHeadDim + kPad) +
+ load_smem_V_d) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
} else {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel * KV_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
+ load_smem_K_d * (Bc + kPad) +
+ load_smem_K_Bc) * sizeof(half));
#pragma unroll
for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
@@ -287,6 +307,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
// : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
// Matmul with NN layout, Q row major, K row major.
// S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]
+ //
fill_3D_regs(R_S, 0);
#pragma unroll
for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
@@ -298,7 +319,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K);
int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N)
uint32_t lane_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel * KV_tile_size +
+ smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
lane_smem_K_d * (Bc + kPad) +
lane_smem_K_Bc) * sizeof(half)
);
@@ -320,17 +341,16 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
} // end loop over d, S=Q@K^T
__syncthreads();
- // Then, async prefetch curr V tile_K_seqlen [Bc,d] (no stages),
- // before rowmax and rowsum, load V from gmem -> smem.
- // TODO: Can we support stages 2 for V g2s?
- {
+ // : If kCanPrefetchKVg2s is not enable,
+ // we will load V g2s here, before rowmax and rowsum.
+ if constexpr (!kCanPrefetchKVg2s) {
load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
int load_gmem_V_d = load_smem_V_d;
int load_gmem_V_addr = (
V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
uint32_t load_smem_V_ptr = (
- smem_V_base_ptr + (smem_sel * KV_tile_size +
+ smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
load_smem_V_Bc * (kHeadDim + kPad) +
load_smem_V_d) * sizeof(half)
);
@@ -341,14 +361,15 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
CP_ASYNC_COMMIT_GROUP();
}
- if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
+ // : load next K tile from gmem -> smem 0, before P@V.
+ if constexpr (kCanPrefetchKVg2s) {
if ((tile_K_seqlen + 1) < Tc) {
load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel_next * KV_tile_size +
+ smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
#pragma unroll
@@ -407,7 +428,6 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]);
lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]);
} // end for kWarpTileSeqLenQ
- // __syncthreads();
// Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block.
#pragma unroll
@@ -447,11 +467,11 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
// Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention.
// Here, we have to wait V ready before compute O = P @ V
- if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
+ if constexpr (kCanPrefetchKVg2s) {
if ((tile_K_seqlen + 1) < Tc) {
- CP_ASYNC_WAIT_GROUP(1);
+ CP_ASYNC_WAIT_GROUP(1); // we have send V & K g2s, wait V and let K async.
} else {
- CP_ASYNC_WAIT_GROUP(0);
+ CP_ASYNC_WAIT_GROUP(0); // we have only send V g2s.
}
} else {
CP_ASYNC_WAIT_GROUP(0);
@@ -480,6 +500,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
// ...
// 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7}
+ //
fill_3D_regs(R_O, 0);
#pragma unroll
for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) {
@@ -490,7 +511,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
int lane_smem_V_Bc = tile_V_Bc * kMmaAtomK + lane_id % 16; // 0~15; Bc, matmul K
int lane_smem_V_d = warp_smem_V_d; // 0
uint32_t lane_smem_V_ptr = (
- smem_V_base_ptr + (smem_sel * KV_tile_size +
+ smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
lane_smem_V_Bc * (kHeadDim + kPad) +
lane_smem_V_d) * sizeof(half)
);
@@ -582,9 +603,10 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
lane_block_row_max_old[i][1] = block_row_max_new_1;
}
- if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
+ if constexpr (kCanPrefetchKVg2s) {
if ((tile_K_seqlen + 1) < Tc) {
- CP_ASYNC_WAIT_GROUP(0);
+ // now, we have to wait next K tile ready in smem.
+ CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
}
@@ -619,7 +641,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
#pragma unroll
for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8
-
+ // reuse R_Q regs if kCanPrefetchQs2r is enable, reduce registers usage.
if constexpr (kCanPrefetchQs2r && kNumPrefetchQs2r > 1) { // always true for shared qkv kernel
// reuse R_Q[4/8][1][4] for collective store.
R_Q[0][0][0] = R_D[i][j][0]; R_Q[1][0][0] = R_D[i][j][1]; // warp_size 4
@@ -703,15 +725,15 @@ void launch_flash_attn_mma_stages_split_q_shared_qkv(
// static int kMaxSramPerBlock;
// cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);
-
// Calculate SRAM size needed per block, QKV smem size, QKV fully shared the same smem.
const int smem_max_size = (Br * (kHeadDim + kPad)) * sizeof(half); // 128x(32/64/128)x2/1024=8/16/32M
const int QKV_batch = Q.size(0);
const int QKV_head = Q.size(1);
const int QKV_seqlen = Q.size(2); // QKV_seqlen
- assert(QKV_seqlen % Bc == 0); // multiple of Bc=64
+ assert(QKV_seqlen % max(Br, Bc) == 0); // multiple of max(Br, Bc)
+ // TODO: How to apply block swizzle to improve L2 Cache hit rate?
dim3 grid(QKV_batch, QKV_head, div_ceil(QKV_seqlen, Br)); // batch_size x num_heads x Tr(=N/Br)
dim3 block(kNumThreads); // 4/8 warps per block