Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions hgemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,24 @@ python3 hgemm.py --mma-all --plot --topk 8

## 目前性能

### NVIDIA GeForce RTX 3080 Laptop

在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 WMMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,使用Windows WSL2 + RTX 3080 Laptop进行测试。

![](./NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png)

```bash
python3 hgemm.py --wmma-all
----------------------------------------------------------------------------------------------------------------------------------
M=16384, N=16384, K=8192, Warmup=5, Iters=20, 27/27
----------------------------------------------------------------------------------------------------------------------------------
(wmma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:96.91984ms, swizzle: NOOP, TFLOPS: 45.38 (+0.00%)
(wmma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:102.8722ms, swizzle: NOOP, TFLOPS: 42.75
(wmma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:85.65800ms, swizzle: 4096, TFLOPS: 51.34 (+13.15%)
(wmma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:95.70884ms, swizzle: 4096, TFLOPS: 45.95
(cublas): ['68.375 ', '-2.234375 '], time:104.2092ms, swizzle: NOOP, TFLOPS: 42.20
----------------------------------------------------------------------------------------------------------------------------------
```
### NVIDIA L20

目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现smem swizzle/permute(受限于WMMA API的灵活性以及row major的layout),后续将会尝试通过MMA PTX实现smem swizzle/permute。
Expand Down Expand Up @@ -147,24 +165,6 @@ python3 hgemm.py --mma-all --wmma-all --cuda-all
----------------------------------------------------------------------------------------------------------------------------------
```

### NVIDIA GeForce RTX 3080 Laptop

在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 WMMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,不过Laptop由于我是在WSL测试的,性能数据不稳定,这部分看看就好,别太当真。

![](./NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png)

```bash
python3 hgemm.py --wmma-all
----------------------------------------------------------------------------------------------------------------------------------
M=16384, N=16384, K=8192, Warmup=5, Iters=20, 27/27
----------------------------------------------------------------------------------------------------------------------------------
(wmma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:96.91984ms, swizzle: NOOP, TFLOPS: 45.38 (+0.00%)
(wmma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:102.8722ms, swizzle: NOOP, TFLOPS: 42.75
(wmma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:85.65800ms, swizzle: 4096, TFLOPS: 51.34 (+13.15%)
(wmma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:95.70884ms, swizzle: 4096, TFLOPS: 45.95
(cublas): ['68.375 ', '-2.234375 '], time:104.2092ms, swizzle: NOOP, TFLOPS: 42.20
----------------------------------------------------------------------------------------------------------------------------------
```

## 性能优化笔记

Expand Down
96 changes: 54 additions & 42 deletions hgemm/hgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

torch.set_grad_enabled(False)


def get_args():
parser = argparse.ArgumentParser(description="hgemm benchmark")
parser.add_argument("--M", type=int, default=None, help="Matrix M size")
parser.add_argument("--N", type=int, default=None, help="Matrix N size")
parser.add_argument("--K", type=int, default=None, help="Matrix K size")
parser.add_argument("--MNK", type=int, default=None, help="Matrix M=N=K size")
parser.add_argument("--MMNK", type=int, default=16384, help="Matrix MAX M=M=N=K size")
parser.add_argument("--SEP", type=int, default=512, help="Matrix MAX M=M=N=K size")
parser.add_argument("--MMNK", type=int, default=12800, help="Matrix MAX M=M=N=K size")
parser.add_argument("--SEP", '--sep', type=int, default=256, help="Matrix SEP M=M=N=K size")
parser.add_argument("--warmup", "--w", type=int, default=2, help="Warmup iters")
parser.add_argument("--iters", "--i", type=int, default=10, help="Benchmark iters")
parser.add_argument("--verbose", "--v", action="store_true", help="Verbose")
Expand All @@ -36,16 +37,30 @@ def get_args():
parser.add_argument("--no-default", action="store_true", help="Disable default tests")
parser.add_argument("--plot-flops", "--plot", action="store_true", help="Plot TFLOPS")
parser.add_argument("--plot-topk", "--topk", type=int, default=8, help="Plot top k TFLOPS")
parser.add_argument("--no-hint-top1", "--no-hint", action="store_true", help="Hint top 1 TFLOPS")
parser.add_argument("--no-plot-best", "--no-best", action="store_true", help="Not Plot best TFLOPS")
parser.add_argument("--exclude-tags", "--exclude", type=str, default=None, help="Exclude tag for plot, sperated by comma")
parser.add_argument("--save-tag", "--tag", type=str, default=None, help="Save tag for plot")
parser.add_argument("--save-dir", "--dir", type=str, default="./", help="Save dir for plot")
return parser.parse_args()

args = get_args()
print(args)


def get_device_name():
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
# since we will run GPU on WSL2, so add WSL2 tag.
if "Laptop" in device_name:
device_name += " WSL2"
return device_name


def get_device_capability():
return torch.cuda.get_device_capability(torch.cuda.current_device())


# Load the CUDA kernel as a python module
print("Loading hgemm lib ...")
print(f"Loading hgemm lib on device: {get_device_name()}, capability: {get_device_capability()} ...")

lib = load(name='hgemm_lib',
sources=['hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu',
'hgemm_wmma_stage.cu', 'hgemm_cublas.cu',
Expand Down Expand Up @@ -184,36 +199,35 @@ def run_benchmark(perf_func: callable,
return out, mean_time


def get_device_name():
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
# we will run GPU on WSL2, so add WSL2 tag.
if "Laptop" in device_name:
device_name += " WSL2"
return device_name


def get_device_capability():
return torch.cuda.get_device_capability(torch.cuda.current_device())


def get_topk_tflops():
topk_tflops = sorted(TOATL_TFLOPS.items(), key=lambda x: x[1],
reverse=True)
print("-" * 130)
print(" " * 42 + f"HGEMM TOTAL TFLOPS, {get_device_name()}")
print(" " * 32 + f"THE TOTAL TFLOPS OF {len(topk_tflops)} HGEMM ALGO ON {get_device_name()} DEVICE")
print("-" * 130)
for tag, tflops in list(topk_tflops)[::-1]:
print(f"{tag:>42}: {tflops:<10.2f} TFLOPS")
print(f"{'(cublas)':>42}: {CUBLAS_TOTAL_TFLOPS:<10.2f} TFLOPS")
print(f"{tag:>45}: {tflops:>20.2f} TFLOPS")
print(f"{'(cublas)':>45}: {CUBLAS_TOTAL_TFLOPS:>20.2f} TFLOPS")
print("-" * 130)
return dict(topk_tflops[:args.plot_topk]).keys()
return list(dict(topk_tflops[:args.plot_topk]).keys())


def get_best_tflops():
all_tflops = []
for tag, tflops in STATIS_INFO.items():
if "cublas" not in tag and "MNK" not in tag:
all_tflops.append(tflops)
# [N, NUM_MNK], reduce max on N dim
all_tflops = torch.tensor(all_tflops, dtype=torch.float)
best_tflops = torch.max(all_tflops, dim=0, keepdim=False)[0].tolist()
return best_tflops


def plot_tflops():
import matplotlib.pyplot as plt
import numpy as np
_, ax = plt.subplots(figsize=(16, 9))
plt.subplots_adjust(left=0.03, right=0.99, top=0.95, bottom=0.05)
ax: plt.Axes = plt.subplots(figsize=(16, 9))[1] # fig, axs
plt.subplots_adjust(left=0.04, right=0.99, top=0.95, bottom=0.05)
ax.set_title(f"My HGEMM vs cuBLAS, {get_device_name()}, Warmup={args.warmup}, Iters={args.iters}")
ax.set_xlabel("M=N=K")
ax.set_ylabel("TFLOPS")
Expand All @@ -224,36 +238,37 @@ def plot_tflops():
exclude_tags.append("MNK")
exclude_tags = set(exclude_tags)

def should_exclude(tag: str) -> bool:
topk_tflops = get_topk_tflops()
STATIS_INFO["(best)"] = get_best_tflops()
draw_tags = topk_tflops
draw_tags.append("(cublas)")
draw_tags.append("(best)")

def skip_it(tag: str) -> bool:
for etag in exclude_tags:
if etag in tag:
return True
if tag not in draw_tags:
return True
return False

topk_tflops = get_topk_tflops()
is_top_1 = True
# draw by topk order
for tag, tflops in STATIS_INFO.items():
if (should_exclude(tag)) or (tag not in topk_tflops
and "cublas" not in tag):
if skip_it(tag):
continue
if "cublas" in tag:
ax.plot(tflops, label=tag, linewidth=3)
else:
if is_top_1 and not args.no_hint_top1:
if "best" in tag and not args.no_plot_best:
ax.plot(tflops, label=tag, linewidth=4)
is_top_1 = False
else:
ax.plot(tflops, label=tag, linestyle='--')

ax.legend()
if args.save_tag:
plt.savefig(f"{args.save_tag}", dpi=300)
print(f"plot hgemm TFLOPS done, saved as {args.save_tag}")
else:
device_name = get_device_name().replace(" ", "_")
save_tag = f"{device_name}.png"
plt.savefig(save_tag, dpi=300)
print(f"plot hgemm TFLOPS done, saved as {save_tag}")
device_name = get_device_name().replace(" ", "_")
save_tag = f"{args.save_dir}/{device_name}.png"
plt.savefig(save_tag, dpi=300)
print(f"plot hgemm TFLOPS done, saved as {save_tag}")


def get_mnk(sep: int = args.SEP):
Expand Down Expand Up @@ -386,7 +401,4 @@ def get_mnk(sep: int = args.SEP):
print("-" * 130)

if args.plot_flops:
try:
plot_tflops()
except Exception as e:
print(f"plot hgemm TFLOPS failed, {e}")
plot_tflops()