Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion kernels/softmax/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,15 @@ online_safe_softmax_f32x4_pack_per_token_kernel<(H/4)><<<grid, block>>>( \
case 1024: \
LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(1024) \
break; \
case 2048: \
LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(2048) \
break; \
case 4096: \
LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(4096) \
break; \
default: \
throw std::runtime_error( \
"only support H: 128/256/512/1024; raise error if warp_num*4 > H"); \
"only support H: 128/256/.../4096;"); \
break; \
}

Expand Down
20 changes: 11 additions & 9 deletions kernels/softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

def run_benchmark(perf_func: callable, x: torch.Tensor,
tag: str, out: Optional[torch.Tensor] = None,
warmup: int = 10, iters: int = 1000,
warmup: int = 10, iters: int = 100,
show_all: bool = False):
if out is not None:
out.fill_(0)
Expand Down Expand Up @@ -60,7 +60,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
N = 128 * 128
print(" " * 45 + f"N={N}")
print("-" * 100)
x = torch.randn((N)).cuda().float()
x = torch.randn((N), device="cuda").cuda().float()
out = torch.zeros_like(x).cuda().float().contiguous()
run_benchmark(lib.softmax_f32, x, "f32(fence)", out)
run_benchmark(lib.softmax_f32x4, x, "f32x4(fence)", out)
Expand All @@ -71,7 +71,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
S, H = 4096, 256
print(" " * 45 + f"S={S}, H={H}")
print("-" * 100)
x = torch.randn((S, H)).cuda().float().contiguous()
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
out = torch.zeros_like(x).cuda().float().contiguous()
run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out)
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
Expand All @@ -95,7 +95,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
S, H = 4096, 512
print(" " * 45 + f"S={S}, H={H}")
print("-" * 100)
x = torch.randn((S, H)).cuda().float().contiguous()
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
out = torch.zeros_like(x).cuda().float().contiguous()
run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out)
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
Expand All @@ -119,7 +119,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
S, H = 4096, 1024
print(" " * 45 + f"S={S}, H={H}")
print("-" * 100)
x = torch.randn((S, H)).cuda().float().contiguous()
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
out = torch.zeros_like(x).cuda().float().contiguous()
run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out)
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
Expand All @@ -143,10 +143,11 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
S, H = 4096, 2048
print(" " * 45 + f"S={S}, H={H}")
print("-" * 100)
x = torch.randn((S, H)).cuda().float().contiguous()
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
out = torch.zeros_like(x).cuda().float().contiguous()
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
run_benchmark(lib.online_safe_softmax_f32x4_pack_per_token, x, "f32x4(safe+online)", out)
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")

print("-" * 100)
Expand All @@ -162,10 +163,11 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
S, H = 4096, 4096
print(" " * 45 + f"S={S}, H={H}")
print("-" * 100)
x = torch.randn((S, H)).cuda().float().contiguous()
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
out = torch.zeros_like(x).cuda().float().contiguous()
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
run_benchmark(lib.online_safe_softmax_f32x4_pack_per_token, x, "f32x4(safe+online)", out)
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")

print("-" * 100)
Expand All @@ -180,7 +182,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
S, H = 4096, 8192
print(" " * 45 + f"S={S}, H={H}")
print("-" * 100)
x = torch.randn((S, H)).cuda().float().contiguous()
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
out = torch.zeros_like(x).cuda().float().contiguous()
x_f16 = x.half().contiguous()
out_f16 = out.half().contiguous()
Expand All @@ -192,7 +194,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
S, H = 8192, 8192
print(" " * 45 + f"S={S}, H={H}")
print("-" * 100)
x = torch.randn((S, H)).cuda().float().contiguous()
x = torch.randn((S, H), device="cuda").cuda().float().contiguous()
out = torch.zeros_like(x).cuda().float().contiguous()
x_f16 = x.half().contiguous()
out_f16 = out.half().contiguous()
Expand Down