<a href="https://colab.research.google.com/github/zslsally/CUDA-Learn-Notes/blob/main/elementwise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

Sun Mar 16 22:32:22 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   38C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
!pip install ninja

Collecting ninja
  Downloading ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.3 kB)
Downloading ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (422 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.9/422.9 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ninja
Successfully installed ninja-1.11.1.3


In [None]:
!export TORCH_EXTENSIONS_DIR=/content/torch_extensions

In [None]:
import torch
print(torch.cuda.is_available())
import time
from torch.utils.cpp_extension import load
from typing import Optional
from functools import partial

True


In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x79be1dcdd690>

In [None]:
# Load the CUDA kernel as a python module
lib = load(name='elementwise_lib',
           sources=['elementwise.cu'],
           extra_cuda_cflags=[
               "-O3",
                "-U__CUDA_NO_HALF_OPERATORS__",
                "-U__CUDA_NO_HALF_CONVERSIONS__",
                "-U__CUDA_NO_HALF2_OPERATORS__",
                "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                "--expt-relaxed-constexpr",
                "--expt-extended-lambda",
                "--use_fast_math",
            ],
           extra_cflags=['-std=c++17'])

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [None]:
def run_benchmark(perf_func: callable, a: torch.Tensor, b: torch.Tensor, tag: str,
                  out: Optional[torch.Tensor] = None, warmup: int = 10,
                  iters: int = 1000, show_all: bool = False):
    # torch.dot vs custom dot_prod kernel
    if out is not None:
        out.fill_(0)
    # warmup
    if out is not None:
        for i in range(warmup):
            perf_func(a, b, out)
    else:
        for i in range(warmup):
            _ = perf_func(a, b)
    torch.cuda.synchronize()
    start = time.time()
    # iters
    if out is not None:
        for i in range(iters):
            perf_func(a, b, out)
    else:
        for i in range(iters):
            out = perf_func(a, b)
    torch.cuda.synchronize()
    end = time.time()
    total_time = (end - start) * 1000 # ms
    mean_time = total_time / iters
    out_info = f"out_{tag}"
    out_val = out.flatten().detach().cpu().numpy().tolist()[:2]
    out_val = [round(v, 8) for v in out_val]
    print(f"{out_info:>18}: {out_val}, time:{mean_time:.8f}ms")
    if show_all: print(out)
    return out, mean_time

In [None]:
Ss = [1024, 2048, 4096]
Ks = [1024, 2048, 4096]
SKs = [(S, K) for S in Ss for K in Ks]

for (S, K) in SKs:
    print("-" * 85)
    print(" " * 40 + f"S={S}, K={K}")
    a = torch.randn((S, K)).cuda().float().contiguous()
    b = torch.randn((S, K)).cuda().float().contiguous()
    c = torch.zeros_like(a).cuda().float().contiguous()
    run_benchmark(lib.elementwise_add_f32,   a, b, "f32",   c)
    run_benchmark(lib.elementwise_add_f32x4, a, b, "f32x4", c)
    run_benchmark(partial(torch.add, out=c), a, b, "f32_th")

    print("-" * 85)
    a_f16 = a.half().contiguous()
    b_f16 = b.half().contiguous()
    c_f16 = c.half().contiguous()
    run_benchmark(lib.elementwise_add_f16,        a_f16, b_f16, "f16",       c_f16)
    run_benchmark(lib.elementwise_add_f16x2,      a_f16, b_f16, "f16x2",     c_f16)
    run_benchmark(lib.elementwise_add_f16x8,      a_f16, b_f16, "f16x8",     c_f16)
    run_benchmark(lib.elementwise_add_f16x8_pack, a_f16, b_f16, "f16x8pack", c_f16)
    run_benchmark(partial(torch.add, out=c_f16),  a_f16, b_f16, "f16_th")
    print("-" * 85)

-------------------------------------------------------------------------------------
                                        S=1024, K=1024
           out_f32: [1.30998158, -1.69877255], time:0.06985998ms
         out_f32x4: [1.30998158, -1.69877255], time:0.05388236ms
        out_f32_th: [1.30998158, -1.69877255], time:0.05390501ms
-------------------------------------------------------------------------------------
           out_f16: [1.31054688, -1.69921875], time:0.04510260ms
         out_f16x2: [1.31054688, -1.69921875], time:0.02726936ms
         out_f16x8: [1.31054688, -1.69921875], time:0.03176880ms
     out_f16x8pack: [1.31054688, -1.69921875], time:0.02609944ms
        out_f16_th: [1.31054688, -1.69921875], time:0.02635026ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        S=1024, K=2048
           out_f32: [0.021