From 433814f4004feedd5a819fb0cf184d61c937dc76 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Sat, 28 Dec 2024 20:21:05 +0800 Subject: [PATCH 1/2] Update README.md --- kernels/swizzle/README.md | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/kernels/swizzle/README.md b/kernels/swizzle/README.md index 625c7285..53ee0811 100644 --- a/kernels/swizzle/README.md +++ b/kernels/swizzle/README.md @@ -3,12 +3,12 @@ ## πŸ“š build bin ```bash -make +make # build all default binaries ``` ## πŸ“š ncu profile -Achieve 0 bank conflicts for LDSM via smem swizzle. +- πŸ“š Achieve 0 bank conflicts for LDSM via smem swizzle. ```bash ncu --metrics l1tex__data_bank_reads ./mat_trans_swizzle.bin @@ -20,7 +20,7 @@ ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_s ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin 1024 1024 1024 0 1 ``` -log: (achieve 0 bank conflicts for LDSM via smem swizzle) +- πŸ“š log: (achieve 0 bank conflicts for LDSM via smem swizzle) ```bash ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin 1024 1024 1024 0 1 @@ -72,16 +72,10 @@ ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./h ## πŸ“š performance -- NVIDIA TRX 3080 Laptop +- πŸ“š NVIDIA RTX 3080 Laptop ```bash ./hgemm_mma_swizzle.bin 4096 4096 4096 1 10 -ALGO = HGEMM MMA NAIVE -M N K = 4096 4096 4096, W = 1, R = 10, Time = 0.02986609 s, AVG Performance = 4.6018 Tflops - -ALGO = HGEMM MMA NAIVE + SMEM SWIZZLE -M N K = 4096 4096 4096, W = 1, R = 10, Time = 0.02860964 s, AVG Performance = 4.8039 Tflops - ALGO = HGEMM mma2x4_warp4x4 M N K = 4096 4096 4096, W = 1, R = 10, Time = 0.00392888 s, AVG Performance = 34.9817 Tflops @@ -92,7 +86,7 @@ M N K = 4096 4096 4096, W = 1, R = 10, Time = 0.00234496 s, AVG Performa ## πŸ“š print swizzle layout -- M16K16 +- πŸ“š M16K16 ```bash python3 print_swizzle_layout.py --logical-col 64 --show-logical-col @@ -147,7 +141,7 @@ smem col 0~16, step 8- ---------------------- ``` -- M16K64 +- πŸ“š M16K64 (Zigzag) ```bash python3 print_swizzle_layout.py --logical-col 64 --show-logical-col From 861ffd9cd09836a687ba620dca004b1804887566 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sat, 28 Dec 2024 22:07:05 +0800 Subject: [PATCH 2/2] add pad -> swizzle layout tools --- .../flash-attn/tools/print_swizzle_layout.py | 46 +++++++++++++++---- kernels/hgemm/tools/print_swizzle_layout.py | 46 +++++++++++++++---- kernels/swizzle/hgemm_mma_swizzle.cu | 16 +++++-- kernels/swizzle/print_swizzle_layout.py | 46 +++++++++++++++---- 4 files changed, 123 insertions(+), 31 deletions(-) diff --git a/kernels/flash-attn/tools/print_swizzle_layout.py b/kernels/flash-attn/tools/print_swizzle_layout.py index b08b1a2a..69617bfd 100644 --- a/kernels/flash-attn/tools/print_swizzle_layout.py +++ b/kernels/flash-attn/tools/print_swizzle_layout.py @@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int, def print_smem_swizzle_layout(rows: int = 16, - logical_col_stride: int = 16, - num_elems_per_128b: int = 8, - show_logical_col_id: bool = False, - use_logical_col_stride: bool = False): + logical_col_stride: int = 16, + num_elems_per_128b: int = 8, + smem_pading: int = 0, + show_logical_col_id: bool = False, + use_logical_col_stride: bool = False): # ---------------------------------------------------------------- # [INFO] Assert smem store layout col_stride <= 16, prefer 16. | # [INFO] For logical_col_stride > 16, we have to permute the | @@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16, # ---------------------------------------------------------------- str_len = 0 total_banks = 0 + assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8" # 4 bytes per bank banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4 if use_logical_col_stride: banks_per_col = int((logical_col_stride * 2) / 4) if logical_col_stride > 16: print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}") + if smem_pading == 8: + banks_per_col += 4 + print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}") + banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4) for i in range(rows): layout_str_len = 0 @@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16, num_elems_per_128b) logical_col_ids.append(j) smem_layout_col_ids.append(layout_j) + smem_layout_str = f"|row {i:<2}|" + + r = 0 for c, l in zip(logical_col_ids, smem_layout_col_ids): - smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if - show_logical_col_id else f"{l:<2}"), - sep=" ", - width=max_bank_str_len-1, - return_str=True) + "|" + smem_layout_str += pretty_print_line( + (f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"), + sep=" ", + width=(max_bank_str_len-1), + return_str=True + ) + "|" + r += 1 + if logical_col_stride >= 16: + if smem_pading == 8 and (r > 1 and r % 2 == 0): + smem_layout_str += pretty_print_line( + (f"pad"), + sep=" ", width=max_bank_str_len-1, + return_str=True + ) + "|" + else: + if smem_pading == 8: + smem_layout_str += pretty_print_line( + (f"pad"), + sep=" ", width=max_bank_str_len-1, + return_str=True + ) + "|" + layout_str_len = len(smem_layout_str) str_len = max(layout_str_len, banks_str_len) @@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16, def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--rows", type=int, default=16) + parser.add_argument("--smem-padding", "--pad", type=int, default=0) parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8) parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64) parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true") @@ -186,6 +213,7 @@ def get_args(): print_smem_swizzle_layout(rows=args.rows, logical_col_stride=args.logical_col_stride, num_elems_per_128b=args.num_elems_per_128b, + smem_pading=args.smem_padding, show_logical_col_id=args.show_logical_col_id, use_logical_col_stride=args.use_logical_col_stride) diff --git a/kernels/hgemm/tools/print_swizzle_layout.py b/kernels/hgemm/tools/print_swizzle_layout.py index b08b1a2a..69617bfd 100644 --- a/kernels/hgemm/tools/print_swizzle_layout.py +++ b/kernels/hgemm/tools/print_swizzle_layout.py @@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int, def print_smem_swizzle_layout(rows: int = 16, - logical_col_stride: int = 16, - num_elems_per_128b: int = 8, - show_logical_col_id: bool = False, - use_logical_col_stride: bool = False): + logical_col_stride: int = 16, + num_elems_per_128b: int = 8, + smem_pading: int = 0, + show_logical_col_id: bool = False, + use_logical_col_stride: bool = False): # ---------------------------------------------------------------- # [INFO] Assert smem store layout col_stride <= 16, prefer 16. | # [INFO] For logical_col_stride > 16, we have to permute the | @@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16, # ---------------------------------------------------------------- str_len = 0 total_banks = 0 + assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8" # 4 bytes per bank banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4 if use_logical_col_stride: banks_per_col = int((logical_col_stride * 2) / 4) if logical_col_stride > 16: print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}") + if smem_pading == 8: + banks_per_col += 4 + print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}") + banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4) for i in range(rows): layout_str_len = 0 @@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16, num_elems_per_128b) logical_col_ids.append(j) smem_layout_col_ids.append(layout_j) + smem_layout_str = f"|row {i:<2}|" + + r = 0 for c, l in zip(logical_col_ids, smem_layout_col_ids): - smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if - show_logical_col_id else f"{l:<2}"), - sep=" ", - width=max_bank_str_len-1, - return_str=True) + "|" + smem_layout_str += pretty_print_line( + (f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"), + sep=" ", + width=(max_bank_str_len-1), + return_str=True + ) + "|" + r += 1 + if logical_col_stride >= 16: + if smem_pading == 8 and (r > 1 and r % 2 == 0): + smem_layout_str += pretty_print_line( + (f"pad"), + sep=" ", width=max_bank_str_len-1, + return_str=True + ) + "|" + else: + if smem_pading == 8: + smem_layout_str += pretty_print_line( + (f"pad"), + sep=" ", width=max_bank_str_len-1, + return_str=True + ) + "|" + layout_str_len = len(smem_layout_str) str_len = max(layout_str_len, banks_str_len) @@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16, def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--rows", type=int, default=16) + parser.add_argument("--smem-padding", "--pad", type=int, default=0) parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8) parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64) parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true") @@ -186,6 +213,7 @@ def get_args(): print_smem_swizzle_layout(rows=args.rows, logical_col_stride=args.logical_col_stride, num_elems_per_128b=args.num_elems_per_128b, + smem_pading=args.smem_padding, show_logical_col_id=args.show_logical_col_id, use_logical_col_stride=args.use_logical_col_stride) diff --git a/kernels/swizzle/hgemm_mma_swizzle.cu b/kernels/swizzle/hgemm_mma_swizzle.cu index cbb4f068..8938c0ac 100644 --- a/kernels/swizzle/hgemm_mma_swizzle.cu +++ b/kernels/swizzle/hgemm_mma_swizzle.cu @@ -520,7 +520,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4( constexpr int MMA_TILE_N = 4; constexpr int WARP_TILE_M = 4; constexpr int WARP_TILE_N = 4; - // bank conflicts free via pad = 8, ζ‹’η»εΉ»ζƒ³οΌŒη›ΈδΏ‘profile + // bank conflicts free via pad = 8. // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_swizzle.bin // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin // constexpr int A_PAD = 8; @@ -541,6 +541,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4( ); } +template void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle( half* a, half* b, half* c, int M, int N, int K) { constexpr int MMA_M = 16; @@ -551,7 +552,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle( constexpr int WARP_TILE_M = 4; constexpr int WARP_TILE_N = 4; constexpr int A_PAD = 0; - constexpr int B_PAD = 8; + // B_PAD = 8, bank conflicts free via pad = 8. constexpr int NUM_THREADS= ( MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 dim3 block(NUM_THREADS); @@ -644,9 +645,16 @@ int main(int argc, char *argv[]) { avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec; printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R); printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops); + + printf("\nALGO = HGEMM mma2x4_warp4x4 + A SMEM SWIZZLE + B_PAD 0\n"); + avg_sec = perf_gemm(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle<0>, + M, N, K, W, R); + avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec; + printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R); + printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops); - printf("\nALGO = HGEMM mma2x4_warp4x4 + SMEM SWIZZLE\n"); - avg_sec = perf_gemm(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle, + printf("\nALGO = HGEMM mma2x4_warp4x4 + A SMEM SWIZZLE + B_PAD 8\n"); + avg_sec = perf_gemm(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle<8>, M, N, K, W, R); avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec; printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R); diff --git a/kernels/swizzle/print_swizzle_layout.py b/kernels/swizzle/print_swizzle_layout.py index b08b1a2a..69617bfd 100644 --- a/kernels/swizzle/print_swizzle_layout.py +++ b/kernels/swizzle/print_swizzle_layout.py @@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int, def print_smem_swizzle_layout(rows: int = 16, - logical_col_stride: int = 16, - num_elems_per_128b: int = 8, - show_logical_col_id: bool = False, - use_logical_col_stride: bool = False): + logical_col_stride: int = 16, + num_elems_per_128b: int = 8, + smem_pading: int = 0, + show_logical_col_id: bool = False, + use_logical_col_stride: bool = False): # ---------------------------------------------------------------- # [INFO] Assert smem store layout col_stride <= 16, prefer 16. | # [INFO] For logical_col_stride > 16, we have to permute the | @@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16, # ---------------------------------------------------------------- str_len = 0 total_banks = 0 + assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8" # 4 bytes per bank banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4 if use_logical_col_stride: banks_per_col = int((logical_col_stride * 2) / 4) if logical_col_stride > 16: print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}") + if smem_pading == 8: + banks_per_col += 4 + print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}") + banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4) for i in range(rows): layout_str_len = 0 @@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16, num_elems_per_128b) logical_col_ids.append(j) smem_layout_col_ids.append(layout_j) + smem_layout_str = f"|row {i:<2}|" + + r = 0 for c, l in zip(logical_col_ids, smem_layout_col_ids): - smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if - show_logical_col_id else f"{l:<2}"), - sep=" ", - width=max_bank_str_len-1, - return_str=True) + "|" + smem_layout_str += pretty_print_line( + (f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"), + sep=" ", + width=(max_bank_str_len-1), + return_str=True + ) + "|" + r += 1 + if logical_col_stride >= 16: + if smem_pading == 8 and (r > 1 and r % 2 == 0): + smem_layout_str += pretty_print_line( + (f"pad"), + sep=" ", width=max_bank_str_len-1, + return_str=True + ) + "|" + else: + if smem_pading == 8: + smem_layout_str += pretty_print_line( + (f"pad"), + sep=" ", width=max_bank_str_len-1, + return_str=True + ) + "|" + layout_str_len = len(smem_layout_str) str_len = max(layout_str_len, banks_str_len) @@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16, def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--rows", type=int, default=16) + parser.add_argument("--smem-padding", "--pad", type=int, default=0) parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8) parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64) parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true") @@ -186,6 +213,7 @@ def get_args(): print_smem_swizzle_layout(rows=args.rows, logical_col_stride=args.logical_col_stride, num_elems_per_128b=args.num_elems_per_128b, + smem_pading=args.smem_padding, show_logical_col_id=args.show_logical_col_id, use_logical_col_stride=args.use_logical_col_stride)