In [1]:
%%writefile add_kernel.cu
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

// CUDA 核函数
template <typename scalar_t>
__global__ void my_add_kernel(const scalar_t* __restrict__ input,
                              scalar_t* __restrict__ output,
                              int size) {
    const int index = blockIdx.x * blockDim.x + threadIdx.x;
    if (index < size) {
        output[index] = input[index] + 1.0;
    }
}

// C++ 调用的启动函数
void my_add_cuda_launcher(const torch::Tensor& input, torch::Tensor& output) {
    const int threads = 1024;
    const int blocks = (input.numel() + threads - 1) / threads;

    // AT_DISPATCH_FLOATING_TYPES 宏用于自动处理 float/double 类型分发
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "my_add_kernel", ([&] {
        my_add_kernel<scalar_t><<<blocks, threads>>>(
            input.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            input.numel());
    }));
}

Writing add_kernel.cu


In [2]:
%%writefile add_op.cpp
#include <torch/extension.h>

// 声明 CUDA launcher 函数
void my_add_cuda_launcher(const torch::Tensor& input, torch::Tensor& output);

// C++ 包装函数：检查并调用 CUDA
torch::Tensor my_add(torch::Tensor input) {
    // 检查输入是否在 CUDA 上
    TORCH_CHECK(input.is_cuda(), "Input tensor must be a CUDA tensor");
    // 检查内存是否连续
    TORCH_CHECK(input.is_contiguous(), "Input tensor must be contiguous");

    auto output = torch::zeros_like(input);
    my_add_cuda_launcher(input, output);
    return output;
}

// PyBind11 绑定模块
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("add_one", &my_add, "My custom add one operator");
}

Writing add_op.cpp


In [5]:
import torch
from torch.utils.cpp_extension import load

# 编译并加载模块
# name: 编译后的模块名
# sources: 源文件列表
my_extension = load(
    name="add_op",
    sources=["add_op.cpp", "add_kernel.cu"],
    verbose=False
)

# 测试
input_tensor = torch.randn(5, device="cuda")
pytorch_output = input_tensor+1

# 调用算子
output_tensor = my_extension.add_one(input_tensor)

if (pytorch_output-output_tensor).abs().max() <0.1:
    print("+1 算子功能正常")
    print("Output:", output_tensor)

+1 算子功能正常
Output: tensor([ 0.6145, -0.3479,  1.6543,  0.3850,  0.5882], device='cuda:0')
