Skip to content
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
|✔️|✔️|✔️|?|
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**|
|✔️|✔️|✔️|✔️|

Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](#mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (👇Benchmark)

Expand All @@ -66,7 +66,7 @@ Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run
|SDPA(EFFICIENT ATTENTION)|(1,48,8192,512)|16 TFLOPS|58 TFLOPS|85 TFLOPS|
|mma(split-q+tiling-qk+stage2)|(1,48,8192,512)|**23 TFLOPS**|**81 TFLOPS**|**120 TFLOPS**|

The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` method, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).

- 📚 Split KV (Basic, FlashAttention-1)
<div id="mma-split-kv"></div>
Expand Down Expand Up @@ -427,6 +427,7 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| [[cute系列详解][Swizzle]📖cute Swizzle细谈](https://zhuanlan.zhihu.com/p/684250988)|@进击的Killua|
| [[cute系列详解][Swizzle]📖cutlass swizzle机制解析(一)](https://zhuanlan.zhihu.com/p/710337546)|@Titus|
| [[cute系列详解][Swizzle]📖cutlass swizzle机制解析(二)](https://zhuanlan.zhihu.com/p/711398930)|@Titus|
| [[cute系列详解][Swizzle]📖CUDA避免bank conflict的swizzle机制解析](https://zhuanlan.zhihu.com/p/4746910252)|@frankshi|
| [[cute系列详解][GEMM]📖cute 之 简单GEMM实现](https://zhuanlan.zhihu.com/p/667521327)|@reed|
| [[cute系列详解][GEMM]📖cute 之 GEMM流水线](https://zhuanlan.zhihu.com/p/665082713)|@reed|
| [[cute系列详解][GEMM]📖cute 之 高效GEMM实现](https://zhuanlan.zhihu.com/p/675308830)|@reed|
Expand Down
8 changes: 4 additions & 4 deletions kernels/flash-attn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
|✔️|✔️|✔️|✔️|
|Pack LDST (pack 128 bits)|SMEM Padding|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shuffle & Reg Reuse)|**Split KV/Q**|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shfl & Reg Reuse)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**|
|✔️|✔️|✔️|✔️|
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
|✔️|✔️|✔️|?|

This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ (👇Benchmark)

|Algorithm| (B,H,N,D) | NVIDIA GeForce RTX 3080 Laptop | NVIDIA L20 | NVIDIA RTX 4090 |
|Algorithm| (B,H,N,D) | NVIDIA RTX 3080 Laptop | NVIDIA L20 | NVIDIA GeForce RTX 4090 |
|:---:|:---:|:---:|:---:|:---:|
|FlashAttention-2|(1,8,8192,64)|37 TFLOPS|100 TFLOPS|145 TFLOPS|
|mma(split-q+share-qkv+stage2)|(1,8,8192,64)|**55 TFLOPS**|96 TFLOPS|**218 TFLOPS**|
Expand Down
85 changes: 85 additions & 0 deletions kernels/nvidia-nsight/bank_conflicts.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
## Check Bank Conflicts via NCU

- 检查device支持的metrics
```bash
# ncu check bank conflicts
# 先查看当前devices支持的metrics有哪些
ncu --query-metrics | grep data | grep bank | grep l1tex
```
metrics:
```bash
ncu --query-metrics | grep data | grep bank | grep l1tex
l1tex__data_bank_conflicts_pipe_lsu Counter # of data bank conflicts generated by LSU pipe
l1tex__data_bank_conflicts_pipe_lsu_cmd_read Counter # of data bank conflicts generated by LSU reads
l1tex__data_bank_conflicts_pipe_lsu_cmd_write Counter # of data bank conflicts generated by LSU writes
l1tex__data_bank_conflicts_pipe_lsu_mem_global Counter # of data bank conflicts generated by global ops
l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_atom Counter # of data bank conflicts generated by global atomics
l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_ld Counter # of data bank conflicts generated by global loads
l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_red Counter # of data bank conflicts generated by global reductions
l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_st Counter # of data bank conflicts generated by global stores
l1tex__data_bank_conflicts_pipe_lsu_mem_shared Counter # of shared memory data bank conflicts generated by LDS, LD, 3D
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_atom Counter # of shared memory data bank conflicts generated by ATOMS, ATOM
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld Counter # of shared memory data bank conflicts generated by LDS, LD, 3D
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of data bank conflicts generated by shared ldgsts ops
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST, 3D
l1tex__data_bank_reads Counter # of data bank reads
l1tex__data_bank_writes Counter # of data bank writes
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of shared memory data bank conflicts generated by LDGSTS
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of shared memory data bank conflicts generated by LDGSTS.ACCESS
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_bypass Counter # of shared memory data bank conflicts generated by LDGSTS.BYPASS
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm Counter # of shared memory data bank conflicts generated by LDSM
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST
sm__sass_l1tex_data_bank_writes_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of LDGSTS.ACCESS shared data bank writes
smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of shared memory data bank conflicts generated by LDGSTS
smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of shared memory data bank conflicts generated by LDGSTS.ACCESS
smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_bypass Counter # of shared memory data bank conflicts generated by LDGSTS.BYPASS
smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm Counter # of shared memory data bank conflicts generated by LDSM
smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST
smsp__sass_l1tex_data_bank_writes_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of LDGSTS.ACCESS shared data bank writes
```

- 由LD指令产生的bank conflicts
```bash
# profile l1tex smem data bank conflicts
# 由LDS, LD指令产生的bank conflicts
ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum hgemm_mma_stage.89.bin
ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum hgemm_cute.89.debug.bin
ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld \
python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1
```
log:
```bash
void flash_fwd_splitkv_combine_kernel<Flash_fwd_kernel_traits<64, 64, 256, 4, 0, 0, cutlass::half_t, Flash_kernel_traits<64, 64, 256, 4, cutlass::half_t>>, 8, 3, 1>(Flash_fwd_params) (512, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.avg 11.18
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.max 13
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.min 10
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 1029
-------------------------------------------------------- ----------- ------------
```

- 由LDSM指令产生的bank conflicts

```bash
# 由LDSM(ldmatrix)指令产生的bank conflicts
ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm \
python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1
ncu --metrics smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm \
python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1
```
log:
```bash
void flash_fwd_splitkv_combine_kernel<Flash_fwd_kernel_traits<64, 64, 256, 4, 0, 0, cutlass::half_t, Flash_kernel_traits<64, 64, 256, 4, cutlass::half_t>>, 8, 3, 1>(Flash_fwd_params) (512, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
Section: Command line profiler metrics
------------------------------------------------------------------ ----------- ------------
Metric Name Metric Unit Metric Value
------------------------------------------------------------------ ----------- ------------
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 0
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 0
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 0
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 0
------------------------------------------------------------------ ----------- ------------
```
32 changes: 32 additions & 0 deletions kernels/swizzle/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
*.so
*.a
*.dylib
*.dll
*.lib
.DS_Store
build
*.whl
tmp
__pycache__
*.onnx
*.engine
*.pt
*.pth
*.nsys*
*.ncu*
*.sqlite*
*.engine
*.bin
*.out
*bin
bin
output
*.egg-info
*.whl
dist
*.pdf
*.tex
*.log
*.md5
*.aux*
*.dpth
Loading