diff --git a/kernels/flash-attn/flash_attn_mma.py b/kernels/flash-attn/flash_attn_mma.py index 9ae87ddb..8ca323d0 100644 --- a/kernels/flash-attn/flash_attn_mma.py +++ b/kernels/flash-attn/flash_attn_mma.py @@ -259,9 +259,10 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor, out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2) out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1) out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2) - out_mma_share_kv, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage1)", o, stages=1) - out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage1)", o, stages=1) - out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage2)", o, stages=2) + out_mma_share_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage1)", o, stages=1) + out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage2)", o, stages=2) + out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage1)", o, stages=1) + out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage2)", o, stages=2) out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)") if args.run_torch_sdpa: out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)") @@ -269,5 +270,7 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor, torch.cuda.synchronize() if args.check: - check_all_close(out_flash, out_mma_split_kv1, "out_mma_split_kv1", args.show_all) - check_all_close(out_flash, out_mma_split_q1, "out_mma_split_q1", args.show_all) + check_all_close(out_flash, out_mma_split_kv1, "out_mma_split_kv1", args.show_all) + check_all_close(out_flash, out_mma_split_q1, "out_mma_split_q1", args.show_all) + check_all_close(out_flash, out_mma_share_kv1, "out_mma_share_kv1", args.show_all) + check_all_close(out_flash, out_mma_share_qkv1, "out_mma_share_qkv1", args.show_all) diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu index 5ad6bfa7..7f24e3c8 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu @@ -138,8 +138,12 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1 extern __shared__ half smem[]; constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M - constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M - constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M + // constexpr int KV_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M + // constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M + constexpr int KV_tile_size = ( + ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ? + ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad)) + ); // K multi-stages: currently, only apply multi stages for K across seq_len. half* Q_tile_smem = smem; // 8M/16M half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M @@ -164,7 +168,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // ---------------------- Registers for S=Q@K^T/O=P@V ---------------------------- // registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d]. - // TODO: Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs. + // Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs. // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store. // Then we can load Q from smem only once and reuse it for // processes. This will reduce large io-access for Q smem while N is large. @@ -172,6 +176,9 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // FIXME(DefTruth): why can not get good performance for headdim >= 64 ? // Will enable it untill I have figure out the performance issues. constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8) && (kHeadDim < 64); + constexpr bool kCanPrefetchKVg2s = (kStage == 2); // whether prefetch KV g2s. + constexpr int kPrefetchKg2sSmemId = 0; // smem id for K g2s, 0. + constexpr int kPrefetchVg2sSmemId = kCanPrefetchKVg2s ? 1 : 0; // smem id for V g2s, 1. constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1; uint32_t R_Q[kNumPrefetchQs2r][kWarpTileSeqLenQ][4]; // [4/8/1][1][4] uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2] @@ -180,7 +187,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // = Q_tile[Br,d] * K[Bc,d], each thread hold 2x32 bits regs. uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][ 2]; // [1][8][2] // registers for tile_K_seqlen O=PV[Br,d]=P@V, [2][2/4][2], 8 or 16 regs. - // TODO: may reuse R_D as R_O? kWarpTileSeqLenP=kWarpTileSeqLenQ. uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs. uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] @@ -206,39 +212,13 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, #pragma unroll 1 for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) { // TODO: process last tile_K_seqlen ? pad to multiple of 8. - // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0; - int smem_sel = (tile_K_seqlen) % kStage; - // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2; - int smem_sel_next = (tile_K_seqlen + (kStage - 1)) % kStage; - - // multi stages pipeling gmem -> smem - // NOTE: kStage must be > 1 for pipeling. For s1, smem_sel - // and smem_sel_next will always equal 0, thus, we can not - // prefetch KV from gmem to smem before tile_K_seqlen MMA done. - - // First, prefetch curr K tile_K_seqlen [d,Bc] (no stages) - if (tile_K_seqlen == 0) { - load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); - uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel * K_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); - #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { - CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); - } - CP_ASYNC_COMMIT_GROUP(); - } - + + // : Load Q tile from smem -> regs, before Q@K^T. if constexpr (kCanPrefetchQs2r) { // Wait Q ready and let K copy async, then prefetch Q from smem -> regs. // NOTE: we only need to load Q once from smem -> regs, and then reuse it. if (tile_K_seqlen == 0) { - // TODO: Full share QKV smem after Q is ready load to regs. - CP_ASYNC_WAIT_GROUP(1); + CP_ASYNC_WAIT_GROUP(0); __syncthreads(); #pragma unroll @@ -262,11 +242,61 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, } } } // end if tile_K_seqlen == 0 - // Now, we have to wait curr K tile ready for Q@K^T MMA. - CP_ASYNC_WAIT_GROUP(0); - __syncthreads(); + } // end if kCanPrefetchQs2r + + // Load K tile from gmem -> smem, always use smem part 0. + if constexpr (kCanPrefetchKVg2s) { + if (tile_K_seqlen == 0) { + load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + // Now, we have to wait curr K tile ready for Q@K^T MMA. + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + // : Load V tile async from gmem -> smem 1, before Q@K^T + { + load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size + + load_smem_V_Bc * (kHeadDim + kPad) + + load_smem_V_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } } else { - // Wait curr Q and K tile ready. + load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + // Now, we have to wait curr K tile ready for Q@K^T MMA. CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } @@ -274,12 +304,14 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] // Matmul with NN layout, Q row major, K row major. // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc] + // fill_3D_regs(R_S, 0); #pragma unroll for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { // smem -> reg, load m16k16 smem Q, offset d according tile_K_d. // ldmatrix.x4 for Q_tile_smem. - if constexpr (!kCanPrefetchQs2r) { + if constexpr (!kCanPrefetchQs2r) { + // load Q from smem -> regs in each loop w/o prefetch Q s2r. #pragma unroll for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; @@ -302,7 +334,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K); int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N) uint32_t lane_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel * K_tile_size + + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + lane_smem_K_d * (Bc + kPad) + lane_smem_K_Bc) * sizeof(half) ); @@ -338,16 +370,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, } // end loop over d, S=Q@K^T __syncthreads(); - // Then, async prefetch curr V tile_K_seqlen [Bc,d] (no stages), - // before rowmax and rowsum, load V from gmem -> smem. - { + // : If kCanPrefetchKVg2s is not enable, + // we will load V g2s here, before rowmax and rowsum. + if constexpr (!kCanPrefetchKVg2s) { load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; int load_gmem_V_d = load_smem_V_d; int load_gmem_V_addr = ( V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); uint32_t load_smem_V_ptr = ( - smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) + + smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size + + load_smem_V_Bc * (kHeadDim + kPad) + load_smem_V_d) * sizeof(half) ); #pragma unroll @@ -357,6 +390,25 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, CP_ASYNC_COMMIT_GROUP(); } + // : load next K tile from gmem -> smem 0, before P@V. + if constexpr (kCanPrefetchKVg2s) { + if ((tile_K_seqlen + 1) < Tc) { + load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + } + // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps // | 64x64 | warp_KV 0 | // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | @@ -446,9 +498,12 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention. // Here, we have to wait V ready before compute O = P @ V - if constexpr (kStage > 1) { - // NOTE: For kStage > 1, we have send V mem issues before K - CP_ASYNC_WAIT_GROUP(1); + if constexpr (kCanPrefetchKVg2s) { + if ((tile_K_seqlen + 1) < Tc) { + CP_ASYNC_WAIT_GROUP(1); // we have send V & K g2s, wait V and let K async. + } else { + CP_ASYNC_WAIT_GROUP(0); // we have only send V g2s. + } } else { CP_ASYNC_WAIT_GROUP(0); } @@ -476,6 +531,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // ... // 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7} + // fill_3D_regs(R_O, 0); #pragma unroll for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) { @@ -486,7 +542,8 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, int lane_smem_V_Bc = tile_V_Bc * kMmaAtomK + lane_id % 16; // 0~15; Bc, matmul K int lane_smem_V_d = warp_smem_V_d; // 0 uint32_t lane_smem_V_ptr = ( - smem_V_base_ptr + (lane_smem_V_Bc * (kHeadDim + kPad) + + smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size + + lane_smem_V_Bc * (kHeadDim + kPad) + lane_smem_V_d) * sizeof(half) ); LDMATRIX_X2_T(R_V[j][0], R_V[j][1], lane_smem_V_ptr); // R_V @@ -517,23 +574,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, } // end for V Bc. __syncthreads(); - // NOTE: Load next K tile async before rescale O - if ((tile_K_seqlen + 1) < Tc) { - load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); - uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel * K_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); - #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { - CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); - } - CP_ASYNC_COMMIT_GROUP(); - } - // Rescale O -> Update row sum Exp -> then, Update row max. #pragma unroll for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP=1 @@ -594,11 +634,12 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, lane_block_row_max_old[i][1] = block_row_max_new_1; } - // NOTE: After compute P @ V, we have to wait next K tile ready in smem. - // do not need to wait any things if kStage == 1. - if constexpr (kStage > 1) { - CP_ASYNC_WAIT_GROUP(0); - __syncthreads(); + if constexpr (kCanPrefetchKVg2s) { + if ((tile_K_seqlen + 1) < Tc) { + // now, we have to wait next K tile ready in smem. + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } } } // end loop over N @@ -690,7 +731,8 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, template void launch_flash_attn_mma_stages_split_q_shared_kv( torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) { - // Tile BrxBc=128x64 + // Now: fixed tile BrxBc=128x64 + // TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size. constexpr int kMmaAtomM = 16; constexpr int kMmaAtomN = 8; constexpr int kMmaAtomK = 16; @@ -712,18 +754,22 @@ void launch_flash_attn_mma_stages_split_q_shared_kv( // static int kMaxSramPerBlock; // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); - // Calculate SRAM size needed per block, Q,K/V smem size, KV shared the same smem. + constexpr int KV_tile_size = ( + ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ? + ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad)) + ); const int smem_max_size = ((Br * (kHeadDim + kPad)) + - (kStage * kHeadDim * (Bc + kPad))) * sizeof(half); + (kStage * KV_tile_size)) * sizeof(half); const int QKV_batch = Q.size(0); const int QKV_head = Q.size(1); const int QKV_seqlen = Q.size(2); // QKV_seqlen - assert(QKV_seqlen % Bc == 0); // multiple of Bc=64 - + assert(QKV_seqlen % max(Br, Bc) == 0); // multiple of max(Br, Bc) + + // TODO: How to apply block swizzle to improve L2 Cache hit rate? dim3 grid(QKV_batch, QKV_head, div_ceil(QKV_seqlen, Br)); // batch_size x num_heads x Tr(=N/Br) - dim3 block(kNumThreads); // 4 warps per block + dim3 block(kNumThreads); // 4/8 warps per block cudaFuncSetAttribute( flash_attn_mma_stages_split_q_shared_kv_kernel< @@ -783,8 +829,24 @@ void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q, const int d = Q.size(3); // B, H, N, d if (stages > 1) { - throw std::runtime_error( - "split_q_shared_kv not support stages>1 now!"); + switch (d) + { + case 32: + launch_flash_attn_mma_stages_split_q_shared_kv<32, 2>(Q, K, V, O); + break; + case 64: + launch_flash_attn_mma_stages_split_q_shared_kv<64, 2>(Q, K, V, O); + break; + case 96: + launch_flash_attn_mma_stages_split_q_shared_kv<96, 2>(Q, K, V, O); + break; + case 128: + launch_flash_attn_mma_stages_split_q_shared_kv<128, 2>(Q, K, V, O); + break; + default: + throw std::runtime_error("headdim not support!"); + break; + } } else { switch (d) { diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu index b4efc69f..64badcb3 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu @@ -165,17 +165,16 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // ---------------------- Registers for S=Q@K^T/O=P@V ---------------------------- // registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d]. - // TODO: Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs. + // Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs. // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store. // Then we can load Q from smem only once and reuse it for // processes. This will reduce large io-access for Q smem while N is large. static_assert(kHeadDim <= 128, "shared_qkv only support headdim<=128"); static_assert(kHeadDim >= 32, "shared_qkv only support headdim>=32"); constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8); // always true. - // Use kStage and (Br / Bc) to control multi-stage policy for K g2s. + // Use kStage and(Q_tile_size / KV_tile_size) to control multi-stage policy for K/V g2s. constexpr bool kCanPrefetchKVg2s = ( ((Q_tile_size / KV_tile_size) >= 2) && (kStage >= 2)); // for d<=64 is true. - // constexpr int kPrefetchStageKg2s = kCanPrefetchKVg2s ? 2 : 1; // only apply stage 2 for k prefetch. constexpr int kPrefetchKg2sSmemId = 0; // smem id for K g2s, 0. constexpr int kPrefetchVg2sSmemId = kCanPrefetchKVg2s ? 1 : 0; // smem id for V g2s, 1. constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1; @@ -186,7 +185,6 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // = Q_tile[Br,d] * K[Bc,d], each thread hold 2x32 bits regs. uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][ 2]; // [1][8][2] // registers for tile_K_seqlen O=PV[Br,d]=P@V, [2][2/4][2], 8 or 16 regs. - // TODO: may reuse R_D as R_O? kWarpTileSeqLenP=kWarpTileSeqLenQ. uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs. uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] @@ -211,41 +209,39 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc] #pragma unroll 1 for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) { - // // TODO: process last tile_K_seqlen ? pad to multiple of 8. - // // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0; - // int smem_sel = (tile_K_seqlen) % kPrefetchStageKg2s; - // // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2; - // int smem_sel_next = (tile_K_seqlen + (kPrefetchStageKg2s - 1)) % kPrefetchStageKg2s; - - // Wait Q ready and let K copy async, then prefetch Q from smem -> regs. - // NOTE: we only need to load Q once from smem -> regs, and then reuse it. + // TODO: process last tile_K_seqlen ? pad to multiple of 8. + + // : Load Q tile from smem -> regs, before Q@K^T. static_assert(kCanPrefetchQs2r); // always prefetch Q s2r. - if (tile_K_seqlen == 0) { - // TODO: Full share QKV smem after Q is ready load to regs. - CP_ASYNC_WAIT_GROUP(0); - __syncthreads(); + if constexpr (kCanPrefetchQs2r) { + // Wait Q ready and let K copy async, then prefetch Q from smem -> regs. + // NOTE: we only need to load Q once from smem -> regs, and then reuse it. + if (tile_K_seqlen == 0) { + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); - #pragma unroll - for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { - // Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs. - // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store. - // Then we can load Q from smem only once and reuse it for - // processes. This will reduce large io-access for Q smem while N is large. #pragma unroll - for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] - int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; - int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 - int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8 - uint32_t lane_smem_Q_ptr = ( - smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) + - lane_smem_Q_d) * sizeof(half) - ); - LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1], - R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3], - lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4] + for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { + // Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs. + // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store. + // Then we can load Q from smem only once and reuse it for + // processes. This will reduce large io-access for Q smem while N is large. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] + int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 + int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_Q_ptr = ( + smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) + + lane_smem_Q_d) * sizeof(half) + ); + LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1], + R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3], + lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4] + } } - } - } + } // end if tile_K_seqlen == 0 + } // end if kCanPrefetchQs2r // Load K tile from gmem -> smem, always use smem part 0. if constexpr (kCanPrefetchKVg2s) { @@ -307,14 +303,31 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] // Matmul with NN layout, Q row major, K row major. // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc] - // + // fill_3D_regs(R_S, 0); #pragma unroll for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { + // smem -> reg, load m16k16 smem Q, offset d according tile_K_d. + // ldmatrix.x4 for Q_tile_smem. + if constexpr (!kCanPrefetchQs2r) { + // load Q from smem -> regs in each loop w/o prefetch Q s2r. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] + int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 + int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_Q_ptr = ( + smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) + + lane_smem_Q_d) * sizeof(half) + ); + LDMATRIX_X4(R_Q[0][i][0], R_Q[0][i][1], R_Q[0][i][2], R_Q[0][i][3], + lane_smem_Q_ptr); // now, R_Q[1][1][4] + } + } // smem -> reg, load k16n8 from smem K, offset d according tile_K_d. // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N] - #pragma unroll - for (int j = 0; j < kWarpTileSeqLenK; ++j) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N) int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K); int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N) @@ -322,20 +335,34 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + lane_smem_K_d * (Bc + kPad) + lane_smem_K_Bc) * sizeof(half) - ); + ); LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K } // end for kWarpTileSeqLenK - - // MMA compute - #pragma unroll - for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + + if constexpr (kCanPrefetchQs2r) { + // MMA compute + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + HMMA16816(R_S[i][j][0], R_S[i][j][1], + R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1], + R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3], + R_K[j][0], R_K[j][1], + R_S[i][j][0], R_S[i][j][1]); + } + } + } else { + // MMA compute #pragma unroll - for (int j = 0; j < kWarpTileSeqLenK; ++j) { - HMMA16816(R_S[i][j][0], R_S[i][j][1], - R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1], - R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3], - R_K[j][0], R_K[j][1], - R_S[i][j][0], R_S[i][j][1]); + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + HMMA16816(R_S[i][j][0], R_S[i][j][1], + R_Q[0][i][0], R_Q[0][i][1], R_Q[0][i][2], R_Q[0][i][3], + R_K[j][0], R_K[j][1], + R_S[i][j][0], R_S[i][j][1]); + } } } } // end loop over d, S=Q@K^T @@ -700,7 +727,8 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, template void launch_flash_attn_mma_stages_split_q_shared_qkv( torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) { - // Tile BrxBc=128x64 + // Now: fixed tile BrxBc=128x64 + // TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size. constexpr int kMmaAtomM = 16; constexpr int kMmaAtomN = 8; constexpr int kMmaAtomK = 16; diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_q.cu b/kernels/flash-attn/mma/flash_attn_mma_split_q.cu index a5f96c23..60c23843 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_split_q.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_split_q.cu @@ -284,7 +284,7 @@ flash_attn_mma_stages_split_q_kernel(half* Q, CP_ASYNC_COMMIT_GROUP(); } - // Then, prefetch curr K tile_K_seqlen [d,Bc] (no stages) + // Then, prefetch curr V tile_K_seqlen [d,Bc] (no stages) { load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;