Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions kernels/flash-attn/tools/print_swizzle_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int,


def print_smem_swizzle_layout(rows: int = 16,
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
smem_pading: int = 0,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
# ----------------------------------------------------------------
# [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
# [INFO] For logical_col_stride > 16, we have to permute the |
Expand Down Expand Up @@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16,
# ----------------------------------------------------------------
str_len = 0
total_banks = 0
assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8"
# 4 bytes per bank
banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4
if use_logical_col_stride:
banks_per_col = int((logical_col_stride * 2) / 4)
if logical_col_stride > 16:
print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}")
if smem_pading == 8:
banks_per_col += 4
print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}")

banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4)
for i in range(rows):
layout_str_len = 0
Expand Down Expand Up @@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16,
num_elems_per_128b)
logical_col_ids.append(j)
smem_layout_col_ids.append(layout_j)

smem_layout_str = f"|row {i:<2}|"

r = 0
for c, l in zip(logical_col_ids, smem_layout_col_ids):
smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if
show_logical_col_id else f"{l:<2}"),
sep=" ",
width=max_bank_str_len-1,
return_str=True) + "|"
smem_layout_str += pretty_print_line(
(f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"),
sep=" ",
width=(max_bank_str_len-1),
return_str=True
) + "|"
r += 1
if logical_col_stride >= 16:
if smem_pading == 8 and (r > 1 and r % 2 == 0):
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"
else:
if smem_pading == 8:
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"

layout_str_len = len(smem_layout_str)
str_len = max(layout_str_len, banks_str_len)

Expand All @@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16,
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--rows", type=int, default=16)
parser.add_argument("--smem-padding", "--pad", type=int, default=0)
parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8)
parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64)
parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true")
Expand All @@ -186,6 +213,7 @@ def get_args():
print_smem_swizzle_layout(rows=args.rows,
logical_col_stride=args.logical_col_stride,
num_elems_per_128b=args.num_elems_per_128b,
smem_pading=args.smem_padding,
show_logical_col_id=args.show_logical_col_id,
use_logical_col_stride=args.use_logical_col_stride)

46 changes: 37 additions & 9 deletions kernels/hgemm/tools/print_swizzle_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int,


def print_smem_swizzle_layout(rows: int = 16,
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
smem_pading: int = 0,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
# ----------------------------------------------------------------
# [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
# [INFO] For logical_col_stride > 16, we have to permute the |
Expand Down Expand Up @@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16,
# ----------------------------------------------------------------
str_len = 0
total_banks = 0
assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8"
# 4 bytes per bank
banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4
if use_logical_col_stride:
banks_per_col = int((logical_col_stride * 2) / 4)
if logical_col_stride > 16:
print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}")
if smem_pading == 8:
banks_per_col += 4
print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}")

banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4)
for i in range(rows):
layout_str_len = 0
Expand Down Expand Up @@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16,
num_elems_per_128b)
logical_col_ids.append(j)
smem_layout_col_ids.append(layout_j)

smem_layout_str = f"|row {i:<2}|"

r = 0
for c, l in zip(logical_col_ids, smem_layout_col_ids):
smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if
show_logical_col_id else f"{l:<2}"),
sep=" ",
width=max_bank_str_len-1,
return_str=True) + "|"
smem_layout_str += pretty_print_line(
(f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"),
sep=" ",
width=(max_bank_str_len-1),
return_str=True
) + "|"
r += 1
if logical_col_stride >= 16:
if smem_pading == 8 and (r > 1 and r % 2 == 0):
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"
else:
if smem_pading == 8:
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"

layout_str_len = len(smem_layout_str)
str_len = max(layout_str_len, banks_str_len)

Expand All @@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16,
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--rows", type=int, default=16)
parser.add_argument("--smem-padding", "--pad", type=int, default=0)
parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8)
parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64)
parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true")
Expand All @@ -186,6 +213,7 @@ def get_args():
print_smem_swizzle_layout(rows=args.rows,
logical_col_stride=args.logical_col_stride,
num_elems_per_128b=args.num_elems_per_128b,
smem_pading=args.smem_padding,
show_logical_col_id=args.show_logical_col_id,
use_logical_col_stride=args.use_logical_col_stride)

16 changes: 12 additions & 4 deletions kernels/swizzle/hgemm_mma_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4(
constexpr int MMA_TILE_N = 4;
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
// bank conflicts free via pad = 8, 拒绝幻想,相信profile
// bank conflicts free via pad = 8.
// ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_swizzle.bin
// ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin
// constexpr int A_PAD = 8;
Expand All @@ -541,6 +541,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4(
);
}

template <const int B_PAD = 8>
void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle(
half* a, half* b, half* c, int M, int N, int K) {
constexpr int MMA_M = 16;
Expand All @@ -551,7 +552,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle(
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
constexpr int A_PAD = 0;
constexpr int B_PAD = 8;
// B_PAD = 8, bank conflicts free via pad = 8.
constexpr int NUM_THREADS= (
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
dim3 block(NUM_THREADS);
Expand Down Expand Up @@ -644,9 +645,16 @@ int main(int argc, char *argv[]) {
avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops);

printf("\nALGO = HGEMM mma2x4_warp4x4 + A SMEM SWIZZLE + B_PAD 0\n");
avg_sec = perf_gemm<half>(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle<0>,
M, N, K, W, R);
avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops);

printf("\nALGO = HGEMM mma2x4_warp4x4 + SMEM SWIZZLE\n");
avg_sec = perf_gemm<half>(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle,
printf("\nALGO = HGEMM mma2x4_warp4x4 + A SMEM SWIZZLE + B_PAD 8\n");
avg_sec = perf_gemm<half>(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle<8>,
M, N, K, W, R);
avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
Expand Down
46 changes: 37 additions & 9 deletions kernels/swizzle/print_swizzle_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int,


def print_smem_swizzle_layout(rows: int = 16,
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
smem_pading: int = 0,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
# ----------------------------------------------------------------
# [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
# [INFO] For logical_col_stride > 16, we have to permute the |
Expand Down Expand Up @@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16,
# ----------------------------------------------------------------
str_len = 0
total_banks = 0
assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8"
# 4 bytes per bank
banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4
if use_logical_col_stride:
banks_per_col = int((logical_col_stride * 2) / 4)
if logical_col_stride > 16:
print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}")
if smem_pading == 8:
banks_per_col += 4
print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}")

banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4)
for i in range(rows):
layout_str_len = 0
Expand Down Expand Up @@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16,
num_elems_per_128b)
logical_col_ids.append(j)
smem_layout_col_ids.append(layout_j)

smem_layout_str = f"|row {i:<2}|"

r = 0
for c, l in zip(logical_col_ids, smem_layout_col_ids):
smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if
show_logical_col_id else f"{l:<2}"),
sep=" ",
width=max_bank_str_len-1,
return_str=True) + "|"
smem_layout_str += pretty_print_line(
(f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"),
sep=" ",
width=(max_bank_str_len-1),
return_str=True
) + "|"
r += 1
if logical_col_stride >= 16:
if smem_pading == 8 and (r > 1 and r % 2 == 0):
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"
else:
if smem_pading == 8:
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"

layout_str_len = len(smem_layout_str)
str_len = max(layout_str_len, banks_str_len)

Expand All @@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16,
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--rows", type=int, default=16)
parser.add_argument("--smem-padding", "--pad", type=int, default=0)
parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8)
parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64)
parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true")
Expand All @@ -186,6 +213,7 @@ def get_args():
print_smem_swizzle_layout(rows=args.rows,
logical_col_stride=args.logical_col_stride,
num_elems_per_128b=args.num_elems_per_128b,
smem_pading=args.smem_padding,
show_logical_col_id=args.show_logical_col_id,
use_logical_col_stride=args.use_logical_col_stride)