Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/scripts/setup-env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pip install --progress-bar=off -r requirements.txt
echo '::endgroup::'

echo '::group::Install extension-cpp'
python setup.py develop
pip install -e . --no-build-isolation
Copy link
Author

@mikaylagawarecki mikaylagawarecki Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was needed to get CI to run properly.

Otherwise, CI was installing the nightly but using torch 2.7 to build the extension

echo '::endgroup::'

echo '::group::Collect environment information'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- python-version: 3.13
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.4"
gpu-arch-version: "12.9"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we no longer have cu nightlies for 12.4

fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
Expand Down
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# C++/CUDA Extensions in PyTorch
# C++/CUDA Extensions in PyTorch with LibTorch Stable ABI

An example of writing a C++/CUDA extension for PyTorch using the [LibTorch Stable ABI](https://pytorch.org/docs/main/notes/libtorch_stable_abi.html).
See [here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial.

An example of writing a C++/CUDA extension for PyTorch. See
[here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial.
This repo demonstrates how to write an example `extension_cpp.ops.mymuladd`
custom op that has both custom CPU and CUDA kernels.
custom op that has both custom CPU and CUDA kernels, it leverages the LibTorch
Stable ABI to ensure that the extension built can be run with any version of
PyTorch >= 2.10.0.

The examples in this repo work with PyTorch 2.10+. For an example of how to use
the non-stable subset of LibTorch, see [this previous commit](https://github.com/pytorch/extension-cpp/tree/0ec4969c7bc8e15a8456e5eb9d9ca0a7ec15bc95).

The examples in this repo work with PyTorch 2.4+.

To build:
```
Expand Down
149 changes: 98 additions & 51 deletions extension_cpp/csrc/cuda/muladd.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>

#include <torch/csrc/stable/c/shim.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>

namespace extension_cpp {

Expand All @@ -13,21 +17,35 @@ __global__ void muladd_kernel(int numel, const float* a, const float* b, float c
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
}

at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
torch::stable::Tensor mymuladd_cuda(
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
double c) {
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);

torch::stable::Tensor a_contig = torch::stable::contiguous(a);
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
torch::stable::Tensor result = torch::stable::empty_like(a_contig);

const float* a_ptr = a_contig.const_data_ptr<float>();
const float* b_ptr = b_contig.const_data_ptr<float>();
float* result_ptr = result.mutable_data_ptr<float>();

int numel = a_contig.numel();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// For now, we rely on the raw shim API to get the current CUDA stream.
// This will be improved in a future release.
// When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to
// check the error code and throw an appropriate runtime_error otherwise.
void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);

muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
return result;
}
Expand All @@ -37,20 +55,34 @@ __global__ void mul_kernel(int numel, const float* a, const float* b, float* res
if (idx < numel) result[idx] = a[idx] * b[idx];
}

at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
torch::stable::Tensor mymul_cuda(
const torch::stable::Tensor& a,
const torch::stable::Tensor& b) {
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);

torch::stable::Tensor a_contig = torch::stable::contiguous(a);
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
torch::stable::Tensor result = torch::stable::empty_like(a_contig);

const float* a_ptr = a_contig.const_data_ptr<float>();
const float* b_ptr = b_contig.const_data_ptr<float>();
float* result_ptr = result.mutable_data_ptr<float>();

int numel = a_contig.numel();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// For now, we rely on the raw shim API to get the current CUDA stream.
// This will be improved in a future release.
// When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to
// check the error code and throw an appropriate runtime_error otherwise.
void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);

mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
return result;
}
Expand All @@ -60,32 +92,47 @@ __global__ void add_kernel(int numel, const float* a, const float* b, float* res
if (idx < numel) result[idx] = a[idx] + b[idx];
}

void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(b.sizes() == out.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_CHECK(out.dtype() == at::kFloat);
TORCH_CHECK(out.is_contiguous());
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = out.data_ptr<float>();
// An example of an operator that mutates one of its inputs.
void myadd_out_cuda(
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
torch::stable::Tensor& out) {
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
STD_TORCH_CHECK(b.sizes().equals(out.sizes()));
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);
STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CUDA);

torch::stable::Tensor a_contig = torch::stable::contiguous(a);
torch::stable::Tensor b_contig = torch::stable::contiguous(b);

const float* a_ptr = a_contig.const_data_ptr<float>();
const float* b_ptr = b_contig.const_data_ptr<float>();
float* result_ptr = out.mutable_data_ptr<float>();

int numel = a_contig.numel();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// For now, we rely on the raw shim API to get the current CUDA stream.
// This will be improved in a future release.
// When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to
// check the error code and throw an appropriate runtime_error otherwise.
void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);

add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
}


// Registers CUDA implementations for mymuladd, mymul, myadd_out
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
m.impl("mymuladd", &mymuladd_cuda);
m.impl("mymul", &mymul_cuda);
m.impl("myadd_out", &myadd_out_cuda);
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda));
m.impl("mymul", TORCH_BOX(&mymul_cuda));
m.impl("myadd_out", TORCH_BOX(&myadd_out_cuda));
}

}
116 changes: 67 additions & 49 deletions extension_cpp/csrc/muladd.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#include <Python.h>
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>

#include <vector>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>

extern "C" {
/* Creates a dummy empty _C module that can be imported from Python.
The import from Python will load the .so consisting of this file
in this extension, so that the TORCH_LIBRARY static initializers
in this extension, so that the STABLE_TORCH_LIBRARY static initializers
below are run. */
PyObject* PyInit__C(void)
{
Expand All @@ -26,75 +27,92 @@ extern "C" {

namespace extension_cpp {

at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
torch::stable::Tensor mymuladd_cpu(
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
double c) {
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);

torch::stable::Tensor a_contig = torch::stable::contiguous(a);
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
torch::stable::Tensor result = torch::stable::empty_like(a_contig);

const float* a_ptr = a_contig.const_data_ptr<float>();
const float* b_ptr = b_contig.const_data_ptr<float>();
float* result_ptr = result.mutable_data_ptr<float>();

for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
}
return result;
}

at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
torch::stable::Tensor mymul_cpu(
const torch::stable::Tensor& a,
const torch::stable::Tensor& b) {
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);

torch::stable::Tensor a_contig = torch::stable::contiguous(a);
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
torch::stable::Tensor result = torch::stable::empty_like(a_contig);

const float* a_ptr = a_contig.const_data_ptr<float>();
const float* b_ptr = b_contig.const_data_ptr<float>();
float* result_ptr = result.mutable_data_ptr<float>();

for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = a_ptr[i] * b_ptr[i];
}
return result;
}

// An example of an operator that mutates one of its inputs.
void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(b.sizes() == out.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_CHECK(out.dtype() == at::kFloat);
TORCH_CHECK(out.is_contiguous());
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = out.data_ptr<float>();
void myadd_out_cpu(
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
torch::stable::Tensor& out) {
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
STD_TORCH_CHECK(b.sizes().equals(out.sizes()));
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CPU);

torch::stable::Tensor a_contig = torch::stable::contiguous(a);
torch::stable::Tensor b_contig = torch::stable::contiguous(b);

const float* a_ptr = a_contig.const_data_ptr<float>();
const float* b_ptr = b_contig.const_data_ptr<float>();
float* result_ptr = out.mutable_data_ptr<float>();

for (int64_t i = 0; i < out.numel(); i++) {
result_ptr[i] = a_ptr[i] + b_ptr[i];
}
}

// Defines the operators
TORCH_LIBRARY(extension_cpp, m) {
STABLE_TORCH_LIBRARY(extension_cpp, m) {
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
m.def("mymul(Tensor a, Tensor b) -> Tensor");
m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
}

// Registers CPU implementations for mymuladd, mymul, myadd_out
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", &mymuladd_cpu);
m.impl("mymul", &mymul_cpu);
m.impl("myadd_out", &myadd_out_cpu);
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu));
m.impl("mymul", TORCH_BOX(&mymul_cpu));
m.impl("myadd_out", TORCH_BOX(&myadd_out_cpu));
}

}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[build-system]
requires = [
"setuptools",
"torch",
"torch>=2.10.0",
]
build-backend = "setuptools.build_meta"
Loading