diff --git a/README.md b/README.md
index 4ff3db84..ed849bcf 100644
--- a/README.md
+++ b/README.md
@@ -52,8 +52,8 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
-|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
-|✔️|✔️|✔️|?|
+|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**|
+|✔️|✔️|✔️|✔️|
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](#mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (👇Benchmark)
@@ -66,7 +66,7 @@ Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run
|SDPA(EFFICIENT ATTENTION)|(1,48,8192,512)|16 TFLOPS|58 TFLOPS|85 TFLOPS|
|mma(split-q+tiling-qk+stage2)|(1,48,8192,512)|**23 TFLOPS**|**81 TFLOPS**|**120 TFLOPS**|
-The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
+The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` method, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
- 📚 Split KV (Basic, FlashAttention-1)
@@ -427,6 +427,7 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| [[cute系列详解][Swizzle]📖cute Swizzle细谈](https://zhuanlan.zhihu.com/p/684250988)|@进击的Killua|
| [[cute系列详解][Swizzle]📖cutlass swizzle机制解析(一)](https://zhuanlan.zhihu.com/p/710337546)|@Titus|
| [[cute系列详解][Swizzle]📖cutlass swizzle机制解析(二)](https://zhuanlan.zhihu.com/p/711398930)|@Titus|
+| [[cute系列详解][Swizzle]📖CUDA避免bank conflict的swizzle机制解析](https://zhuanlan.zhihu.com/p/4746910252)|@frankshi|
| [[cute系列详解][GEMM]📖cute 之 简单GEMM实现](https://zhuanlan.zhihu.com/p/667521327)|@reed|
| [[cute系列详解][GEMM]📖cute 之 GEMM流水线](https://zhuanlan.zhihu.com/p/665082713)|@reed|
| [[cute系列详解][GEMM]📖cute 之 高效GEMM实现](https://zhuanlan.zhihu.com/p/675308830)|@reed|
diff --git a/kernels/flash-attn/README.md b/kernels/flash-attn/README.md
index e49eefe4..20c3ee42 100644
--- a/kernels/flash-attn/README.md
+++ b/kernels/flash-attn/README.md
@@ -7,14 +7,14 @@
|✔️|✔️|✔️|✔️|
|Pack LDST (pack 128 bits)|SMEM Padding|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)
|✔️|✔️|✔️|✔️|
-|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shuffle & Reg Reuse)|**Split KV/Q**|
+|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shfl & Reg Reuse)|**Split KV/Q**|
+|✔️|✔️|✔️|✔️|
+|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**|
|✔️|✔️|✔️|✔️|
-|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
-|✔️|✔️|✔️|?|
This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ (👇Benchmark)
-|Algorithm| (B,H,N,D) | NVIDIA GeForce RTX 3080 Laptop | NVIDIA L20 | NVIDIA RTX 4090 |
+|Algorithm| (B,H,N,D) | NVIDIA RTX 3080 Laptop | NVIDIA L20 | NVIDIA GeForce RTX 4090 |
|:---:|:---:|:---:|:---:|:---:|
|FlashAttention-2|(1,8,8192,64)|37 TFLOPS|100 TFLOPS|145 TFLOPS|
|mma(split-q+share-qkv+stage2)|(1,8,8192,64)|**55 TFLOPS**|96 TFLOPS|**218 TFLOPS**|
diff --git a/kernels/nvidia-nsight/bank_conflicts.md b/kernels/nvidia-nsight/bank_conflicts.md
new file mode 100644
index 00000000..ac18c1cf
--- /dev/null
+++ b/kernels/nvidia-nsight/bank_conflicts.md
@@ -0,0 +1,85 @@
+## Check Bank Conflicts via NCU
+
+- 检查device支持的metrics
+```bash
+# ncu check bank conflicts
+# 先查看当前devices支持的metrics有哪些
+ncu --query-metrics | grep data | grep bank | grep l1tex
+```
+metrics:
+```bash
+ncu --query-metrics | grep data | grep bank | grep l1tex
+l1tex__data_bank_conflicts_pipe_lsu Counter # of data bank conflicts generated by LSU pipe
+l1tex__data_bank_conflicts_pipe_lsu_cmd_read Counter # of data bank conflicts generated by LSU reads
+l1tex__data_bank_conflicts_pipe_lsu_cmd_write Counter # of data bank conflicts generated by LSU writes
+l1tex__data_bank_conflicts_pipe_lsu_mem_global Counter # of data bank conflicts generated by global ops
+l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_atom Counter # of data bank conflicts generated by global atomics
+l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_ld Counter # of data bank conflicts generated by global loads
+l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_red Counter # of data bank conflicts generated by global reductions
+l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_st Counter # of data bank conflicts generated by global stores
+l1tex__data_bank_conflicts_pipe_lsu_mem_shared Counter # of shared memory data bank conflicts generated by LDS, LD, 3D
+l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_atom Counter # of shared memory data bank conflicts generated by ATOMS, ATOM
+l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld Counter # of shared memory data bank conflicts generated by LDS, LD, 3D
+l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of data bank conflicts generated by shared ldgsts ops
+l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST, 3D
+l1tex__data_bank_reads Counter # of data bank reads
+l1tex__data_bank_writes Counter # of data bank writes
+sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of shared memory data bank conflicts generated by LDGSTS
+sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of shared memory data bank conflicts generated by LDGSTS.ACCESS
+sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_bypass Counter # of shared memory data bank conflicts generated by LDGSTS.BYPASS
+sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm Counter # of shared memory data bank conflicts generated by LDSM
+sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST
+sm__sass_l1tex_data_bank_writes_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of LDGSTS.ACCESS shared data bank writes
+smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of shared memory data bank conflicts generated by LDGSTS
+smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of shared memory data bank conflicts generated by LDGSTS.ACCESS
+smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_bypass Counter # of shared memory data bank conflicts generated by LDGSTS.BYPASS
+smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm Counter # of shared memory data bank conflicts generated by LDSM
+smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST
+smsp__sass_l1tex_data_bank_writes_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of LDGSTS.ACCESS shared data bank writes
+```
+
+- 由LD指令产生的bank conflicts
+```bash
+# profile l1tex smem data bank conflicts
+# 由LDS, LD指令产生的bank conflicts
+ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum hgemm_mma_stage.89.bin
+ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum hgemm_cute.89.debug.bin
+ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld \
+ python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1
+```
+log:
+```bash
+void flash_fwd_splitkv_combine_kernel>, 8, 3, 1>(Flash_fwd_params) (512, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
+ Section: Command line profiler metrics
+ -------------------------------------------------------- ----------- ------------
+ Metric Name Metric Unit Metric Value
+ -------------------------------------------------------- ----------- ------------
+ l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.avg 11.18
+ l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.max 13
+ l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.min 10
+ l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 1029
+ -------------------------------------------------------- ----------- ------------
+```
+
+- 由LDSM指令产生的bank conflicts
+
+```bash
+# 由LDSM(ldmatrix)指令产生的bank conflicts
+ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm \
+ python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1
+ncu --metrics smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm \
+ python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1
+```
+log:
+```bash
+void flash_fwd_splitkv_combine_kernel>, 8, 3, 1>(Flash_fwd_params) (512, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
+ Section: Command line profiler metrics
+ ------------------------------------------------------------------ ----------- ------------
+ Metric Name Metric Unit Metric Value
+ ------------------------------------------------------------------ ----------- ------------
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 0
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 0
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 0
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 0
+ ------------------------------------------------------------------ ----------- ------------
+```
diff --git a/kernels/swizzle/.gitignore b/kernels/swizzle/.gitignore
new file mode 100644
index 00000000..2ef2ddf9
--- /dev/null
+++ b/kernels/swizzle/.gitignore
@@ -0,0 +1,32 @@
+*.so
+*.a
+*.dylib
+*.dll
+*.lib
+.DS_Store
+build
+*.whl
+tmp
+__pycache__
+*.onnx
+*.engine
+*.pt
+*.pth
+*.nsys*
+*.ncu*
+*.sqlite*
+*.engine
+*.bin
+*.out
+*bin
+bin
+output
+*.egg-info
+*.whl
+dist
+*.pdf
+*.tex
+*.log
+*.md5
+*.aux*
+*.dpth
diff --git a/kernels/swizzle/hgemm_mma_swizzle.cu b/kernels/swizzle/hgemm_mma_swizzle.cu
new file mode 100644
index 00000000..37fbd6e4
--- /dev/null
+++ b/kernels/swizzle/hgemm_mma_swizzle.cu
@@ -0,0 +1,333 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+using namespace nvcuda;
+
+#define WARP_SIZE 32
+#define DEVICE_INLINE __device__ inline
+#define HOST_DEVICE_INLINE __device__ __host__ inline
+#define INT4(value) (reinterpret_cast(&(value))[0])
+#define FLOAT4(value) (reinterpret_cast(&(value))[0])
+#define HALF2(value) (reinterpret_cast(&(value))[0])
+#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
+#define LDST32BITS(value) (reinterpret_cast(&(value))[0])
+#define LDST64BITS(value) (reinterpret_cast(&(value))[0])
+#define LDST128BITS(value) (reinterpret_cast(&(value))[0])
+#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::)
+#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::)
+#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n))
+// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes.
+#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
+#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
+#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
+#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
+#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
+#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
+#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
+#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
+#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1))
+
+HOST_DEVICE_INLINE
+int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); }
+
+// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major.
+template
+__global__ void hgemm_mma_m16n8k16_naive_kernel(half* A, half* B, half* C,
+ int M, int N, int K) {
+ const int bx = blockIdx.x;
+ const int by = blockIdx.y;
+ const int NUM_K_TILES = div_ceil(K, MMA_K);
+ constexpr int BM = MMA_M; // 16
+ constexpr int BN = MMA_N; // 8
+ constexpr int BK = MMA_K; // 16
+
+ __shared__ half s_a[MMA_M][MMA_K]; // 16x16
+ __shared__ half s_b[MMA_K][MMA_N]; // 16x8
+ __shared__ half s_c[MMA_M][MMA_N]; // 16x8
+
+ const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block
+ const int lane_id = tid % WARP_SIZE; // 0~31
+
+ // s_a[16][16], 每行16,每线程load 8,需要2线程,共16行,需2x16=32线程
+ const int load_smem_a_m = tid / 2; // row 0~15
+ const int load_smem_a_k = (tid % 2) * 8; // col 0,8
+ // s_b[16][8], 每行8,每线程load 8,需要1线程,共16行,需16线程,只需一半线程加载
+ const int load_smem_b_k = tid; // row 0~31, but only use 0~15
+ const int load_smem_b_n = 0; // col 0
+ const int load_gmem_a_m = by * BM + load_smem_a_m; // global m
+ const int load_gmem_b_n = bx * BN + load_smem_b_n; // global n
+ if (load_gmem_a_m >= M && load_gmem_b_n >= N) return;
+
+ uint32_t RC[2] = {0, 0};
+
+ #pragma unroll
+ for (int k = 0; k < NUM_K_TILES; ++k) {
+ // gmem_a -> smem_a
+ int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a
+ int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
+ LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = (
+ LDST128BITS(A[load_gmem_a_addr]));
+
+ // gmem_b -> smem_b
+ if (lane_id < MMA_K) {
+ int load_gmem_b_k = k * MMA_K + load_smem_b_k; // global row of b
+ int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
+ LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = (
+ LDST128BITS(B[load_gmem_b_addr]));
+ }
+ __syncthreads();
+
+ uint32_t RA[4];
+ uint32_t RB[2];
+
+ // ldmatrix for s_a, ldmatrix.trans for s_b.
+ // s_a: (0,1)*8 -> 0,8 -> [(0~15),(0,8)]
+ uint32_t load_smem_a_ptr = __cvta_generic_to_shared(
+ &s_a[lane_id % 16][(lane_id / 16) * 8]);
+ LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], load_smem_a_ptr);
+ uint32_t load_smem_b_ptr = __cvta_generic_to_shared(
+ &s_b[lane_id % 16][0]);
+ LDMATRIX_X2_T(RB[0], RB[1], load_smem_b_ptr);
+
+ HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0], RC[1]);
+
+ __syncthreads();
+ }
+
+ // s_c[16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+ // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16]
+ LDST32BITS(s_c[lane_id / 4 ][(lane_id % 4) * 2]) = LDST32BITS(RC[0]);
+ LDST32BITS(s_c[lane_id / 4 + 8][(lane_id % 4) * 2]) = LDST32BITS(RC[1]);
+
+ __syncthreads();
+
+ // store s_c[16][8]
+ if (lane_id < MMA_M) {
+ // store 128 bits per memory issue.
+ int store_gmem_c_m = by * BM + lane_id;
+ int store_gmem_c_n = bx * BN;
+ int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n;
+ LDST128BITS(C[store_gmem_c_addr]) = (LDST128BITS(s_c[lane_id][0]));
+ }
+}
+
+// 128x128, mma2x4, warp4x4(64,32,16)
+template
+__global__ void __launch_bounds__(256)
+hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel(
+ half* A, half* B, half* C, int M, int N, int K) {
+ const int bx = blockIdx.x;
+ const int by = blockIdx.y;
+ const int NUM_K_TILES = div_ceil(K, MMA_K);
+ constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128
+ constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128
+ constexpr int BK = MMA_K; // 16
+
+ __shared__ half s_a[BM][BK+A_PAD]; // 128*16*2=4KB
+ __shared__ half s_b[BK][BN+B_PAD]; // 16*128*2=4KB, 16*(128+16)*2=4.5KB
+
+ const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block
+ const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
+ const int lane_id = tid % WARP_SIZE; // 0~31
+ const int warp_m = warp_id % 2; // 0,1
+ const int warp_n = warp_id / 2; // 0,1,2,3
+
+ // 先计算shared memory中的索引
+ // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=16 按行读取 A行主序
+ // 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程
+ int load_smem_a_m = tid / 2; // row 0~127
+ int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8
+ // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序
+ // 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程
+ int load_smem_b_k = tid / 16; // row 0~15
+ int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120
+ // 再计算全局内存中的索引
+ // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
+ int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
+ int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
+ if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
+
+ uint32_t RC[WARP_TILE_M][WARP_TILE_N][2];
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ RC[i][j][0] = 0;
+ RC[i][j][1] = 0;
+ }
+ }
+
+ #pragma unroll
+ for (int k = 0; k < NUM_K_TILES; ++k) {
+ // gmem -> smem
+ int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a
+ int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
+ int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b
+ int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
+ LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = (
+ LDST128BITS(B[load_gmem_b_addr]));
+ LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = (
+ LDST128BITS(A[load_gmem_a_addr]));
+ __syncthreads();
+
+ // ldmatrix for s_a, ldmatrix.trans for s_b.
+ uint32_t RA[WARP_TILE_M][4];
+ uint32_t RB[WARP_TILE_N][2];
+
+ // smem -> reg
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
+ int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
+ int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_a_ptr = __cvta_generic_to_shared(
+ &s_a[lane_smem_a_m][lane_smem_a_k]);
+ LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr);
+ }
+
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
+ int lane_smem_b_k = lane_id % 16; // 0~15
+ int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
+ uint32_t lane_smem_b_ptr = __cvta_generic_to_shared(
+ &s_b[lane_smem_b_k][lane_smem_b_n]);
+ LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr);
+ }
+
+ // MMA compute
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ HMMA16816(RC[i][j][0], RC[i][j][1],
+ RA[i][0], RA[i][1], RA[i][2], RA[i][3],
+ RB[j][0], RB[j][1],
+ RC[i][j][0], RC[i][j][1]);
+ }
+ }
+ __syncthreads();
+ }
+
+ // reg -> gmem, MMA_MxMMA_N=16x8
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
+ int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
+ // mapping lane smem index -> global index.
+ // [16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+ // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16]
+ int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4;
+ int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n + (lane_id % 4) * 2;
+ int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n;
+ int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n;
+ // TODO: how to use LDST128BITS here ? reverse the loop order ?
+ LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[i][j][0]);
+ LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[i][j][1]);
+ }
+ }
+}
+
+
+// --------------------- PyTorch bindings for custom kernel -----------------------
+#define STRINGFY(str) #str
+#define TORCH_BINDING_COMMON_EXTENSION(func) \
+ m.def(STRINGFY(func), &func, STRINGFY(func));
+
+#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
+if(((T).options().dtype() != (th_type))) { \
+ std::cout << "Tensor Info:" << (T).options() << std::endl; \
+ throw std::runtime_error("values must be "#th_type); \
+}
+
+#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \
+if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
+ throw std::runtime_error("Tensor size mismatch!"); \
+}
+
+// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major.
+void hgemm_mma_m16n8k16_naive(
+ torch::Tensor a, torch::Tensor b, torch::Tensor c) {
+ CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
+ CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
+ CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
+ const int M = a.size(0);
+ const int K = a.size(1);
+ const int N = b.size(1);
+ CHECK_TORCH_TENSOR_SHAPE(a, M, K)
+ CHECK_TORCH_TENSOR_SHAPE(b, K, N)
+ CHECK_TORCH_TENSOR_SHAPE(c, M, N)
+ constexpr int MMA_M = 16;
+ constexpr int MMA_N = 8;
+ constexpr int MMA_K = 16;
+
+ dim3 block(WARP_SIZE);
+ dim3 grid(div_ceil(N, MMA_N), div_ceil(M, MMA_M));
+
+ hgemm_mma_m16n8k16_naive_kernel<
+ MMA_M, MMA_N, MMA_K><<>>(
+ reinterpret_cast(a.data_ptr()),
+ reinterpret_cast(b.data_ptr()),
+ reinterpret_cast(c.data_ptr()),
+ M, N, K
+ );
+}
+
+// 128x128, mma2x4, warp4x4(64,32,16)
+void hgemm_mma_m16n8k16_mma2x4_warp4x4(
+ torch::Tensor a, torch::Tensor b, torch::Tensor c) {
+ CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
+ CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
+ CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
+ const int M = a.size(0);
+ const int K = a.size(1);
+ const int N = b.size(1);
+ CHECK_TORCH_TENSOR_SHAPE(a, M, K)
+ CHECK_TORCH_TENSOR_SHAPE(b, K, N)
+ CHECK_TORCH_TENSOR_SHAPE(c, M, N)
+ constexpr int MMA_M = 16;
+ constexpr int MMA_N = 8;
+ constexpr int MMA_K = 16;
+ constexpr int MMA_TILE_M = 2;
+ constexpr int MMA_TILE_N = 4;
+ constexpr int WARP_TILE_M = 4;
+ constexpr int WARP_TILE_N = 4;
+ constexpr int A_PAD = 0;
+ constexpr int B_PAD = 16;
+ constexpr int NUM_THREADS= (
+ MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
+
+ dim3 block(NUM_THREADS);
+ dim3 grid(div_ceil(N, MMA_N * MMA_TILE_N * WARP_TILE_N),
+ div_ceil(M, MMA_M * MMA_TILE_M * WARP_TILE_M));
+
+ hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel<
+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N,
+ WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD><<>>(
+ reinterpret_cast(a.data_ptr()),
+ reinterpret_cast(b.data_ptr()),
+ reinterpret_cast(c.data_ptr()),
+ M, N, K
+ );
+}
diff --git a/kernels/swizzle/matrix_trans_swizzle.cu b/kernels/swizzle/matrix_trans_swizzle.cu
new file mode 100644
index 00000000..105dbf23
--- /dev/null
+++ b/kernels/swizzle/matrix_trans_swizzle.cu
@@ -0,0 +1,35 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+using namespace nvcuda;
+
+// reference: https://zhuanlan.zhihu.com/p/4746910252
+__global__ void matrix_trans_swizzling(int* dev_A, int M, int N, int* dev_B) {
+ int row = blockIdx.y * blockDim.y + threadIdx.y;
+ int col = blockIdx.x * blockDim.x + threadIdx.x;
+
+ __shared__ int s_data[32][32];
+
+ if (row < M && col < N) {
+ // 从全局内存读取数据写入共享内存的逻辑坐标(row=x,col=y)
+ // 其映射的物理存储位置位置(row=x,col=x^y)
+ s_data[threadIdx.x][threadIdx.x ^ threadIdx.y] = dev_A[row * N + col];
+ __syncthreads();
+ int n_col = blockIdx.y * blockDim.y + threadIdx.x;
+ int n_row = blockIdx.x * blockDim.x + threadIdx.y;
+ if (n_row < N && n_col < M) {
+ // 从共享内存的逻辑坐标(row=y,col=x)读取数据
+ // 其映射的物理存储位置(row=y,col=x^y)
+ dev_B[n_row * M + n_col] = s_data[threadIdx.y][threadIdx.x ^ threadIdx.y];
+ }
+ }
+}