diff --git a/.github/scripts/setup-env.sh b/.github/scripts/setup-env.sh index f4cebc5..21ac274 100755 --- a/.github/scripts/setup-env.sh +++ b/.github/scripts/setup-env.sh @@ -101,7 +101,7 @@ pip install --progress-bar=off -r requirements.txt echo '::endgroup::' echo '::group::Install extension-cpp' -python setup.py develop +pip install -e . --no-build-isolation echo '::endgroup::' echo '::group::Collect environment information' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1565937..5372bf9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,7 +19,7 @@ jobs: - python-version: 3.13 runner: linux.g5.4xlarge.nvidia.gpu gpu-arch-type: cuda - gpu-arch-version: "12.4" + gpu-arch-version: "12.9" fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: diff --git a/README.md b/README.md index d523814..f913239 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,16 @@ -# C++/CUDA Extensions in PyTorch +# C++/CUDA Extensions in PyTorch with LibTorch Stable ABI + +An example of writing a C++/CUDA extension for PyTorch using the [LibTorch Stable ABI](https://pytorch.org/docs/main/notes/libtorch_stable_abi.html). +See [here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial. -An example of writing a C++/CUDA extension for PyTorch. See -[here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial. This repo demonstrates how to write an example `extension_cpp.ops.mymuladd` -custom op that has both custom CPU and CUDA kernels. +custom op that has both custom CPU and CUDA kernels, it leverages the LibTorch +Stable ABI to ensure that the extension built can be run with any version of +PyTorch >= 2.10.0. + +The examples in this repo work with PyTorch 2.10+. For an example of how to use +the non-stable subset of LibTorch, see [this previous commit](https://github.com/pytorch/extension-cpp/tree/0ec4969c7bc8e15a8456e5eb9d9ca0a7ec15bc95). -The examples in this repo work with PyTorch 2.4+. To build: ``` diff --git a/extension_cpp/csrc/cuda/muladd.cu b/extension_cpp/csrc/cuda/muladd.cu index 769513b..e30834c 100644 --- a/extension_cpp/csrc/cuda/muladd.cu +++ b/extension_cpp/csrc/cuda/muladd.cu @@ -1,10 +1,14 @@ -#include -#include -#include +#include +#include +#include +#include +#include +#include + +#include #include #include -#include namespace extension_cpp { @@ -13,21 +17,35 @@ __global__ void muladd_kernel(int numel, const float* a, const float* b, float c if (idx < numel) result[idx] = a[idx] * b[idx] + c; } -at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = at::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); +torch::stable::Tensor mymuladd_cuda( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + double c) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); int numel = a_contig.numel(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // For now, we rely on the raw shim API to get the current CUDA stream. + // This will be improved in a future release. + // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to + // check the error code and throw an appropriate runtime_error otherwise. + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr); return result; } @@ -37,20 +55,34 @@ __global__ void mul_kernel(int numel, const float* a, const float* b, float* res if (idx < numel) result[idx] = a[idx] * b[idx]; } -at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = at::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); +torch::stable::Tensor mymul_cuda( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + int numel = a_contig.numel(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // For now, we rely on the raw shim API to get the current CUDA stream. + // This will be improved in a future release. + // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to + // check the error code and throw an appropriate runtime_error otherwise. + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr); return result; } @@ -60,32 +92,47 @@ __global__ void add_kernel(int numel, const float* a, const float* b, float* res if (idx < numel) result[idx] = a[idx] + b[idx]; } -void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(b.sizes() == out.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_CHECK(out.dtype() == at::kFloat); - TORCH_CHECK(out.is_contiguous()); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = out.data_ptr(); +// An example of an operator that mutates one of its inputs. +void myadd_out_cuda( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + torch::stable::Tensor& out) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(b.sizes().equals(out.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA); + STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CUDA); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = out.mutable_data_ptr(); + int numel = a_contig.numel(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // For now, we rely on the raw shim API to get the current CUDA stream. + // This will be improved in a future release. + // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to + // check the error code and throw an appropriate runtime_error otherwise. + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr); } - // Registers CUDA implementations for mymuladd, mymul, myadd_out -TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { - m.impl("mymuladd", &mymuladd_cuda); - m.impl("mymul", &mymul_cuda); - m.impl("myadd_out", &myadd_out_cuda); +STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda)); + m.impl("mymul", TORCH_BOX(&mymul_cuda)); + m.impl("myadd_out", TORCH_BOX(&myadd_out_cuda)); } } diff --git a/extension_cpp/csrc/muladd.cpp b/extension_cpp/csrc/muladd.cpp index d68332a..55ea10f 100644 --- a/extension_cpp/csrc/muladd.cpp +++ b/extension_cpp/csrc/muladd.cpp @@ -1,14 +1,15 @@ #include -#include -#include -#include -#include +#include +#include +#include +#include +#include extern "C" { /* Creates a dummy empty _C module that can be imported from Python. The import from Python will load the .so consisting of this file - in this extension, so that the TORCH_LIBRARY static initializers + in this extension, so that the STABLE_TORCH_LIBRARY static initializers below are run. */ PyObject* PyInit__C(void) { @@ -26,36 +27,47 @@ extern "C" { namespace extension_cpp { -at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); +torch::stable::Tensor mymuladd_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + double c) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { result_ptr[i] = a_ptr[i] * b_ptr[i] + c; } return result; } -at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); +torch::stable::Tensor mymul_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { result_ptr[i] = a_ptr[i] * b_ptr[i]; } @@ -63,38 +75,44 @@ at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) { } // An example of an operator that mutates one of its inputs. -void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(b.sizes() == out.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_CHECK(out.dtype() == at::kFloat); - TORCH_CHECK(out.is_contiguous()); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = out.data_ptr(); +void myadd_out_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + torch::stable::Tensor& out) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(b.sizes().equals(out.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = out.mutable_data_ptr(); + for (int64_t i = 0; i < out.numel(); i++) { result_ptr[i] = a_ptr[i] + b_ptr[i]; } } // Defines the operators -TORCH_LIBRARY(extension_cpp, m) { +STABLE_TORCH_LIBRARY(extension_cpp, m) { m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); m.def("mymul(Tensor a, Tensor b) -> Tensor"); m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()"); } // Registers CPU implementations for mymuladd, mymul, myadd_out -TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { - m.impl("mymuladd", &mymuladd_cpu); - m.impl("mymul", &mymul_cpu); - m.impl("myadd_out", &myadd_out_cpu); +STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu)); + m.impl("mymul", TORCH_BOX(&mymul_cpu)); + m.impl("myadd_out", TORCH_BOX(&myadd_out_cpu)); } } diff --git a/pyproject.toml b/pyproject.toml index 918072e..ffef670 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ "setuptools", - "torch", + "torch>=2.10.0", ] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 0dde1e4..ed2ffdf 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ library_name = "extension_cpp" + if torch.__version__ >= "2.6.0": py_limited_api = True else: @@ -39,9 +40,19 @@ def get_extensions(): "-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", "-DPy_LIMITED_API=0x03090000", # min CPython version 3.9 + # define TORCH_TARGET_VERSION with min version 2.10 to expose only the + # stable API subset from torch + # Format: [MAJ 1 byte][MIN 1 byte][PATCH 1 byte][ABI TAG 5 bytes] + # 2.10.0 = 0x020A000000000000 + "-DTORCH_TARGET_VERSION=0x020a000000000000", ], "nvcc": [ "-O3" if not debug_mode else "-O0", + # NVCC also needs TORCH_TARGET_VERSION for stable ABI in CUDA code + "-DTORCH_TARGET_VERSION=0x020a000000000000", + # USE_CUDA is currently needed for aoti_torch_get_current_cuda_stream + # declaration in shim.h. This will be improved in a future release. + "-DUSE_CUDA", ], } if debug_mode: @@ -78,7 +89,7 @@ def get_extensions(): packages=find_packages(), ext_modules=get_extensions(), install_requires=["torch"], - description="Example of PyTorch C++ and CUDA extensions", + description="Example of PyTorch C++ and CUDA extensions using LibTorch Stable ABI", long_description=open("README.md").read(), long_description_content_type="text/markdown", url="https://github.com/pytorch/extension-cpp",