Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7d8992c
Add empty to stable ops
mikaylagawarecki Nov 12, 2025
85b1f20
Add reshape, view, flatten to torch/csrc/stable
mikaylagawarecki Nov 12, 2025
c96c2f8
Update on "Add reshape, view, flatten to torch/csrc/stable"
mikaylagawarecki Nov 12, 2025
52c7274
Update base for Update on "Add reshape, view, flatten to torch/csrc/s…
mikaylagawarecki Nov 13, 2025
aa52712
Update on "Add reshape, view, flatten to torch/csrc/stable"
mikaylagawarecki Nov 13, 2025
0b0e15a
Fix TORCH_FEATURE_VERSION guards
mikaylagawarecki Nov 14, 2025
65708c0
Split libtorch agnostic tests by feature version
mikaylagawarecki Nov 14, 2025
3e01a0d
Test libtorch_agnostic with TORCH_TARGET_VERSION on target pytorch ve…
mikaylagawarecki Nov 14, 2025
7e517c5
Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target…
mikaylagawarecki Nov 14, 2025
fd02d3f
Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target…
mikaylagawarecki Nov 14, 2025
f3897df
Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target…
mikaylagawarecki Nov 14, 2025
9db7e47
Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target…
mikaylagawarecki Nov 14, 2025
f816f3d
Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target…
mikaylagawarecki Nov 14, 2025
a3ccf16
Update base for Update on "Test libtorch_agnostic with TORCH_TARGET_V…
mikaylagawarecki Nov 17, 2025
86dddb2
Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target…
mikaylagawarecki Nov 17, 2025
605d100
Update base for Update on "Test libtorch_agnostic with TORCH_TARGET_V…
mikaylagawarecki Nov 17, 2025
6431f2c
Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target…
mikaylagawarecki Nov 17, 2025
f7054bd
Update base for Update on "Test libtorch_agnostic with TORCH_TARGET_V…
mikaylagawarecki Nov 17, 2025
e5dd318
Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target…
mikaylagawarecki Nov 17, 2025
b476610
Test that TORCH_FEATURE_VERSION guards are used where needed
mikaylagawarecki Nov 17, 2025
aec0929
Update base for Update on "Test that TORCH_FEATURE_VERSION guards are…
mikaylagawarecki Nov 17, 2025
0b4d592
Update on "Test that TORCH_FEATURE_VERSION guards are used where needed"
mikaylagawarecki Nov 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,97 @@ test_custom_script_ops() {
assert_git_not_dirty
}

test_libtorch_agnostic_targetting() {
echo "Testing libtorch_agnostic runs correctly on TORCH_TARGET_VERSION"

REPO_DIR=$(pwd)
WHEEL_DIR="${REPO_DIR}/test/cpp_extensions/.wheels"

# Build wheel with current PyTorch (this has TORCH_TARGET_VERSION 2_9_0)
echo "Building 2.9 extension wheel with current PyTorch..."
pushd test/cpp_extensions/libtorch_agnostic_2_9_extension
time python setup.py bdist_wheel

# Save the wheel
mkdir -p "$WHEEL_DIR"
cp dist/*.whl "$WHEEL_DIR/"
WHEEL_FILE=$(find "$WHEEL_DIR" -maxdepth 1 -name "*.whl" -type f | head -1)
echo "Built wheel: $(basename "$WHEEL_FILE")"
popd

# Create venv and install PyTorch 2.9
python -m venv venv_pytorch_2_9
# shellcheck disable=SC1091
. venv_pytorch_2_9/bin/activate

# Clear PYTHONPATH to avoid using the development PyTorch
echo "Clearing PYTHONPATH to use only venv packages..."
unset PYTHONPATH

# Upgrade pip to latest version
echo "Upgrading pip to latest version..."
pip install --upgrade pip
pip --version

echo "Installing PyTorch 2.9..."

# Install from release channel only
PYTORCH_VERSION="2.9.0"

# Extract CUDA version from BUILD_ENVIRONMENT (e.g., "cuda12.1" -> "cu121")
if [[ "$BUILD_ENVIRONMENT" =~ cuda([0-9]+)\.([0-9]+) ]]; then
CUDA_MAJOR="${BASH_REMATCH[1]}"
CUDA_MINOR="${BASH_REMATCH[2]}"
CUDA_VERSION="cu${CUDA_MAJOR}${CUDA_MINOR}"
echo " Detected CUDA ${CUDA_MAJOR}.${CUDA_MINOR} from BUILD_ENVIRONMENT, using ${CUDA_VERSION}"
else
# Default to CPU build
CUDA_VERSION="cpu"
echo " No CUDA detected in BUILD_ENVIRONMENT, using CPU build"
fi

if pip install torch=="${PYTORCH_VERSION}" --index-url https://download.pytorch.org/whl/${CUDA_VERSION}/; then
echo "Installed PyTorch ${PYTORCH_VERSION} from release channel (${CUDA_VERSION})"
else
echo " FAILED to install PyTorch 2.9.0 from release channel"
echo " URL: https://download.pytorch.org/whl/${CUDA_VERSION}/"
deactivate
rm -rf venv_pytorch_2_9
return 1
fi

INSTALLED_VERSION=$(python -c "import torch; print(torch.__version__)" 2>/dev/null || echo "unknown")
echo " Installed version: $INSTALLED_VERSION"

# Install test dependencies
echo "Installing test dependencies..."
pip install expecttest numpy unittest-xml-reporting

# Install the pre-built wheel
echo ""
echo "Installing pre-built 2.9 extension wheel (built with PyTorch 2.10)..."
pip install "$WHEEL_FILE"
echo "Installed $(basename "$WHEEL_FILE") into PyTorch 2.9 environment"

# Run tests with PyTorch 2.9 runtime (2.10 tests will be skipped automatically)
echo ""
echo "Running tests with PyTorch 2.9 runtime (using wheel built on PyTorch 2.10)..."
if time python test/cpp_extensions/test_libtorch_agnostic.py -v; then
echo ""
echo " Wheel built with current torch and TORCH_TARGET_VERSION 2_9_0 works with PyTorch 2.9 runtime!"
else
echo "targeting test failed"
deactivate
rm -rf venv_pytorch_2_9 "$WHEEL_DIR"
return 1
fi

deactivate
rm -rf venv_pytorch_2_9 "$WHEEL_DIR"

assert_git_not_dirty
}

test_jit_hooks() {
echo "Testing jit hooks in cpp"
HOOK_BUILD="${CUSTOM_TEST_ARTIFACT_BUILD_DIR}/jit-hook-build"
Expand Down Expand Up @@ -1722,6 +1813,8 @@ elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]];
elif [[ "${TEST_CONFIG}" == *backward* ]]; then
test_forward_backward_compatibility
# Do NOT add tests after bc check tests, see its comment.
elif [[ "${TEST_CONFIG}" == *libtorch_agnostic_targetting* ]]; then
test_libtorch_agnostic_targetting
elif [[ "${TEST_CONFIG}" == *xla* ]]; then
install_torchvision
build_xla
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
{ config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
]}
secrets: inherit

Expand Down
1 change: 1 addition & 0 deletions .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ jobs:
{ config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
{ config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
{ config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" },
{ config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" },
]}
secrets: inherit

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>

#include <vector>

using torch::stable::Tensor;

// Declare my__foreach_mul (defined in my__foreach_mul.cpp)
extern std::vector<Tensor> my__foreach_mul(
torch::headeronly::HeaderOnlyArrayRef<Tensor> self,
torch::headeronly::HeaderOnlyArrayRef<Tensor> other);

// Helper function for cloning
Tensor my_clone(Tensor t) {
return clone(t);
}

std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
// This function tests that my__foreach_mul can take in std::initializer_lists
// in addition to std::vectors.
Tensor t1_1 = my_clone(t1);
Tensor t1_2 = my_clone(t1);
Tensor t2_1 = my_clone(t2);
Tensor t2_2 = my_clone(t2);
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def(
"make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]");
}

STABLE_TORCH_LIBRARY_IMPL(
libtorch_agnostic_2_10,
CompositeExplicitAutograd,
m) {
m.impl(
"make_tensor_clones_and_call_foreach",
TORCH_BOX(&make_tensor_clones_and_call_foreach));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// This is duplicated from the libtorch_agnostic_2_9_extension
// as a negative test for test_version_compatibility.py

#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/TensorAccessor.h>

#include "tensor_accessor_kernel.h"

using torch::stable::Tensor;

Tensor mv_tensor_accessor_cpu(Tensor m, Tensor v) {
STD_TORCH_CHECK(m.dim() == 2, "m must be 2D");
STD_TORCH_CHECK(v.dim() == 1, "v must be 1D");
STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold");
STD_TORCH_CHECK(m.scalar_type() == v.scalar_type(), "m and v must have the same dtype");
STD_TORCH_CHECK(m.device() == v.device(), "m and v must be on the same device");
Tensor res = new_empty(m, {m.size(0)});
THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cpu",
AT_WRAP(([&]() {
auto resa = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(res.data_ptr()), res.sizes().data(), res.strides().data());
auto ma = Accessor_cpu<scalar_t, 2>(reinterpret_cast<scalar_t*>(m.data_ptr()), m.sizes().data(), m.strides().data());
auto va = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(v.data_ptr()), v.sizes().data(), v.strides().data());
mv_tensor_accessor_kernel<Accessor_cpu, scalar_t>(resa, ma, va);
})),
AT_FLOATING_TYPES);
return res;
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("mv_tensor_accessor_cpu(Tensor res, Tensor m, Tensor v) -> Tensor");
}

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("mv_tensor_accessor_cpu", TORCH_BOX(&mv_tensor_accessor_cpu));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// This is duplicated from the libtorch_agnostic_2_9_extension
// as a negative test for test_version_compatibility.py

#include "tensor_accessor_kernel.h"

#include <cuda_runtime.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>

using torch::stable::Tensor;

Tensor mv_tensor_accessor_cuda(Tensor m, Tensor v) {
STD_TORCH_CHECK(m.dim() == 2, "m must be 2D");
STD_TORCH_CHECK(v.dim() == 1, "v must be 1D");
STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold");
STD_TORCH_CHECK(
m.scalar_type() == v.scalar_type(), "m and v must have the same dtype");
STD_TORCH_CHECK(
m.device() == v.device(), "m and v must be on the same device");
Tensor res = new_empty(m, {m.size(0)});
THO_DISPATCH_V2(
m.scalar_type(),
"mv_tensor_accessor_cuda",
AT_WRAP(([&]() {
auto resa = Accessor_cuda<scalar_t, 1>(
reinterpret_cast<scalar_t*>(res.data_ptr()),
res.sizes().data(),
res.strides().data());
auto ma = Accessor_cuda<scalar_t, 2>(
reinterpret_cast<scalar_t*>(m.data_ptr()),
m.sizes().data(),
m.strides().data());
auto va = Accessor_cuda<scalar_t, 1>(
reinterpret_cast<scalar_t*>(v.data_ptr()),
v.sizes().data(),
v.strides().data());
mv_tensor_accessor_kernel<Accessor_cuda, scalar_t>
<<<1, 1, 0, 0>>>(resa, ma, va);
})),
AT_FLOATING_TYPES);
return res;
}

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CUDA, m) {
m.impl("mv_tensor_accessor", TORCH_BOX(&mv_tensor_accessor_cuda));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <vector>

using torch::stable::Tensor;

std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data());
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
}

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("my__foreach_mul", TORCH_BOX(&my__foreach_mul));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/stableivalue_conversions.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

using torch::stable::Tensor;

void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
}

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("my__foreach_mul_", TORCH_BOX(&my__foreach_mul_));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/device.h>
#include <torch/csrc/stable/ops.h>

#include <optional>

using torch::stable::Tensor;

Tensor my_empty(
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
std::optional<torch::headeronly::ScalarType> dtype,
std::optional<torch::stable::Device> device,
std::optional<bool> pin_memory) {
return empty(size, dtype, device, pin_memory);
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def(
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
}

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("my_empty", TORCH_BOX(&my_empty));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>

using torch::stable::Tensor;

Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) {
return reshape(t, shape);
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("my_reshape(Tensor t, int[] shape) -> Tensor");
}

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("my_reshape", TORCH_BOX(&my_reshape));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>

using torch::stable::Tensor;

Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) {
return view(t, size);
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("my_view(Tensor t, int[] size) -> Tensor");
}

STABLE_TORCH_LIBRARY_IMPL(
libtorch_agnostic_2_10,
CompositeExplicitAutograd,
m) {
m.impl("my_view", TORCH_BOX(&my_view));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/TensorAccessor.h>

template <typename T, size_t N>
using Accessor_cpu = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;

#if defined(__CUDACC__) || defined(__HIPCC__)
#define MAYBE_GLOBAL __global__

template <typename T, size_t N>
using Accessor_cuda = torch::headeronly::HeaderOnlyGenericPackedTensorAccessor<T, N, torch::headeronly::RestrictPtrTraits>;

#else
#define MAYBE_GLOBAL
#endif

template <template <typename, size_t> class Accessor, typename scalar_t>
MAYBE_GLOBAL void mv_tensor_accessor_kernel(Accessor<scalar_t, 1> resa, Accessor<scalar_t, 2> ma, Accessor<scalar_t, 1> va) {
for (int64_t i = 0; i < resa.size(0); i++) {
scalar_t val = 0;
for (int64_t j = 0; j < ma.size(1); j++) {
val += ma[i][j] * va[j];
}
resa[i] = val;
}
}
Loading