Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
| ✔️ [hgemv_k16_f16](./hgemv/hgemv.cu)|f16|f16|[link](./hgemv/)|⭐️⭐️⭐️|
| ✔️ [flash_attn_1_fwd_f32](./flash-attn/flash_attn.cu)|f32|f32|[link](./flash-attn)|⭐️⭐️⭐️|
| ✔️ [flash_attn_2_fwd_f16_m16n8k16*](./flash-attn/flash_attn_mma.cu)|f16|f16|[link](./flash-attn)|⭐️⭐️⭐️|
| ✔️ [hard_nms cpp only](./nms/nms.cc)|f32|/|/|⭐️|
| ✔️ [nms_kernel](./nms/nms.cu)|f32|/|[link](./nms)|⭐️⭐️|
| ✔️ [notes v1(deprecated)](./notes-v1.cu)|f32|f32|/|⭐️|

👉TIPS: * means using **Tensor Cores(MMA/WMMA)**, otherwise, using CUDA Cores by default.
Expand Down
2 changes: 1 addition & 1 deletion hgemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

- NVIDIA L20

目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现smem swizzle(受限于WMMA API的灵活性以及row major的layout),后续将会尝试通过MMA PTX和col major的layout实现smem swizzle,[点击查看性能数据](#NV-L20)。
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现smem swizzle/permute(受限于WMMA API的灵活性以及row major的layout),后续将会尝试通过MMA PTX和col major的layout实现smem swizzle/permute,[点击查看性能数据](#NV-L20)。

- NVIDIA GeForce RTX 3080 Laptop

Expand Down
1 change: 1 addition & 0 deletions hgemm/hgemm_mma_stage.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
}
}

// collective store with reg reuse & warp shuffle
for (int i = 0; i < WARP_TILE_M; ++i) {
// reuse RA[2][4][4] reg here, this may boost 0.3~0.5 TFLOPS up.
// may not put 'if' in N loop, it will crash the 'pragma unroll' hint ?
Expand Down