diff --git a/hgemm/hgemm.py b/hgemm/hgemm.py index b6834593..b0ede84d 100644 --- a/hgemm/hgemm.py +++ b/hgemm/hgemm.py @@ -104,6 +104,22 @@ def get_device_capability(): TOATL_TFLOPS: dict[str, float] = {} CUBLAS_TOTAL_TFLOPS = 0 + +def make_block_swizzle_stride(N: int, K: int): + # make swizzle stride as N/8,N/4,N/2 and multiples of 256 + if args.swizzle_factor is None: + swizzle_factor = 0.5 if N <= 4096 else 0.25 + if all((N >= 14848, K > 8192, N % 8 == 0)): + swizzle_factor = 0.125 + else: + swizzle_factor = args.swizzle_factor + + swizzle_stride = int(N * swizzle_factor) + swizzle_stride = swizzle_stride if swizzle_stride >= 256 else 1 + + return swizzle_stride + + def run_benchmark(perf_func: callable, a: torch.Tensor, b: torch.Tensor, tag: str, out: Optional[torch.Tensor] = None, @@ -121,13 +137,7 @@ def run_benchmark(perf_func: callable, if 'tn' in tag: N = b.size(0) if swizzle: - # make swizzle stride as N/4 or N/2 and multiples of 256 - if args.swizzle_factor is None: - swizzle_factor = 0.5 if N <= 4096 else 0.25 - else: - swizzle_factor = args.swizzle_factor - swizzle_stride = int((int(N * swizzle_factor) // 256) * 256) - swizzle_stride = swizzle_stride if swizzle_stride >= 256 else 1 + swizzle_stride = make_block_swizzle_stride(N, K) swizzle = swizzle if swizzle_stride >= 256 else False else: swizzle_stride = 1 # means no thread block swizzle @@ -187,7 +197,6 @@ def run_benchmark(perf_func: callable, print(f"{out_info:>42}: {out_val}, time:{mean_time}ms, " f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}") if show_matrix: print(out) - time.sleep(args.sleep_duration) if args.plot_flops: STATIS_INFO[tag] = STATIS_INFO.get(tag, []) STATIS_INFO[tag].append(TFLOPS) @@ -196,6 +205,9 @@ def run_benchmark(perf_func: callable, else: global CUBLAS_TOTAL_TFLOPS CUBLAS_TOTAL_TFLOPS += TFLOPS + + torch.cuda.synchronize() + time.sleep(args.sleep_duration) return out, mean_time diff --git a/hgemm/hgemm_mma_stage.cu b/hgemm/hgemm_mma_stage.cu index 78bfca57..72a8450c 100644 --- a/hgemm/hgemm_mma_stage.cu +++ b/hgemm/hgemm_mma_stage.cu @@ -2014,7 +2014,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages( // s4: 4*128*(16)*2=16KB, 4*16*(128+16)*2=18KB, ~34KB // s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB if (swizzle) { - assert(swizzle_stride % 256 == 0); + // assert(swizzle_stride % 256 == 0); switch (stages) { case 2: // ~17KB @@ -2144,7 +2144,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem( // s4: 4*128*(16)*2=16KB, 4*16*(128+16)*2=18KB, ~34KB // s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB if (swizzle) { - assert(swizzle_stride % 256 == 0); + // assert(swizzle_stride % 256 == 0); switch (stages) { case 2: // ~17KB @@ -2275,7 +2275,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem( // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB // s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB if (swizzle) { - assert(swizzle_stride % 256 == 0); + // assert(swizzle_stride % 256 == 0); switch (stages) { case 2: // ~35KB @@ -2407,7 +2407,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4( // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB // s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB if (swizzle) { - assert(swizzle_stride % 256 == 0); + // assert(swizzle_stride % 256 == 0); switch (stages) { case 2: // ~35KB @@ -2540,7 +2540,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr( // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB // s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB if (swizzle) { - assert(swizzle_stride % 256 == 0); + // assert(swizzle_stride % 256 == 0); switch (stages) { case 2: // ~35KB diff --git a/hgemm/hgemm_mma_stage_tn.cu b/hgemm/hgemm_mma_stage_tn.cu index 853a157f..db8cf1a4 100644 --- a/hgemm/hgemm_mma_stage_tn.cu +++ b/hgemm/hgemm_mma_stage_tn.cu @@ -406,7 +406,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn( constexpr int BK = MMA_K; if (swizzle) { - assert(swizzle_stride % 256 == 0); + // assert(swizzle_stride % 256 == 0); switch (stages) { case 2: diff --git a/hgemm/hgemm_wmma_stage.cu b/hgemm/hgemm_wmma_stage.cu index b8131389..a8697908 100644 --- a/hgemm/hgemm_wmma_stage.cu +++ b/hgemm/hgemm_wmma_stage.cu @@ -1030,7 +1030,7 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages( // s4: 4*128*(16)*2=16KB, 4*16*(128+16)*2=18KB, ~34KB // s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB if (swizzle) { - assert(swizzle_stride % 256 == 0); + // assert(swizzle_stride % 256 == 0); switch (stages) { case 2: // ~17KB @@ -1158,7 +1158,7 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem( // s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB // s6: 6*128*(16)*2=24KB, 6*16*(128+16)*2=27KB, ~51KB > 48KB if (swizzle) { - assert(swizzle_stride % 256 == 0); + // assert(swizzle_stride % 256 == 0); switch (stages) { case 2: // ~17KB @@ -1293,7 +1293,7 @@ void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem( // s3: 3*256*(16)*2=24KB, 3*16*(256+16)*2=25.5KB, ~50KB > 48KB // s4: 4*256*(16)*2=32KB, 4*16*(256+16)*2=34KB, ~66KB if (swizzle) { - assert(swizzle_stride % 256 == 0); + // assert(swizzle_stride % 256 == 0); switch (stages) { case 2: // ~33KB @@ -1418,7 +1418,7 @@ void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem( constexpr int BK = WMMA_K * WARP_TILE_K; if (swizzle) { - assert(swizzle_stride % 256 == 0); + // assert(swizzle_stride % 256 == 0); switch (stages) { case 2: @@ -1457,4 +1457,4 @@ void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem( break; } } -} \ No newline at end of file +}