From f93b32382debfcfd43b78d6b382f0cca7a28c432 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Fri, 29 Nov 2024 17:06:52 +0800 Subject: [PATCH 1/2] Update softmax.cu --- kernels/softmax/softmax.cu | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/kernels/softmax/softmax.cu b/kernels/softmax/softmax.cu index 3d03a33e..8617eeb9 100644 --- a/kernels/softmax/softmax.cu +++ b/kernels/softmax/softmax.cu @@ -600,9 +600,15 @@ online_safe_softmax_f32x4_pack_per_token_kernel<(H/4)><<>>( \ case 1024: \ LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(1024) \ break; \ + case 2048: \ + LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(2048) \ + break; \ + case 4096: \ + LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(4096) \ + break; \ default: \ throw std::runtime_error( \ - "only support H: 128/256/512/1024; raise error if warp_num*4 > H"); \ + "only support H: 128/256/.../4096;"); \ break; \ } From 52db58e777d5568eed22c6bf88f32727d8d98a5c Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Fri, 29 Nov 2024 17:07:15 +0800 Subject: [PATCH 2/2] Update softmax.py --- kernels/softmax/softmax.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/kernels/softmax/softmax.py b/kernels/softmax/softmax.py index 976f97d9..87566b72 100644 --- a/kernels/softmax/softmax.py +++ b/kernels/softmax/softmax.py @@ -24,7 +24,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, tag: str, out: Optional[torch.Tensor] = None, - warmup: int = 10, iters: int = 1000, + warmup: int = 10, iters: int = 100, show_all: bool = False): if out is not None: out.fill_(0) @@ -60,7 +60,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, N = 128 * 128 print(" " * 45 + f"N={N}") print("-" * 100) -x = torch.randn((N)).cuda().float() +x = torch.randn((N), device="cuda").cuda().float() out = torch.zeros_like(x).cuda().float().contiguous() run_benchmark(lib.softmax_f32, x, "f32(fence)", out) run_benchmark(lib.softmax_f32x4, x, "f32x4(fence)", out) @@ -71,7 +71,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, S, H = 4096, 256 print(" " * 45 + f"S={S}, H={H}") print("-" * 100) -x = torch.randn((S, H)).cuda().float().contiguous() +x = torch.randn((S, H), device="cuda").cuda().float().contiguous() out = torch.zeros_like(x).cuda().float().contiguous() run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out) run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) @@ -95,7 +95,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, S, H = 4096, 512 print(" " * 45 + f"S={S}, H={H}") print("-" * 100) -x = torch.randn((S, H)).cuda().float().contiguous() +x = torch.randn((S, H), device="cuda").cuda().float().contiguous() out = torch.zeros_like(x).cuda().float().contiguous() run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out) run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) @@ -119,7 +119,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, S, H = 4096, 1024 print(" " * 45 + f"S={S}, H={H}") print("-" * 100) -x = torch.randn((S, H)).cuda().float().contiguous() +x = torch.randn((S, H), device="cuda").cuda().float().contiguous() out = torch.zeros_like(x).cuda().float().contiguous() run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out) run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) @@ -143,10 +143,11 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, S, H = 4096, 2048 print(" " * 45 + f"S={S}, H={H}") print("-" * 100) -x = torch.randn((S, H)).cuda().float().contiguous() +x = torch.randn((S, H), device="cuda").cuda().float().contiguous() out = torch.zeros_like(x).cuda().float().contiguous() run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out) +run_benchmark(lib.online_safe_softmax_f32x4_pack_per_token, x, "f32x4(safe+online)", out) run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)") print("-" * 100) @@ -162,10 +163,11 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, S, H = 4096, 4096 print(" " * 45 + f"S={S}, H={H}") print("-" * 100) -x = torch.randn((S, H)).cuda().float().contiguous() +x = torch.randn((S, H), device="cuda").cuda().float().contiguous() out = torch.zeros_like(x).cuda().float().contiguous() run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out) +run_benchmark(lib.online_safe_softmax_f32x4_pack_per_token, x, "f32x4(safe+online)", out) run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)") print("-" * 100) @@ -180,7 +182,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, S, H = 4096, 8192 print(" " * 45 + f"S={S}, H={H}") print("-" * 100) -x = torch.randn((S, H)).cuda().float().contiguous() +x = torch.randn((S, H), device="cuda").cuda().float().contiguous() out = torch.zeros_like(x).cuda().float().contiguous() x_f16 = x.half().contiguous() out_f16 = out.half().contiguous() @@ -192,7 +194,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, S, H = 8192, 8192 print(" " * 45 + f"S={S}, H={H}") print("-" * 100) -x = torch.randn((S, H)).cuda().float().contiguous() +x = torch.randn((S, H), device="cuda").cuda().float().contiguous() out = torch.zeros_like(x).cuda().float().contiguous() x_f16 = x.half().contiguous() out_f16 = out.half().contiguous()