Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
|Tensor Cores|Loop over Seqlen/Headdim |Tile Block (Br, Bc)|MMA (m16n8k16)|
|:---:|:---:|:---:|:---:|
|✔️|✔️|✔️|✔️|
|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)
|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)|
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
|**Shared KV** SMEM|Fully **Shared QKV** SMEM|**Prefetch Q** s2r|SMEM/Block Swizzle|
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
|✔️|✔️|✔️|?|

Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop):
Expand Down Expand Up @@ -133,7 +133,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
<div id="mma-share-qkv"></div>

```C++
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy.
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access.
__global__ void
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
half* K,
Expand Down
4 changes: 2 additions & 2 deletions kernels/flash-attn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shuffle & Reg Reuse)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
|**Shared KV** SMEM|Fully **Shared QKV** SMEM|**Prefetch Q** s2r|SMEM/Block Swizzle|
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
|✔️|✔️|✔️|?|

This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop. However, for large-scale attention computations, there remains a performance gap. Performance optimizations are ongoing; stay tuned for updates.
Expand Down Expand Up @@ -107,7 +107,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
<div id="mma-share-qkv"></div>

```C++
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy.
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access.
__global__ void
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
half* K,
Expand Down
88 changes: 55 additions & 33 deletions kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
static_assert(kHeadDim >= 32, "shared_qkv only support headdim>=32");
constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8); // always true.
// Use kStage and (Br / Bc) to control multi-stage policy for K g2s.
constexpr bool kCanPrefetchKg2s = (
constexpr bool kCanPrefetchKVg2s = (
((Q_tile_size / KV_tile_size) >= 2) && (kStage >= 2)); // for d<=64 is true.
constexpr int kPrefetchStageKg2s = kCanPrefetchKg2s ? 2 : 1; // only apply stage 2 for k prefetch.
// constexpr int kPrefetchStageKg2s = kCanPrefetchKVg2s ? 2 : 1; // only apply stage 2 for k prefetch.
constexpr int kPrefetchKg2sSmemId = 0; // smem id for K g2s, 0.
constexpr int kPrefetchVg2sSmemId = kCanPrefetchKVg2s ? 1 : 0; // smem id for V g2s, 1.
constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1;
uint32_t R_Q[kNumPrefetchQs2r][kWarpTileSeqLenQ][4]; // [4/8/1][1][4]
uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2]
Expand Down Expand Up @@ -209,11 +211,11 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
// tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc]
#pragma unroll 1
for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) {
// TODO: process last tile_K_seqlen ? pad to multiple of 8.
// s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0;
int smem_sel = (tile_K_seqlen) % kPrefetchStageKg2s;
// s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2;
int smem_sel_next = (tile_K_seqlen + (kPrefetchStageKg2s - 1)) % kPrefetchStageKg2s;
// // TODO: process last tile_K_seqlen ? pad to multiple of 8.
// // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0;
// int smem_sel = (tile_K_seqlen) % kPrefetchStageKg2s;
// // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2;
// int smem_sel_next = (tile_K_seqlen + (kPrefetchStageKg2s - 1)) % kPrefetchStageKg2s;

// Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
// NOTE: we only need to load Q once from smem -> regs, and then reuse it.
Expand Down Expand Up @@ -245,17 +247,17 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
}
}

// Load K tile from gmem -> smem
if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
// Load K tile from gmem -> smem, always use smem part 0.
if constexpr (kCanPrefetchKVg2s) {
if (tile_K_seqlen == 0) {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (smem_sel * KV_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
#pragma unroll
for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
Expand All @@ -265,15 +267,33 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
// <Prefetch V g2s>: Load V tile async from gmem -> smem 1, before Q@K^T
{
load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
int load_gmem_V_d = load_smem_V_d;
int load_gmem_V_addr = (
V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
uint32_t load_smem_V_ptr = (
smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
load_smem_V_Bc * (kHeadDim + kPad) +
load_smem_V_d) * sizeof(half)
);
#pragma unroll
for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
}
} else {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (smem_sel * KV_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
#pragma unroll
for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
Expand All @@ -287,6 +307,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
// <loop over K d>: tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
// Matmul with NN layout, Q row major, K row major.
// S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]
// <HGEMM in shared memory>
fill_3D_regs<uint32_t, kWarpTileSeqLenQ, kWarpTileSeqLenK, 2>(R_S, 0);
#pragma unroll
for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
Expand All @@ -298,7 +319,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K);
int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N)
uint32_t lane_smem_K_ptr = (
smem_K_base_ptr + (smem_sel * KV_tile_size +
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
lane_smem_K_d * (Bc + kPad) +
lane_smem_K_Bc) * sizeof(half)
);
Expand All @@ -320,17 +341,16 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
} // end loop over d, S=Q@K^T
__syncthreads();

// Then, async prefetch curr V tile_K_seqlen [Bc,d] (no stages),
// before rowmax and rowsum, load V from gmem -> smem.
// TODO: Can we support stages 2 for V g2s?
{
// <w/o Prefetch V g2s>: If kCanPrefetchKVg2s is not enable,
// we will load V g2s here, before rowmax and rowsum.
if constexpr (!kCanPrefetchKVg2s) {
load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
int load_gmem_V_d = load_smem_V_d;
int load_gmem_V_addr = (
V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
uint32_t load_smem_V_ptr = (
smem_V_base_ptr + (smem_sel * KV_tile_size +
smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
load_smem_V_Bc * (kHeadDim + kPad) +
load_smem_V_d) * sizeof(half)
);
Expand All @@ -341,14 +361,15 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
CP_ASYNC_COMMIT_GROUP();
}

if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
// <Prefetch K g2s>: load next K tile from gmem -> smem 0, before P@V.
if constexpr (kCanPrefetchKVg2s) {
if ((tile_K_seqlen + 1) < Tc) {
load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (smem_sel_next * KV_tile_size +
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
#pragma unroll
Expand Down Expand Up @@ -407,7 +428,6 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
lane_row_max_new[i][0] = warp_reduce_max<float, 4>(lane_row_max_new[i][0]);
lane_row_max_new[i][1] = warp_reduce_max<float, 4>(lane_row_max_new[i][1]);
} // end for kWarpTileSeqLenQ
// __syncthreads();

// Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block.
#pragma unroll
Expand Down Expand Up @@ -447,11 +467,11 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,

// Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention.
// Here, we have to wait V ready before compute O = P @ V
if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
if constexpr (kCanPrefetchKVg2s) {
if ((tile_K_seqlen + 1) < Tc) {
CP_ASYNC_WAIT_GROUP(1);
CP_ASYNC_WAIT_GROUP(1); // we have send V & K g2s, wait V and let K async.
} else {
CP_ASYNC_WAIT_GROUP(0);
CP_ASYNC_WAIT_GROUP(0); // we have only send V g2s.
}
} else {
CP_ASYNC_WAIT_GROUP(0);
Expand Down Expand Up @@ -480,6 +500,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
// ...
// 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7}

// <HGEMM in registers>
fill_3D_regs<uint32_t, kWarpTileSeqLenP, kWarpTileHeadDimV, 2>(R_O, 0);
#pragma unroll
for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) {
Expand All @@ -490,7 +511,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
int lane_smem_V_Bc = tile_V_Bc * kMmaAtomK + lane_id % 16; // 0~15; Bc, matmul K
int lane_smem_V_d = warp_smem_V_d; // 0
uint32_t lane_smem_V_ptr = (
smem_V_base_ptr + (smem_sel * KV_tile_size +
smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
lane_smem_V_Bc * (kHeadDim + kPad) +
lane_smem_V_d) * sizeof(half)
);
Expand Down Expand Up @@ -582,9 +603,10 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
lane_block_row_max_old[i][1] = block_row_max_new_1;
}

if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
if constexpr (kCanPrefetchKVg2s) {
if ((tile_K_seqlen + 1) < Tc) {
CP_ASYNC_WAIT_GROUP(0);
// now, we have to wait next K tile ready in smem.
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
}
Expand Down Expand Up @@ -619,7 +641,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
#pragma unroll
for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8

// reuse R_Q regs if kCanPrefetchQs2r is enable, reduce registers usage.
if constexpr (kCanPrefetchQs2r && kNumPrefetchQs2r > 1) { // always true for shared qkv kernel
// reuse R_Q[4/8][1][4] for collective store.
R_Q[0][0][0] = R_D[i][j][0]; R_Q[1][0][0] = R_D[i][j][1]; // warp_size 4
Expand Down Expand Up @@ -703,15 +725,15 @@ void launch_flash_attn_mma_stages_split_q_shared_qkv(

// static int kMaxSramPerBlock;
// cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);

// Calculate SRAM size needed per block, QKV smem size, QKV fully shared the same smem.
const int smem_max_size = (Br * (kHeadDim + kPad)) * sizeof(half); // 128x(32/64/128)x2/1024=8/16/32M

const int QKV_batch = Q.size(0);
const int QKV_head = Q.size(1);
const int QKV_seqlen = Q.size(2); // QKV_seqlen
assert(QKV_seqlen % Bc == 0); // multiple of Bc=64
assert(QKV_seqlen % max(Br, Bc) == 0); // multiple of max(Br, Bc)

// TODO: How to apply block swizzle to improve L2 Cache hit rate?
dim3 grid(QKV_batch, QKV_head, div_ceil(QKV_seqlen, Br)); // batch_size x num_heads x Tr(=N/Br)
dim3 block(kNumThreads); // 4/8 warps per block

Expand Down