In [None]:
# 编译
!/usr/local/bin/nvcc --version # 检查 nvcc 是否可用
import torch
from torch.utils.cpp_extension import load

fused_module = load(
    name="fused_op_ext",
    sources=["fused_op.cu", "fused_wrapper.cpp"],
    verbose=True # 打印编译信息
)

# 测试
N = 10000000 # 1000万个元素
A_torch = torch.rand(N, dtype=torch.float32, device='cuda')
B_torch = torch.rand(N, dtype=torch.float32, device='cuda')

# 1. 融合算子计时
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
C_fused = fused_module.fused_add_tanh(A_torch, B_torch)
end.record()
torch.cuda.synchronize()
fused_time = start.elapsed_time(end)

# 2. 原生 PyTorch 算子计时 (非融合)
start.record()
C_native = A_torch + torch.tanh(B_torch) # 两次 Global Memory 访问
end.record()
torch.cuda.synchronize()
native_time = start.elapsed_time(end)

# 验证结果
assert torch.allclose(C_fused, C_native), "Results do not match!"

print(f"\n--- Benchmark (N={N}) ---")
print(f"Native (A + tanh(B)): {native_time:.4f} ms")
print(f"Fused (A + tanh(B)):  {fused_time:.4f} ms")
print(f"Speedup: {native_time / fused_time:.2f}x")