# float32x4向量加法

## 任务理解

CUDA中的float4表示4个float组成的一组数据，本任务是对于两个形状相同的浮点矩阵A和B进行逐元素加法，一次读取和运算4个元素，得到C。

## 具体实现

CUDA将一个block内的线程组织为三维结构，方便处理多维数据。

在CUDA代码：
```c
int idx = blockIdx.x * blockDim.x + threadIdx.x;
```
这里用来获取在x方向上的索引，我们把矩阵展开成一个一维数组进行处理。

注意CUDA kenel的调用方法，内核名称之后要跟着<<<blocks, threads>>>。

在C++中，`reinterpret_cast`是一个类型转换操作符，不改变指针指向的内存，只是告诉编译器“把这个类型当做是另外一种类型来看待”。

`data_ptr<T>()`返回一个张量首地址的原始指针，这是一个模版函数。


In [None]:
%%writefile elementwise_fp32x4.cu
#include <torch/extension.h>
#include <cuda_runtime.h>

// float4向量化的kernel
__global__ void elementwise_add_f32x4_kernel (const float4 *A, const float4 *B, float4 *C, int N4) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < N4) {
        float4 va = A[idx];
        float4 vb = B[idx];
        C[idx] = make_float4(va.x + vb.x, va.y + vb.y, va.z + vb.z, va.w + vb.w);
    }
}

void elementwise_add_f32x4 (torch::Tensor A, torch::Tensor B, torch::Tensor C) {
    // numel()函数用来获取A中的元素总数
    int N = A.numel();
    int N4 = N / 4;
    int threads = 256;
    // 下面这个计算其实就是向上取整，数学有点意思
    int blocks = (N4 + threads - 1) / threads;

    // 主体部分按照float4来进行处理
    elementwise_add_f32x4_kernel<<<blocks, threads>>>(
        reinterpret_cast<const float4*>(A.data_ptr<float>()),
        reinterpret_cast<const float4*>(B.data_ptr<float>()),
        reinterpret_cast<float4*>(C.data_ptr<float>()),
        N4
    );

    // 尾部处理
    int remain = N % 4;
    if (remain > 0) {
        int start = N - remain;
        // 这里会自动推导为float*
        auto A_ptr = A.data_ptr<float>();
        auto B_ptr = B.data_ptr<float>();
        auto C_ptr = C.data_ptr<float>();
        for (int i = 0; i < remain; i++) {
            C_ptr[start + i] = A_ptr[start + i] + B_ptr[start + i];
        }
    }
}

// PyTorch绑定
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("elementwise_add_f32x4", &elementwise_add_f32x4, "Elementwise Add f32x4 (float4 vectorized)");
}


Overwriting elementwise_fp32x4.cu


In [9]:
import time
import torch
from torch.utils.cpp_extension import load

torch.set_grad_enabled(False)

# ===============================
# Load the CUDA kernel (fp32x4)
# ===============================
lib = load(
    name="elementwise_fp32x4",
    sources=["elementwise_fp32x4.cu"],
    extra_cuda_cflags=[
        "-O3",
        "--expt-relaxed-constexpr",
        "--expt-extended-lambda",
        "--use_fast_math",
        "-gencode", "arch=compute_89,code=sm_89" 
    ],
    extra_cflags=["-std=c++17"],
    verbose=True
)

# ===============================
# Benchmark function
# ===============================
def run_benchmark(func, a, b, out=None, warmup=10, iters=1000):
    if out is not None:
        out.fill_(0)
    # Warmup
    for _ in range(warmup):
        func(a, b, out)
    torch.cuda.synchronize()

    start = time.time()
    for _ in range(iters):
        func(a, b, out)
    torch.cuda.synchronize()
    end = time.time()

    mean_time = (end - start) * 1000 / iters  # ms per iteration
    out_val = out.flatten().detach().cpu().numpy().tolist()[:2]
    out_val = [round(v, 8) for v in out_val]
    print(f"fp32x4: {out_val}, mean time: {mean_time:.8f} ms")
    return out, mean_time

# ===============================
# Test different sizes
# ===============================
Ss = [1024, 2048, 4096]
Ks = [1024, 2048, 4096]

for S in Ss:
    for K in Ks:
        print("-" * 60)
        print(f"S={S}, K={K}")
        a = torch.randn((S, K), device="cuda", dtype=torch.float32).contiguous()
        b = torch.randn((S, K), device="cuda", dtype=torch.float32).contiguous()
        c = torch.zeros_like(a)
        
        run_benchmark(lib.elementwise_add_f32x4, a, b, c)

        # 验证结果正确性
        if torch.allclose(a+b, c):
            print("Result check: PASS")
        else:
            print("Result check: FAIL")


Using /home/shaneyale/.cache/torch_extensions/py312_cu124 as PyTorch extensions root...
The input conditions for extension module elementwise_fp32x4 have changed. Bumping to version 2 and re-building as elementwise_fp32x4_v2...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/shaneyale/.cache/torch_extensions/py312_cu124/elementwise_fp32x4/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module elementwise_fp32x4_v2...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/2] /usr/local/cuda-12.4/bin/nvcc --generate-dependencies-with-compile --dependency-output elementwise_fp32x4.cuda.o.d -DTORCH_EXTENSION_NAME=elementwise_fp32x4_v2 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/shaneyale/miniconda3/envs/cu124/lib/python3.12/site-packages/torch/include -isystem /home/shaneyale/miniconda3/envs/cu124/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /home/shaneyale/miniconda3/envs/cu124/lib/python3.12/site-packages/torch/include/TH -isystem /home/shaneyale/miniconda3/envs/cu124/lib/python3.12/site-packages/torch/include/THC -isystem /usr/local/cuda-12.4/include -isystem /home/shaneyale/miniconda3/envs/cu124/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_89,co

Loading extension module elementwise_fp32x4_v2...


fp32x4: [1.86473846, -0.87125796], mean time: 0.01842213 ms
Result check: PASS
------------------------------------------------------------
S=1024, K=2048
fp32x4: [-1.71387815, 0.54864913], mean time: 0.02921009 ms
Result check: PASS
------------------------------------------------------------
S=1024, K=4096
fp32x4: [-0.11894207, -1.35181189], mean time: 0.20104289 ms
Result check: PASS
------------------------------------------------------------
S=2048, K=1024
fp32x4: [1.36392868, 0.77033579], mean time: 0.01551175 ms
Result check: PASS
------------------------------------------------------------
S=2048, K=2048
fp32x4: [-0.62173986, 1.71409369], mean time: 0.20037961 ms
Result check: PASS
------------------------------------------------------------
S=2048, K=4096
fp32x4: [1.70619178, -2.45300078], mean time: 0.41244340 ms
Result check: PASS
------------------------------------------------------------
S=4096, K=1024
fp32x4: [-0.64906311, 0.39583477], mean time: 0.20041490 ms
Result che