diff --git a/README.md b/README.md index 4ff3db84..ed849bcf 100644 --- a/README.md +++ b/README.md @@ -52,8 +52,8 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh |✔️|✔️|✔️|✔️| |Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**| |✔️|✔️|✔️|✔️| -|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle| -|✔️|✔️|✔️|?| +|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**| +|✔️|✔️|✔️|✔️| Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](#mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (👇Benchmark) @@ -66,7 +66,7 @@ Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run |SDPA(EFFICIENT ATTENTION)|(1,48,8192,512)|16 TFLOPS|58 TFLOPS|85 TFLOPS| |mma(split-q+tiling-qk+stage2)|(1,48,8192,512)|**23 TFLOPS**|**81 TFLOPS**|**120 TFLOPS**| -The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps). +The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` method, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps). - 📚 Split KV (Basic, FlashAttention-1)
@@ -427,6 +427,7 @@ The kernels listed here will guide you through a step-by-step progression, rangi | [[cute系列详解][Swizzle]📖cute Swizzle细谈](https://zhuanlan.zhihu.com/p/684250988)|@进击的Killua| | [[cute系列详解][Swizzle]📖cutlass swizzle机制解析(一)](https://zhuanlan.zhihu.com/p/710337546)|@Titus| | [[cute系列详解][Swizzle]📖cutlass swizzle机制解析(二)](https://zhuanlan.zhihu.com/p/711398930)|@Titus| +| [[cute系列详解][Swizzle]📖CUDA避免bank conflict的swizzle机制解析](https://zhuanlan.zhihu.com/p/4746910252)|@frankshi| | [[cute系列详解][GEMM]📖cute 之 简单GEMM实现](https://zhuanlan.zhihu.com/p/667521327)|@reed| | [[cute系列详解][GEMM]📖cute 之 GEMM流水线](https://zhuanlan.zhihu.com/p/665082713)|@reed| | [[cute系列详解][GEMM]📖cute 之 高效GEMM实现](https://zhuanlan.zhihu.com/p/675308830)|@reed| diff --git a/kernels/flash-attn/README.md b/kernels/flash-attn/README.md index e49eefe4..20c3ee42 100644 --- a/kernels/flash-attn/README.md +++ b/kernels/flash-attn/README.md @@ -7,14 +7,14 @@ |✔️|✔️|✔️|✔️| |Pack LDST (pack 128 bits)|SMEM Padding|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads) |✔️|✔️|✔️|✔️| -|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shuffle & Reg Reuse)|**Split KV/Q**| +|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shfl & Reg Reuse)|**Split KV/Q**| +|✔️|✔️|✔️|✔️| +|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**| |✔️|✔️|✔️|✔️| -|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle| -|✔️|✔️|✔️|?| This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ (👇Benchmark) -|Algorithm| (B,H,N,D) | NVIDIA GeForce RTX 3080 Laptop | NVIDIA L20 | NVIDIA RTX 4090 | +|Algorithm| (B,H,N,D) | NVIDIA RTX 3080 Laptop | NVIDIA L20 | NVIDIA GeForce RTX 4090 | |:---:|:---:|:---:|:---:|:---:| |FlashAttention-2|(1,8,8192,64)|37 TFLOPS|100 TFLOPS|145 TFLOPS| |mma(split-q+share-qkv+stage2)|(1,8,8192,64)|**55 TFLOPS**|96 TFLOPS|**218 TFLOPS**| diff --git a/kernels/nvidia-nsight/bank_conflicts.md b/kernels/nvidia-nsight/bank_conflicts.md new file mode 100644 index 00000000..ac18c1cf --- /dev/null +++ b/kernels/nvidia-nsight/bank_conflicts.md @@ -0,0 +1,85 @@ +## Check Bank Conflicts via NCU + +- 检查device支持的metrics +```bash +# ncu check bank conflicts +# 先查看当前devices支持的metrics有哪些 +ncu --query-metrics | grep data | grep bank | grep l1tex +``` +metrics: +```bash +ncu --query-metrics | grep data | grep bank | grep l1tex +l1tex__data_bank_conflicts_pipe_lsu Counter # of data bank conflicts generated by LSU pipe +l1tex__data_bank_conflicts_pipe_lsu_cmd_read Counter # of data bank conflicts generated by LSU reads +l1tex__data_bank_conflicts_pipe_lsu_cmd_write Counter # of data bank conflicts generated by LSU writes +l1tex__data_bank_conflicts_pipe_lsu_mem_global Counter # of data bank conflicts generated by global ops +l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_atom Counter # of data bank conflicts generated by global atomics +l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_ld Counter # of data bank conflicts generated by global loads +l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_red Counter # of data bank conflicts generated by global reductions +l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_st Counter # of data bank conflicts generated by global stores +l1tex__data_bank_conflicts_pipe_lsu_mem_shared Counter # of shared memory data bank conflicts generated by LDS, LD, 3D +l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_atom Counter # of shared memory data bank conflicts generated by ATOMS, ATOM +l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld Counter # of shared memory data bank conflicts generated by LDS, LD, 3D +l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of data bank conflicts generated by shared ldgsts ops +l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST, 3D +l1tex__data_bank_reads Counter # of data bank reads +l1tex__data_bank_writes Counter # of data bank writes +sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of shared memory data bank conflicts generated by LDGSTS +sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of shared memory data bank conflicts generated by LDGSTS.ACCESS +sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_bypass Counter # of shared memory data bank conflicts generated by LDGSTS.BYPASS +sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm Counter # of shared memory data bank conflicts generated by LDSM +sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST +sm__sass_l1tex_data_bank_writes_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of LDGSTS.ACCESS shared data bank writes +smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of shared memory data bank conflicts generated by LDGSTS +smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of shared memory data bank conflicts generated by LDGSTS.ACCESS +smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_bypass Counter # of shared memory data bank conflicts generated by LDGSTS.BYPASS +smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm Counter # of shared memory data bank conflicts generated by LDSM +smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST +smsp__sass_l1tex_data_bank_writes_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of LDGSTS.ACCESS shared data bank writes +``` + +- 由LD指令产生的bank conflicts +```bash +# profile l1tex smem data bank conflicts +# 由LDS, LD指令产生的bank conflicts +ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum hgemm_mma_stage.89.bin +ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum hgemm_cute.89.debug.bin +ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld \ + python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1 +``` +log: +```bash +void flash_fwd_splitkv_combine_kernel>, 8, 3, 1>(Flash_fwd_params) (512, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.9 + Section: Command line profiler metrics + -------------------------------------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + -------------------------------------------------------- ----------- ------------ + l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.avg 11.18 + l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.max 13 + l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.min 10 + l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 1029 + -------------------------------------------------------- ----------- ------------ +``` + +- 由LDSM指令产生的bank conflicts + +```bash +# 由LDSM(ldmatrix)指令产生的bank conflicts +ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm \ + python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1 +ncu --metrics smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm \ + python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1 +``` +log: +```bash +void flash_fwd_splitkv_combine_kernel>, 8, 3, 1>(Flash_fwd_params) (512, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.9 + Section: Command line profiler metrics + ------------------------------------------------------------------ ----------- ------------ + Metric Name Metric Unit Metric Value + ------------------------------------------------------------------ ----------- ------------ + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 0 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 0 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 0 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 0 + ------------------------------------------------------------------ ----------- ------------ +``` diff --git a/kernels/swizzle/.gitignore b/kernels/swizzle/.gitignore new file mode 100644 index 00000000..2ef2ddf9 --- /dev/null +++ b/kernels/swizzle/.gitignore @@ -0,0 +1,32 @@ +*.so +*.a +*.dylib +*.dll +*.lib +.DS_Store +build +*.whl +tmp +__pycache__ +*.onnx +*.engine +*.pt +*.pth +*.nsys* +*.ncu* +*.sqlite* +*.engine +*.bin +*.out +*bin +bin +output +*.egg-info +*.whl +dist +*.pdf +*.tex +*.log +*.md5 +*.aux* +*.dpth diff --git a/kernels/swizzle/hgemm_mma_swizzle.cu b/kernels/swizzle/hgemm_mma_swizzle.cu new file mode 100644 index 00000000..37fbd6e4 --- /dev/null +++ b/kernels/swizzle/hgemm_mma_swizzle.cu @@ -0,0 +1,333 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace nvcuda; + +#define WARP_SIZE 32 +#define DEVICE_INLINE __device__ inline +#define HOST_DEVICE_INLINE __device__ __host__ inline +#define INT4(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define HALF2(value) (reinterpret_cast(&(value))[0]) +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) +#define LDST32BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST64BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::) +#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::) +#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n)) +// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. +#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) + +HOST_DEVICE_INLINE +int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } + +// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major. +template +__global__ void hgemm_mma_m16n8k16_naive_kernel(half* A, half* B, half* C, + int M, int N, int K) { + const int bx = blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, MMA_K); + constexpr int BM = MMA_M; // 16 + constexpr int BN = MMA_N; // 8 + constexpr int BK = MMA_K; // 16 + + __shared__ half s_a[MMA_M][MMA_K]; // 16x16 + __shared__ half s_b[MMA_K][MMA_N]; // 16x8 + __shared__ half s_c[MMA_M][MMA_N]; // 16x8 + + const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block + const int lane_id = tid % WARP_SIZE; // 0~31 + + // s_a[16][16], 每行16,每线程load 8,需要2线程,共16行,需2x16=32线程 + const int load_smem_a_m = tid / 2; // row 0~15 + const int load_smem_a_k = (tid % 2) * 8; // col 0,8 + // s_b[16][8], 每行8,每线程load 8,需要1线程,共16行,需16线程,只需一半线程加载 + const int load_smem_b_k = tid; // row 0~31, but only use 0~15 + const int load_smem_b_n = 0; // col 0 + const int load_gmem_a_m = by * BM + load_smem_a_m; // global m + const int load_gmem_b_n = bx * BN + load_smem_b_n; // global n + if (load_gmem_a_m >= M && load_gmem_b_n >= N) return; + + uint32_t RC[2] = {0, 0}; + + #pragma unroll + for (int k = 0; k < NUM_K_TILES; ++k) { + // gmem_a -> smem_a + int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = ( + LDST128BITS(A[load_gmem_a_addr])); + + // gmem_b -> smem_b + if (lane_id < MMA_K) { + int load_gmem_b_k = k * MMA_K + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = ( + LDST128BITS(B[load_gmem_b_addr])); + } + __syncthreads(); + + uint32_t RA[4]; + uint32_t RB[2]; + + // ldmatrix for s_a, ldmatrix.trans for s_b. + // s_a: (0,1)*8 -> 0,8 -> [(0~15),(0,8)] + uint32_t load_smem_a_ptr = __cvta_generic_to_shared( + &s_a[lane_id % 16][(lane_id / 16) * 8]); + LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], load_smem_a_ptr); + uint32_t load_smem_b_ptr = __cvta_generic_to_shared( + &s_b[lane_id % 16][0]); + LDMATRIX_X2_T(RB[0], RB[1], load_smem_b_ptr); + + HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0], RC[1]); + + __syncthreads(); + } + + // s_c[16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] + LDST32BITS(s_c[lane_id / 4 ][(lane_id % 4) * 2]) = LDST32BITS(RC[0]); + LDST32BITS(s_c[lane_id / 4 + 8][(lane_id % 4) * 2]) = LDST32BITS(RC[1]); + + __syncthreads(); + + // store s_c[16][8] + if (lane_id < MMA_M) { + // store 128 bits per memory issue. + int store_gmem_c_m = by * BM + lane_id; + int store_gmem_c_n = bx * BN; + int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n; + LDST128BITS(C[store_gmem_c_addr]) = (LDST128BITS(s_c[lane_id][0])); + } +} + +// 128x128, mma2x4, warp4x4(64,32,16) +template +__global__ void __launch_bounds__(256) +hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel( + half* A, half* B, half* C, int M, int N, int K) { + const int bx = blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, MMA_K); + constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128 + constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128 + constexpr int BK = MMA_K; // 16 + + __shared__ half s_a[BM][BK+A_PAD]; // 128*16*2=4KB + __shared__ half s_b[BK][BN+B_PAD]; // 16*128*2=4KB, 16*(128+16)*2=4.5KB + + const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int lane_id = tid % WARP_SIZE; // 0~31 + const int warp_m = warp_id % 2; // 0,1 + const int warp_n = warp_id / 2; // 0,1,2,3 + + // 先计算shared memory中的索引 + // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=16 按行读取 A行主序 + // 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程 + int load_smem_a_m = tid / 2; // row 0~127 + int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 + // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序 + // 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程 + int load_smem_b_k = tid / 16; // row 0~15 + int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120 + // 再计算全局内存中的索引 + // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 + int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c + int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c + if (load_gmem_a_m >= M || load_gmem_b_n >= N) return; + + uint32_t RC[WARP_TILE_M][WARP_TILE_N][2]; + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + RC[i][j][0] = 0; + RC[i][j][1] = 0; + } + } + + #pragma unroll + for (int k = 0; k < NUM_K_TILES; ++k) { + // gmem -> smem + int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = ( + LDST128BITS(B[load_gmem_b_addr])); + LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = ( + LDST128BITS(A[load_gmem_a_addr])); + __syncthreads(); + + // ldmatrix for s_a, ldmatrix.trans for s_b. + uint32_t RA[WARP_TILE_M][4]; + uint32_t RB[WARP_TILE_N][2]; + + // smem -> reg + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_a_ptr = __cvta_generic_to_shared( + &s_a[lane_smem_a_m][lane_smem_a_k]); + LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = __cvta_generic_to_shared( + &s_b[lane_smem_b_k][lane_smem_b_n]); + LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr); + } + + // MMA compute + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + HMMA16816(RC[i][j][0], RC[i][j][1], + RA[i][0], RA[i][1], RA[i][2], RA[i][3], + RB[j][0], RB[j][1], + RC[i][j][0], RC[i][j][1]); + } + } + __syncthreads(); + } + + // reg -> gmem, MMA_MxMMA_N=16x8 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + // mapping lane smem index -> global index. + // [16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] + int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4; + int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n + (lane_id % 4) * 2; + int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n; + int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n; + // TODO: how to use LDST128BITS here ? reverse the loop order ? + LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[i][j][0]); + LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[i][j][1]); + } + } +} + + +// --------------------- PyTorch bindings for custom kernel ----------------------- +#define STRINGFY(str) #str +#define TORCH_BINDING_COMMON_EXTENSION(func) \ + m.def(STRINGFY(func), &func, STRINGFY(func)); + +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ +if(((T).options().dtype() != (th_type))) { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be "#th_type); \ +} + +#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ +if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ + throw std::runtime_error("Tensor size mismatch!"); \ +} + +// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major. +void hgemm_mma_m16n8k16_naive( + torch::Tensor a, torch::Tensor b, torch::Tensor c) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int MMA_M = 16; + constexpr int MMA_N = 8; + constexpr int MMA_K = 16; + + dim3 block(WARP_SIZE); + dim3 grid(div_ceil(N, MMA_N), div_ceil(M, MMA_M)); + + hgemm_mma_m16n8k16_naive_kernel< + MMA_M, MMA_N, MMA_K><<>>( + reinterpret_cast(a.data_ptr()), + reinterpret_cast(b.data_ptr()), + reinterpret_cast(c.data_ptr()), + M, N, K + ); +} + +// 128x128, mma2x4, warp4x4(64,32,16) +void hgemm_mma_m16n8k16_mma2x4_warp4x4( + torch::Tensor a, torch::Tensor b, torch::Tensor c) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int MMA_M = 16; + constexpr int MMA_N = 8; + constexpr int MMA_K = 16; + constexpr int MMA_TILE_M = 2; + constexpr int MMA_TILE_N = 4; + constexpr int WARP_TILE_M = 4; + constexpr int WARP_TILE_N = 4; + constexpr int A_PAD = 0; + constexpr int B_PAD = 16; + constexpr int NUM_THREADS= ( + MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 + + dim3 block(NUM_THREADS); + dim3 grid(div_ceil(N, MMA_N * MMA_TILE_N * WARP_TILE_N), + div_ceil(M, MMA_M * MMA_TILE_M * WARP_TILE_M)); + + hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel< + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD><<>>( + reinterpret_cast(a.data_ptr()), + reinterpret_cast(b.data_ptr()), + reinterpret_cast(c.data_ptr()), + M, N, K + ); +} diff --git a/kernels/swizzle/matrix_trans_swizzle.cu b/kernels/swizzle/matrix_trans_swizzle.cu new file mode 100644 index 00000000..105dbf23 --- /dev/null +++ b/kernels/swizzle/matrix_trans_swizzle.cu @@ -0,0 +1,35 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace nvcuda; + +// reference: https://zhuanlan.zhihu.com/p/4746910252 +__global__ void matrix_trans_swizzling(int* dev_A, int M, int N, int* dev_B) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int s_data[32][32]; + + if (row < M && col < N) { + // 从全局内存读取数据写入共享内存的逻辑坐标(row=x,col=y) + // 其映射的物理存储位置位置(row=x,col=x^y) + s_data[threadIdx.x][threadIdx.x ^ threadIdx.y] = dev_A[row * N + col]; + __syncthreads(); + int n_col = blockIdx.y * blockDim.y + threadIdx.x; + int n_row = blockIdx.x * blockDim.x + threadIdx.y; + if (n_row < N && n_col < M) { + // 从共享内存的逻辑坐标(row=y,col=x)读取数据 + // 其映射的物理存储位置(row=y,col=x^y) + dev_B[n_row * M + n_col] = s_data[threadIdx.y][threadIdx.x ^ threadIdx.y]; + } + } +}