In [16]:
import torch

def smooth_relu_ref(x, alpha=1.0):
    zero = torch.zeros_like(x)
    mid = x * x / (2.0*alpha)
    return torch.where(x<0, zero, torch.where(x>alpha, x, mid))

In [17]:
x = torch.tensor([-1., 0.2, 2.0], requires_grad=True)
y = smooth_relu_ref(x, 1.0).sum()
y.backward()
print(x.grad)

tensor([0.0000, 0.2000, 1.0000])


In [18]:
from torch.utils.cpp_extension import load_inline
import torch

src = r"""
#include <torch/extension.h>
torch::Tensor add_one(torch::Tensor x) { return x + 1; }
"""

ext = load_inline(
    name="demo_v1",
    cpp_sources=[src],      # just the function
    functions=["add_one"],  # let PyTorch autogenerate the pybind11 module
    verbose=True
)

print(ext.add_one(torch.tensor([1, 2, 3])))

tensor([2, 3, 4])


In [20]:
from torch.utils.cpp_extension import load_inline
import torch

src = r"""
#include <torch/extension.h>
torch::Tensor add_one(torch::Tensor x) { return x + 1; }

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("add_one", &add_one);
}
"""

ext = load_inline(
    name="demo_v2",
    cpp_sources=[src],      # only one source that defines the module
    functions=[],           # <- DO NOT pass functions here
    verbose=True
)

print(ext.add_one(torch.tensor([1, 2, 3])))

[1/2] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=demo_v2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /Users/regansong/Documents/github/pytorch-cpp-ext/venv/lib/python3.12/site-packages/torch/include -isystem /Users/regansong/Documents/github/pytorch-cpp-ext/venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /opt/homebrew/opt/python@3.12/Frameworks/Python.framework/Versions/3.12/include/python3.12 -fPIC -std=c++17 -c /Users/regansong/Library/Caches/torch_extensions/py312_cpu/demo_v2/main.cpp -o main.o 
[31mFAILED: [code=1] [0mmain.o 
c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=demo_v2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /Users/regansong/Documents/github/pytorch-cpp-ext/venv/lib/python3.12/site-packages/torch/include -isystem /Users/regansong/Documents/github/pytorch-cpp-ext/venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /opt/homebrew/opt/python@3.12/Frameworks/Python.framework/Versions/3.12/include/python3.12 -fPIC -st

RuntimeError: Error building extension 'demo_v2'

Two approaches - let PyTorch autogenerate pybind11 bindings (first cell), or manually write them (second cell).