diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index b8db850a3943..01075259e9fe 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1250,6 +1250,97 @@ test_custom_script_ops() { assert_git_not_dirty } +test_libtorch_agnostic_targetting() { + echo "Testing libtorch_agnostic runs correctly on TORCH_TARGET_VERSION" + + REPO_DIR=$(pwd) + WHEEL_DIR="${REPO_DIR}/test/cpp_extensions/.wheels" + + # Build wheel with current PyTorch (this has TORCH_TARGET_VERSION 2_9_0) + echo "Building 2.9 extension wheel with current PyTorch..." + pushd test/cpp_extensions/libtorch_agnostic_2_9_extension + time python setup.py bdist_wheel + + # Save the wheel + mkdir -p "$WHEEL_DIR" + cp dist/*.whl "$WHEEL_DIR/" + WHEEL_FILE=$(find "$WHEEL_DIR" -maxdepth 1 -name "*.whl" -type f | head -1) + echo "Built wheel: $(basename "$WHEEL_FILE")" + popd + + # Create venv and install PyTorch 2.9 + python -m venv venv_pytorch_2_9 + # shellcheck disable=SC1091 + . venv_pytorch_2_9/bin/activate + + # Clear PYTHONPATH to avoid using the development PyTorch + echo "Clearing PYTHONPATH to use only venv packages..." + unset PYTHONPATH + + # Upgrade pip to latest version + echo "Upgrading pip to latest version..." + pip install --upgrade pip + pip --version + + echo "Installing PyTorch 2.9..." + + # Install from release channel only + PYTORCH_VERSION="2.9.0" + + # Extract CUDA version from BUILD_ENVIRONMENT (e.g., "cuda12.1" -> "cu121") + if [[ "$BUILD_ENVIRONMENT" =~ cuda([0-9]+)\.([0-9]+) ]]; then + CUDA_MAJOR="${BASH_REMATCH[1]}" + CUDA_MINOR="${BASH_REMATCH[2]}" + CUDA_VERSION="cu${CUDA_MAJOR}${CUDA_MINOR}" + echo " Detected CUDA ${CUDA_MAJOR}.${CUDA_MINOR} from BUILD_ENVIRONMENT, using ${CUDA_VERSION}" + else + # Default to CPU build + CUDA_VERSION="cpu" + echo " No CUDA detected in BUILD_ENVIRONMENT, using CPU build" + fi + + if pip install torch=="${PYTORCH_VERSION}" --index-url https://download.pytorch.org/whl/${CUDA_VERSION}/; then + echo "Installed PyTorch ${PYTORCH_VERSION} from release channel (${CUDA_VERSION})" + else + echo " FAILED to install PyTorch 2.9.0 from release channel" + echo " URL: https://download.pytorch.org/whl/${CUDA_VERSION}/" + deactivate + rm -rf venv_pytorch_2_9 + return 1 + fi + + INSTALLED_VERSION=$(python -c "import torch; print(torch.__version__)" 2>/dev/null || echo "unknown") + echo " Installed version: $INSTALLED_VERSION" + + # Install test dependencies + echo "Installing test dependencies..." + pip install expecttest numpy unittest-xml-reporting + + # Install the pre-built wheel + echo "" + echo "Installing pre-built 2.9 extension wheel (built with PyTorch 2.10)..." + pip install "$WHEEL_FILE" + echo "Installed $(basename "$WHEEL_FILE") into PyTorch 2.9 environment" + + # Run tests with PyTorch 2.9 runtime (2.10 tests will be skipped automatically) + echo "" + echo "Running tests with PyTorch 2.9 runtime (using wheel built on PyTorch 2.10)..." + if time python test/cpp_extensions/test_libtorch_agnostic.py -v; then + echo "" + echo " Wheel built with current torch and TORCH_TARGET_VERSION 2_9_0 works with PyTorch 2.9 runtime!" + else + echo "targeting test failed" + deactivate + rm -rf venv_pytorch_2_9 "$WHEEL_DIR" + return 1 + fi + + deactivate + rm -rf venv_pytorch_2_9 "$WHEEL_DIR" + + assert_git_not_dirty +} + test_jit_hooks() { echo "Testing jit hooks in cpp" HOOK_BUILD="${CUSTOM_TEST_ARTIFACT_BUILD_DIR}/jit-hook-build" @@ -1722,6 +1813,8 @@ elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; elif [[ "${TEST_CONFIG}" == *backward* ]]; then test_forward_backward_compatibility # Do NOT add tests after bc check tests, see its comment. +elif [[ "${TEST_CONFIG}" == *libtorch_agnostic_targetting* ]]; then + test_libtorch_agnostic_targetting elif [[ "${TEST_CONFIG}" == *xla* ]]; then install_torchvision build_xla diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index e5fd10c70db6..51e211a5ad2a 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -70,6 +70,7 @@ jobs: { config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, + { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 6ba810c3a958..667c37727045 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -83,6 +83,7 @@ jobs: { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, + { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, ]} secrets: inherit diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/__init__.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/__init__.py similarity index 100% rename from test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/__init__.py rename to test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/__init__.py diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/make_tensor_clones_and_call_foreach.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/make_tensor_clones_and_call_foreach.cpp new file mode 100644 index 000000000000..d3dbab589139 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/make_tensor_clones_and_call_foreach.cpp @@ -0,0 +1,41 @@ +#include +#include +#include + +#include + +using torch::stable::Tensor; + +// Declare my__foreach_mul (defined in my__foreach_mul.cpp) +extern std::vector my__foreach_mul( + torch::headeronly::HeaderOnlyArrayRef self, + torch::headeronly::HeaderOnlyArrayRef other); + +// Helper function for cloning +Tensor my_clone(Tensor t) { + return clone(t); +} + +std::vector make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) { + // This function tests that my__foreach_mul can take in std::initializer_lists + // in addition to std::vectors. + Tensor t1_1 = my_clone(t1); + Tensor t1_2 = my_clone(t1); + Tensor t2_1 = my_clone(t2); + Tensor t2_2 = my_clone(t2); + return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2}); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def( + "make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]"); +} + +STABLE_TORCH_LIBRARY_IMPL( + libtorch_agnostic_2_10, + CompositeExplicitAutograd, + m) { + m.impl( + "make_tensor_clones_and_call_foreach", + TORCH_BOX(&make_tensor_clones_and_call_foreach)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cpu.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cpu.cpp new file mode 100644 index 000000000000..705439efffe6 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cpu.cpp @@ -0,0 +1,40 @@ +// This is duplicated from the libtorch_agnostic_2_9_extension +// as a negative test for test_version_compatibility.py + +#include +#include +#include +#include +#include +#include +#include + +#include "tensor_accessor_kernel.h" + +using torch::stable::Tensor; + +Tensor mv_tensor_accessor_cpu(Tensor m, Tensor v) { + STD_TORCH_CHECK(m.dim() == 2, "m must be 2D"); + STD_TORCH_CHECK(v.dim() == 1, "v must be 1D"); + STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold"); + STD_TORCH_CHECK(m.scalar_type() == v.scalar_type(), "m and v must have the same dtype"); + STD_TORCH_CHECK(m.device() == v.device(), "m and v must be on the same device"); + Tensor res = new_empty(m, {m.size(0)}); + THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cpu", + AT_WRAP(([&]() { + auto resa = Accessor_cpu(reinterpret_cast(res.data_ptr()), res.sizes().data(), res.strides().data()); + auto ma = Accessor_cpu(reinterpret_cast(m.data_ptr()), m.sizes().data(), m.strides().data()); + auto va = Accessor_cpu(reinterpret_cast(v.data_ptr()), v.sizes().data(), v.strides().data()); + mv_tensor_accessor_kernel(resa, ma, va); + })), + AT_FLOATING_TYPES); + return res; +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("mv_tensor_accessor_cpu(Tensor res, Tensor m, Tensor v) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("mv_tensor_accessor_cpu", TORCH_BOX(&mv_tensor_accessor_cpu)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu new file mode 100644 index 000000000000..7773210a089e --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu @@ -0,0 +1,47 @@ +// This is duplicated from the libtorch_agnostic_2_9_extension +// as a negative test for test_version_compatibility.py + +#include "tensor_accessor_kernel.h" + +#include +#include +#include +#include + +using torch::stable::Tensor; + +Tensor mv_tensor_accessor_cuda(Tensor m, Tensor v) { + STD_TORCH_CHECK(m.dim() == 2, "m must be 2D"); + STD_TORCH_CHECK(v.dim() == 1, "v must be 1D"); + STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold"); + STD_TORCH_CHECK( + m.scalar_type() == v.scalar_type(), "m and v must have the same dtype"); + STD_TORCH_CHECK( + m.device() == v.device(), "m and v must be on the same device"); + Tensor res = new_empty(m, {m.size(0)}); + THO_DISPATCH_V2( + m.scalar_type(), + "mv_tensor_accessor_cuda", + AT_WRAP(([&]() { + auto resa = Accessor_cuda( + reinterpret_cast(res.data_ptr()), + res.sizes().data(), + res.strides().data()); + auto ma = Accessor_cuda( + reinterpret_cast(m.data_ptr()), + m.sizes().data(), + m.strides().data()); + auto va = Accessor_cuda( + reinterpret_cast(v.data_ptr()), + v.sizes().data(), + v.strides().data()); + mv_tensor_accessor_kernel + <<<1, 1, 0, 0>>>(resa, ma, va); + })), + AT_FLOATING_TYPES); + return res; +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CUDA, m) { + m.impl("mv_tensor_accessor", TORCH_BOX(&mv_tensor_accessor_cuda)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul.cpp new file mode 100644 index 000000000000..834a63afea64 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul.cpp @@ -0,0 +1,20 @@ +#include +#include +#include +#include + +using torch::stable::Tensor; + +std::vector my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef self, torch::headeronly::HeaderOnlyArrayRef other) { + std::array stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)}; + aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data()); + return torch::stable::detail::to>(stack[0]); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my__foreach_mul", TORCH_BOX(&my__foreach_mul)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul_.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul_.cpp new file mode 100644 index 000000000000..8409e6890bdd --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul_.cpp @@ -0,0 +1,19 @@ +#include +#include +#include +#include + +using torch::stable::Tensor; + +void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef self, torch::headeronly::HeaderOnlyArrayRef other) { + std::array stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)}; + aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data()); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my__foreach_mul_", TORCH_BOX(&my__foreach_mul_)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp new file mode 100644 index 000000000000..6278dca9f281 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp @@ -0,0 +1,25 @@ +#include +#include +#include +#include + +#include + +using torch::stable::Tensor; + +Tensor my_empty( + torch::headeronly::HeaderOnlyArrayRef size, + std::optional dtype, + std::optional device, + std::optional pin_memory) { + return empty(size, dtype, device, pin_memory); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def( + "my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my_empty", TORCH_BOX(&my_empty)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_reshape.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_reshape.cpp new file mode 100644 index 000000000000..0a2b1f70f215 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_reshape.cpp @@ -0,0 +1,17 @@ +#include +#include +#include + +using torch::stable::Tensor; + +Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef shape) { + return reshape(t, shape); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_reshape(Tensor t, int[] shape) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my_reshape", TORCH_BOX(&my_reshape)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_view.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_view.cpp new file mode 100644 index 000000000000..25d8c5458924 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_view.cpp @@ -0,0 +1,20 @@ +#include +#include +#include + +using torch::stable::Tensor; + +Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef size) { + return view(t, size); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_view(Tensor t, int[] size) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL( + libtorch_agnostic_2_10, + CompositeExplicitAutograd, + m) { + m.impl("my_view", TORCH_BOX(&my_view)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/tensor_accessor_kernel.h b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/tensor_accessor_kernel.h new file mode 100644 index 000000000000..f1031f38060c --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/tensor_accessor_kernel.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +template +using Accessor_cpu = torch::headeronly::HeaderOnlyTensorAccessor; + +#if defined(__CUDACC__) || defined(__HIPCC__) +#define MAYBE_GLOBAL __global__ + +template +using Accessor_cuda = torch::headeronly::HeaderOnlyGenericPackedTensorAccessor; + +#else +#define MAYBE_GLOBAL +#endif + +template