diff --git a/README.md b/README.md index 97707790..ac19db06 100644 --- a/README.md +++ b/README.md @@ -50,11 +50,11 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh |Tensor Cores|Loop over Seqlen/Headdim |Tile Block (Br, Bc)|MMA (m16n8k16)| |:---:|:---:|:---:|:---:| |✔️|✔️|✔️|✔️| -|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads) +|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)| |✔️|✔️|✔️|✔️| |Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**| |✔️|✔️|✔️|✔️| -|**Shared KV** SMEM|Fully **Shared QKV** SMEM|**Prefetch Q** s2r|SMEM/Block Swizzle| +|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle| |✔️|✔️|✔️|?| Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop): @@ -133,7 +133,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
```C++ -// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy. +// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access. __global__ void flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, diff --git a/kernels/flash-attn/README.md b/kernels/flash-attn/README.md index 5717d994..eb2bc2ca 100644 --- a/kernels/flash-attn/README.md +++ b/kernels/flash-attn/README.md @@ -9,7 +9,7 @@ |✔️|✔️|✔️|✔️| |Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shuffle & Reg Reuse)|**Split KV/Q**| |✔️|✔️|✔️|✔️| -|**Shared KV** SMEM|Fully **Shared QKV** SMEM|**Prefetch Q** s2r|SMEM/Block Swizzle| +|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle| |✔️|✔️|✔️|?| This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop. However, for large-scale attention computations, there remains a performance gap. Performance optimizations are ongoing; stay tuned for updates. @@ -107,7 +107,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
```C++ -// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy. +// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access. __global__ void flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu index aad3d5f2..b4efc69f 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu @@ -173,9 +173,11 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, static_assert(kHeadDim >= 32, "shared_qkv only support headdim>=32"); constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8); // always true. // Use kStage and (Br / Bc) to control multi-stage policy for K g2s. - constexpr bool kCanPrefetchKg2s = ( + constexpr bool kCanPrefetchKVg2s = ( ((Q_tile_size / KV_tile_size) >= 2) && (kStage >= 2)); // for d<=64 is true. - constexpr int kPrefetchStageKg2s = kCanPrefetchKg2s ? 2 : 1; // only apply stage 2 for k prefetch. + // constexpr int kPrefetchStageKg2s = kCanPrefetchKVg2s ? 2 : 1; // only apply stage 2 for k prefetch. + constexpr int kPrefetchKg2sSmemId = 0; // smem id for K g2s, 0. + constexpr int kPrefetchVg2sSmemId = kCanPrefetchKVg2s ? 1 : 0; // smem id for V g2s, 1. constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1; uint32_t R_Q[kNumPrefetchQs2r][kWarpTileSeqLenQ][4]; // [4/8/1][1][4] uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2] @@ -209,11 +211,11 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc] #pragma unroll 1 for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) { - // TODO: process last tile_K_seqlen ? pad to multiple of 8. - // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0; - int smem_sel = (tile_K_seqlen) % kPrefetchStageKg2s; - // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2; - int smem_sel_next = (tile_K_seqlen + (kPrefetchStageKg2s - 1)) % kPrefetchStageKg2s; + // // TODO: process last tile_K_seqlen ? pad to multiple of 8. + // // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0; + // int smem_sel = (tile_K_seqlen) % kPrefetchStageKg2s; + // // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2; + // int smem_sel_next = (tile_K_seqlen + (kPrefetchStageKg2s - 1)) % kPrefetchStageKg2s; // Wait Q ready and let K copy async, then prefetch Q from smem -> regs. // NOTE: we only need to load Q once from smem -> regs, and then reuse it. @@ -245,17 +247,17 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, } } - // Load K tile from gmem -> smem - if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) { + // Load K tile from gmem -> smem, always use smem part 0. + if constexpr (kCanPrefetchKVg2s) { if (tile_K_seqlen == 0) { load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel * KV_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); #pragma unroll for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); @@ -265,15 +267,33 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } + // : Load V tile async from gmem -> smem 1, before Q@K^T + { + load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size + + load_smem_V_Bc * (kHeadDim + kPad) + + load_smem_V_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } } else { load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel * KV_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); #pragma unroll for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); @@ -287,6 +307,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] // Matmul with NN layout, Q row major, K row major. // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc] + // fill_3D_regs(R_S, 0); #pragma unroll for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { @@ -298,7 +319,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K); int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N) uint32_t lane_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel * KV_tile_size + + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + lane_smem_K_d * (Bc + kPad) + lane_smem_K_Bc) * sizeof(half) ); @@ -320,17 +341,16 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, } // end loop over d, S=Q@K^T __syncthreads(); - // Then, async prefetch curr V tile_K_seqlen [Bc,d] (no stages), - // before rowmax and rowsum, load V from gmem -> smem. - // TODO: Can we support stages 2 for V g2s? - { + // : If kCanPrefetchKVg2s is not enable, + // we will load V g2s here, before rowmax and rowsum. + if constexpr (!kCanPrefetchKVg2s) { load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; int load_gmem_V_d = load_smem_V_d; int load_gmem_V_addr = ( V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); uint32_t load_smem_V_ptr = ( - smem_V_base_ptr + (smem_sel * KV_tile_size + + smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size + load_smem_V_Bc * (kHeadDim + kPad) + load_smem_V_d) * sizeof(half) ); @@ -341,14 +361,15 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, CP_ASYNC_COMMIT_GROUP(); } - if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) { + // : load next K tile from gmem -> smem 0, before P@V. + if constexpr (kCanPrefetchKVg2s) { if ((tile_K_seqlen + 1) < Tc) { load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel_next * KV_tile_size + + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + load_smem_K_d * (Bc + kPad) + load_smem_K_Bc) * sizeof(half)); #pragma unroll @@ -407,7 +428,6 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]); lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]); } // end for kWarpTileSeqLenQ - // __syncthreads(); // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block. #pragma unroll @@ -447,11 +467,11 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention. // Here, we have to wait V ready before compute O = P @ V - if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) { + if constexpr (kCanPrefetchKVg2s) { if ((tile_K_seqlen + 1) < Tc) { - CP_ASYNC_WAIT_GROUP(1); + CP_ASYNC_WAIT_GROUP(1); // we have send V & K g2s, wait V and let K async. } else { - CP_ASYNC_WAIT_GROUP(0); + CP_ASYNC_WAIT_GROUP(0); // we have only send V g2s. } } else { CP_ASYNC_WAIT_GROUP(0); @@ -480,6 +500,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // ... // 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7} + // fill_3D_regs(R_O, 0); #pragma unroll for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) { @@ -490,7 +511,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, int lane_smem_V_Bc = tile_V_Bc * kMmaAtomK + lane_id % 16; // 0~15; Bc, matmul K int lane_smem_V_d = warp_smem_V_d; // 0 uint32_t lane_smem_V_ptr = ( - smem_V_base_ptr + (smem_sel * KV_tile_size + + smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size + lane_smem_V_Bc * (kHeadDim + kPad) + lane_smem_V_d) * sizeof(half) ); @@ -582,9 +603,10 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, lane_block_row_max_old[i][1] = block_row_max_new_1; } - if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) { + if constexpr (kCanPrefetchKVg2s) { if ((tile_K_seqlen + 1) < Tc) { - CP_ASYNC_WAIT_GROUP(0); + // now, we have to wait next K tile ready in smem. + CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } } @@ -619,7 +641,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 #pragma unroll for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8 - + // reuse R_Q regs if kCanPrefetchQs2r is enable, reduce registers usage. if constexpr (kCanPrefetchQs2r && kNumPrefetchQs2r > 1) { // always true for shared qkv kernel // reuse R_Q[4/8][1][4] for collective store. R_Q[0][0][0] = R_D[i][j][0]; R_Q[1][0][0] = R_D[i][j][1]; // warp_size 4 @@ -703,15 +725,15 @@ void launch_flash_attn_mma_stages_split_q_shared_qkv( // static int kMaxSramPerBlock; // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); - // Calculate SRAM size needed per block, QKV smem size, QKV fully shared the same smem. const int smem_max_size = (Br * (kHeadDim + kPad)) * sizeof(half); // 128x(32/64/128)x2/1024=8/16/32M const int QKV_batch = Q.size(0); const int QKV_head = Q.size(1); const int QKV_seqlen = Q.size(2); // QKV_seqlen - assert(QKV_seqlen % Bc == 0); // multiple of Bc=64 + assert(QKV_seqlen % max(Br, Bc) == 0); // multiple of max(Br, Bc) + // TODO: How to apply block swizzle to improve L2 Cache hit rate? dim3 grid(QKV_batch, QKV_head, div_ceil(QKV_seqlen, Br)); // batch_size x num_heads x Tr(=N/Br) dim3 block(kNumThreads); // 4/8 warps per block