diff --git a/hgemm/README.md b/hgemm/README.md index d76ac820..9f19366b 100755 --- a/hgemm/README.md +++ b/hgemm/README.md @@ -66,6 +66,24 @@ python3 hgemm.py --mma-all --plot --topk 8 ## 目前性能 +### NVIDIA GeForce RTX 3080 Laptop + +在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 WMMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,使用Windows WSL2 + RTX 3080 Laptop进行测试。 + +![](./NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png) + +```bash +python3 hgemm.py --wmma-all +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=16384, K=8192, Warmup=5, Iters=20, 27/27 +---------------------------------------------------------------------------------------------------------------------------------- + (wmma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:96.91984ms, swizzle: NOOP, TFLOPS: 45.38 (+0.00%) + (wmma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:102.8722ms, swizzle: NOOP, TFLOPS: 42.75 + (wmma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:85.65800ms, swizzle: 4096, TFLOPS: 51.34 (+13.15%) + (wmma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:95.70884ms, swizzle: 4096, TFLOPS: 45.95 + (cublas): ['68.375 ', '-2.234375 '], time:104.2092ms, swizzle: NOOP, TFLOPS: 42.20 +---------------------------------------------------------------------------------------------------------------------------------- +``` ### NVIDIA L20 目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现smem swizzle/permute(受限于WMMA API的灵活性以及row major的layout),后续将会尝试通过MMA PTX实现smem swizzle/permute。 @@ -147,24 +165,6 @@ python3 hgemm.py --mma-all --wmma-all --cuda-all ---------------------------------------------------------------------------------------------------------------------------------- ``` -### NVIDIA GeForce RTX 3080 Laptop - -在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 WMMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,不过Laptop由于我是在WSL测试的,性能数据不稳定,这部分看看就好,别太当真。 - -![](./NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png) - -```bash -python3 hgemm.py --wmma-all ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=16384, K=8192, Warmup=5, Iters=20, 27/27 ----------------------------------------------------------------------------------------------------------------------------------- - (wmma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:96.91984ms, swizzle: NOOP, TFLOPS: 45.38 (+0.00%) - (wmma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:102.8722ms, swizzle: NOOP, TFLOPS: 42.75 - (wmma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:85.65800ms, swizzle: 4096, TFLOPS: 51.34 (+13.15%) - (wmma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:95.70884ms, swizzle: 4096, TFLOPS: 45.95 - (cublas): ['68.375 ', '-2.234375 '], time:104.2092ms, swizzle: NOOP, TFLOPS: 42.20 ----------------------------------------------------------------------------------------------------------------------------------- -``` ## 性能优化笔记 diff --git a/hgemm/hgemm.py b/hgemm/hgemm.py index c79a4df9..b6834593 100644 --- a/hgemm/hgemm.py +++ b/hgemm/hgemm.py @@ -7,14 +7,15 @@ torch.set_grad_enabled(False) + def get_args(): parser = argparse.ArgumentParser(description="hgemm benchmark") parser.add_argument("--M", type=int, default=None, help="Matrix M size") parser.add_argument("--N", type=int, default=None, help="Matrix N size") parser.add_argument("--K", type=int, default=None, help="Matrix K size") parser.add_argument("--MNK", type=int, default=None, help="Matrix M=N=K size") - parser.add_argument("--MMNK", type=int, default=16384, help="Matrix MAX M=M=N=K size") - parser.add_argument("--SEP", type=int, default=512, help="Matrix MAX M=M=N=K size") + parser.add_argument("--MMNK", type=int, default=12800, help="Matrix MAX M=M=N=K size") + parser.add_argument("--SEP", '--sep', type=int, default=256, help="Matrix SEP M=M=N=K size") parser.add_argument("--warmup", "--w", type=int, default=2, help="Warmup iters") parser.add_argument("--iters", "--i", type=int, default=10, help="Benchmark iters") parser.add_argument("--verbose", "--v", action="store_true", help="Verbose") @@ -36,16 +37,30 @@ def get_args(): parser.add_argument("--no-default", action="store_true", help="Disable default tests") parser.add_argument("--plot-flops", "--plot", action="store_true", help="Plot TFLOPS") parser.add_argument("--plot-topk", "--topk", type=int, default=8, help="Plot top k TFLOPS") - parser.add_argument("--no-hint-top1", "--no-hint", action="store_true", help="Hint top 1 TFLOPS") + parser.add_argument("--no-plot-best", "--no-best", action="store_true", help="Not Plot best TFLOPS") parser.add_argument("--exclude-tags", "--exclude", type=str, default=None, help="Exclude tag for plot, sperated by comma") - parser.add_argument("--save-tag", "--tag", type=str, default=None, help="Save tag for plot") + parser.add_argument("--save-dir", "--dir", type=str, default="./", help="Save dir for plot") return parser.parse_args() args = get_args() print(args) + +def get_device_name(): + device_name = torch.cuda.get_device_name(torch.cuda.current_device()) + # since we will run GPU on WSL2, so add WSL2 tag. + if "Laptop" in device_name: + device_name += " WSL2" + return device_name + + +def get_device_capability(): + return torch.cuda.get_device_capability(torch.cuda.current_device()) + + # Load the CUDA kernel as a python module -print("Loading hgemm lib ...") +print(f"Loading hgemm lib on device: {get_device_name()}, capability: {get_device_capability()} ...") + lib = load(name='hgemm_lib', sources=['hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu', 'hgemm_wmma_stage.cu', 'hgemm_cublas.cu', @@ -184,36 +199,35 @@ def run_benchmark(perf_func: callable, return out, mean_time -def get_device_name(): - device_name = torch.cuda.get_device_name(torch.cuda.current_device()) - # we will run GPU on WSL2, so add WSL2 tag. - if "Laptop" in device_name: - device_name += " WSL2" - return device_name - - -def get_device_capability(): - return torch.cuda.get_device_capability(torch.cuda.current_device()) - - def get_topk_tflops(): topk_tflops = sorted(TOATL_TFLOPS.items(), key=lambda x: x[1], reverse=True) print("-" * 130) - print(" " * 42 + f"HGEMM TOTAL TFLOPS, {get_device_name()}") + print(" " * 32 + f"THE TOTAL TFLOPS OF {len(topk_tflops)} HGEMM ALGO ON {get_device_name()} DEVICE") print("-" * 130) for tag, tflops in list(topk_tflops)[::-1]: - print(f"{tag:>42}: {tflops:<10.2f} TFLOPS") - print(f"{'(cublas)':>42}: {CUBLAS_TOTAL_TFLOPS:<10.2f} TFLOPS") + print(f"{tag:>45}: {tflops:>20.2f} TFLOPS") + print(f"{'(cublas)':>45}: {CUBLAS_TOTAL_TFLOPS:>20.2f} TFLOPS") print("-" * 130) - return dict(topk_tflops[:args.plot_topk]).keys() + return list(dict(topk_tflops[:args.plot_topk]).keys()) + + +def get_best_tflops(): + all_tflops = [] + for tag, tflops in STATIS_INFO.items(): + if "cublas" not in tag and "MNK" not in tag: + all_tflops.append(tflops) + # [N, NUM_MNK], reduce max on N dim + all_tflops = torch.tensor(all_tflops, dtype=torch.float) + best_tflops = torch.max(all_tflops, dim=0, keepdim=False)[0].tolist() + return best_tflops def plot_tflops(): import matplotlib.pyplot as plt import numpy as np - _, ax = plt.subplots(figsize=(16, 9)) - plt.subplots_adjust(left=0.03, right=0.99, top=0.95, bottom=0.05) + ax: plt.Axes = plt.subplots(figsize=(16, 9))[1] # fig, axs + plt.subplots_adjust(left=0.04, right=0.99, top=0.95, bottom=0.05) ax.set_title(f"My HGEMM vs cuBLAS, {get_device_name()}, Warmup={args.warmup}, Iters={args.iters}") ax.set_xlabel("M=N=K") ax.set_ylabel("TFLOPS") @@ -224,36 +238,37 @@ def plot_tflops(): exclude_tags.append("MNK") exclude_tags = set(exclude_tags) - def should_exclude(tag: str) -> bool: + topk_tflops = get_topk_tflops() + STATIS_INFO["(best)"] = get_best_tflops() + draw_tags = topk_tflops + draw_tags.append("(cublas)") + draw_tags.append("(best)") + + def skip_it(tag: str) -> bool: for etag in exclude_tags: if etag in tag: return True + if tag not in draw_tags: + return True return False - topk_tflops = get_topk_tflops() - is_top_1 = True + # draw by topk order for tag, tflops in STATIS_INFO.items(): - if (should_exclude(tag)) or (tag not in topk_tflops - and "cublas" not in tag): + if skip_it(tag): continue if "cublas" in tag: ax.plot(tflops, label=tag, linewidth=3) else: - if is_top_1 and not args.no_hint_top1: + if "best" in tag and not args.no_plot_best: ax.plot(tflops, label=tag, linewidth=4) - is_top_1 = False else: ax.plot(tflops, label=tag, linestyle='--') ax.legend() - if args.save_tag: - plt.savefig(f"{args.save_tag}", dpi=300) - print(f"plot hgemm TFLOPS done, saved as {args.save_tag}") - else: - device_name = get_device_name().replace(" ", "_") - save_tag = f"{device_name}.png" - plt.savefig(save_tag, dpi=300) - print(f"plot hgemm TFLOPS done, saved as {save_tag}") + device_name = get_device_name().replace(" ", "_") + save_tag = f"{args.save_dir}/{device_name}.png" + plt.savefig(save_tag, dpi=300) + print(f"plot hgemm TFLOPS done, saved as {save_tag}") def get_mnk(sep: int = args.SEP): @@ -386,7 +401,4 @@ def get_mnk(sep: int = args.SEP): print("-" * 130) if args.plot_flops: - try: - plot_tflops() - except Exception as e: - print(f"plot hgemm TFLOPS failed, {e}") \ No newline at end of file + plot_tflops()