From b47661028d45e4b1296f0c6ac4932b4d2ad6f1e4 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Sun, 16 Nov 2025 21:28:27 -0800 Subject: [PATCH] Test that TORCH_FEATURE_VERSION guards are used where needed [ghstack-poisoned] --- .../libtorch_agnostic_2_10/csrc/kernel.cpp | 205 ------------ .../make_tensor_clones_and_call_foreach.cpp | 41 +++ .../csrc/mv_tensor_accessor_cpu.cpp | 40 +++ .../csrc/mv_tensor_accessor_cuda.cu | 47 +++ .../csrc/my__foreach_mul.cpp | 20 ++ .../csrc/my__foreach_mul_.cpp | 19 ++ .../libtorch_agnostic_2_10/csrc/my_empty.cpp | 25 ++ .../csrc/my_reshape.cpp | 17 + .../libtorch_agnostic_2_10/csrc/my_view.cpp | 20 ++ .../csrc/tensor_accessor_kernel.h | 28 ++ .../csrc/test_device_constructor.cpp | 37 +++ .../csrc/test_device_equality.cpp | 14 + .../csrc/test_device_index.cpp | 14 + .../csrc/test_device_is_cpu.cpp | 14 + .../csrc/test_device_is_cuda.cpp | 14 + .../csrc/test_device_set_index.cpp | 17 + .../csrc/test_get_num_threads.cpp | 14 + .../csrc/test_parallel_for.cpp | 49 +++ .../csrc/test_tensor_device.cpp | 17 + .../test_version_compatibility.py | 300 ++++++++++++++++++ 20 files changed, 747 insertions(+), 205 deletions(-) delete mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/make_tensor_clones_and_call_foreach.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cpu.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul_.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_reshape.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_view.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/tensor_accessor_kernel.h create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_constructor.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_equality.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_index.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_is_cpu.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_is_cuda.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_set_index.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_get_num_threads.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_parallel_for.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_tensor_device.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp deleted file mode 100644 index 72c78984b5215..0000000000000 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp +++ /dev/null @@ -1,205 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef LAE_USE_CUDA -#include -#endif - -#include - -using torch::stable::Tensor; - -std::vector my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef self, torch::headeronly::HeaderOnlyArrayRef other) { - std::array stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)}; - aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data()); - return torch::stable::detail::to>(stack[0]); -} - -void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef self, torch::headeronly::HeaderOnlyArrayRef other) { - std::array stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)}; - aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data()); -} - -Tensor my_clone(Tensor t) { - return clone(t); -} - -std::vector make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) { - // This function tests that my__foreach_mul can take in std::initializer_lists - // in addition to std::vectors. - Tensor t1_1 = my_clone(t1); - Tensor t1_2 = my_clone(t1); - Tensor t2_1 = my_clone(t2); - Tensor t2_2 = my_clone(t2); - return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2}); -} - -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { - m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]"); - m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()"); - m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]"); -} - -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { - m.impl("my__foreach_mul", TORCH_BOX(&my__foreach_mul)); - m.impl("my__foreach_mul_", TORCH_BOX(&my__foreach_mul_)); - m.impl("make_tensor_clones_and_call_foreach", TORCH_BOX(&make_tensor_clones_and_call_foreach)); -} - -// Test functions for torch::stable::Tensor device method - -torch::stable::Device test_tensor_device(torch::stable::Tensor tensor) { - return tensor.device(); -} - -// Test functions for torch::stable::Device - -torch::stable::Device test_device_constructor( - bool is_cuda, - torch::stable::DeviceIndex index, - bool use_str) { - using torch::stable::Device; - using torch::stable::DeviceType; - - if (use_str) { - std::string device_str; - if (is_cuda) { - device_str = "cuda:" + std::to_string(index); - } else { - device_str = "cpu"; - } - return Device(device_str); - } else { - if (is_cuda) { - return Device(DeviceType::CUDA, index); - } else { - return Device(DeviceType::CPU); - } - } -} - -bool test_device_equality(torch::stable::Device d1, torch::stable::Device d2) { - return d1 == d2; -} - -torch::stable::Device test_device_set_index( - torch::stable::Device device, - torch::stable::DeviceIndex index) { - device.set_index(index); - return device; -} - -torch::stable::DeviceIndex test_device_index(torch::stable::Device device) { - return device.index(); -} - -bool test_device_is_cuda(torch::stable::Device device) { - return device.is_cuda(); -} - -bool test_device_is_cpu(torch::stable::Device device) { - return device.is_cpu(); -} - -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { - m.def("test_tensor_device(Tensor t) -> Device"); - m.def( - "test_device_constructor(bool is_cuda, DeviceIndex index, bool use_str) -> Device"); - m.def("test_device_equality(Device d1, Device d2) -> bool"); - m.def("test_device_set_index(Device device, DeviceIndex index) -> Device"); - m.def("test_device_index(Device device) -> DeviceIndex"); - m.def("test_device_is_cuda(Device device) -> bool"); - m.def("test_device_is_cpu(Device device) -> bool"); -} - -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { - m.impl("test_tensor_device", TORCH_BOX(&test_tensor_device)); - m.impl("test_device_constructor", TORCH_BOX(&test_device_constructor)); - m.impl("test_device_equality", TORCH_BOX(&test_device_equality)); - m.impl("test_device_set_index", TORCH_BOX(&test_device_set_index)); - m.impl("test_device_index", TORCH_BOX(&test_device_index)); - m.impl("test_device_is_cuda", TORCH_BOX(&test_device_is_cuda)); - m.impl("test_device_is_cpu", TORCH_BOX(&test_device_is_cpu)); -} - -Tensor test_parallel_for(int64_t size, int64_t grain_size) { - AtenTensorHandle tensor_handle; - int64_t stride = 1; - - aoti_torch_empty_strided( - 1, - &size, - &stride, - aoti_torch_dtype_int64(), - aoti_torch_device_type_cpu(), - 0, - &tensor_handle); - - Tensor tensor(tensor_handle); - int64_t* data_ptr = reinterpret_cast(tensor.data_ptr()); - - torch::stable::zero_(tensor); - - // Use parallel_for to fill each element with its index - // If using a parallel path, the thread id is encoded in the upper 32 bits - torch::stable::parallel_for( - 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { - for (auto i = begin; i < end; i++) { - STD_TORCH_CHECK(i <= UINT32_MAX); - uint32_t thread_id; - torch_get_thread_idx(&thread_id); - data_ptr[i] = i | (static_cast(thread_id) << 32); - } - }); - - return tensor; -} - -uint32_t test_get_num_threads() { - return torch::stable::get_num_threads(); -} - -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { - m.def("test_parallel_for(int size, int grain_size) -> Tensor"); - m.def("test_get_num_threads() -> int"); -} - -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { - m.impl("test_parallel_for", TORCH_BOX(&test_parallel_for)); - m.impl("test_get_num_threads", TORCH_BOX(&test_get_num_threads)); -} - -Tensor my_empty( - torch::headeronly::HeaderOnlyArrayRef size, - std::optional dtype, - std::optional device, - std::optional pin_memory) { - return empty(size, dtype, device, pin_memory); -} - -Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef shape) { - return reshape(t, shape); -} - -Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef size) { - return view(t, size); -} - -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { - m.def( - "my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor"); - m.def("my_reshape(Tensor t, int[] shape) -> Tensor"); - m.def("my_view(Tensor t, int[] size) -> Tensor"); -} - -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { - m.impl("my_empty", TORCH_BOX(&my_empty)); - m.impl("my_reshape", TORCH_BOX(&my_reshape)); - m.impl("my_view", TORCH_BOX(&my_view)); -} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/make_tensor_clones_and_call_foreach.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/make_tensor_clones_and_call_foreach.cpp new file mode 100644 index 0000000000000..d3dbab5891394 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/make_tensor_clones_and_call_foreach.cpp @@ -0,0 +1,41 @@ +#include +#include +#include + +#include + +using torch::stable::Tensor; + +// Declare my__foreach_mul (defined in my__foreach_mul.cpp) +extern std::vector my__foreach_mul( + torch::headeronly::HeaderOnlyArrayRef self, + torch::headeronly::HeaderOnlyArrayRef other); + +// Helper function for cloning +Tensor my_clone(Tensor t) { + return clone(t); +} + +std::vector make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) { + // This function tests that my__foreach_mul can take in std::initializer_lists + // in addition to std::vectors. + Tensor t1_1 = my_clone(t1); + Tensor t1_2 = my_clone(t1); + Tensor t2_1 = my_clone(t2); + Tensor t2_2 = my_clone(t2); + return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2}); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def( + "make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]"); +} + +STABLE_TORCH_LIBRARY_IMPL( + libtorch_agnostic_2_10, + CompositeExplicitAutograd, + m) { + m.impl( + "make_tensor_clones_and_call_foreach", + TORCH_BOX(&make_tensor_clones_and_call_foreach)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cpu.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cpu.cpp new file mode 100644 index 0000000000000..705439efffe63 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cpu.cpp @@ -0,0 +1,40 @@ +// This is duplicated from the libtorch_agnostic_2_9_extension +// as a negative test for test_version_compatibility.py + +#include +#include +#include +#include +#include +#include +#include + +#include "tensor_accessor_kernel.h" + +using torch::stable::Tensor; + +Tensor mv_tensor_accessor_cpu(Tensor m, Tensor v) { + STD_TORCH_CHECK(m.dim() == 2, "m must be 2D"); + STD_TORCH_CHECK(v.dim() == 1, "v must be 1D"); + STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold"); + STD_TORCH_CHECK(m.scalar_type() == v.scalar_type(), "m and v must have the same dtype"); + STD_TORCH_CHECK(m.device() == v.device(), "m and v must be on the same device"); + Tensor res = new_empty(m, {m.size(0)}); + THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cpu", + AT_WRAP(([&]() { + auto resa = Accessor_cpu(reinterpret_cast(res.data_ptr()), res.sizes().data(), res.strides().data()); + auto ma = Accessor_cpu(reinterpret_cast(m.data_ptr()), m.sizes().data(), m.strides().data()); + auto va = Accessor_cpu(reinterpret_cast(v.data_ptr()), v.sizes().data(), v.strides().data()); + mv_tensor_accessor_kernel(resa, ma, va); + })), + AT_FLOATING_TYPES); + return res; +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("mv_tensor_accessor_cpu(Tensor res, Tensor m, Tensor v) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("mv_tensor_accessor_cpu", TORCH_BOX(&mv_tensor_accessor_cpu)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu new file mode 100644 index 0000000000000..7773210a089ee --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu @@ -0,0 +1,47 @@ +// This is duplicated from the libtorch_agnostic_2_9_extension +// as a negative test for test_version_compatibility.py + +#include "tensor_accessor_kernel.h" + +#include +#include +#include +#include + +using torch::stable::Tensor; + +Tensor mv_tensor_accessor_cuda(Tensor m, Tensor v) { + STD_TORCH_CHECK(m.dim() == 2, "m must be 2D"); + STD_TORCH_CHECK(v.dim() == 1, "v must be 1D"); + STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold"); + STD_TORCH_CHECK( + m.scalar_type() == v.scalar_type(), "m and v must have the same dtype"); + STD_TORCH_CHECK( + m.device() == v.device(), "m and v must be on the same device"); + Tensor res = new_empty(m, {m.size(0)}); + THO_DISPATCH_V2( + m.scalar_type(), + "mv_tensor_accessor_cuda", + AT_WRAP(([&]() { + auto resa = Accessor_cuda( + reinterpret_cast(res.data_ptr()), + res.sizes().data(), + res.strides().data()); + auto ma = Accessor_cuda( + reinterpret_cast(m.data_ptr()), + m.sizes().data(), + m.strides().data()); + auto va = Accessor_cuda( + reinterpret_cast(v.data_ptr()), + v.sizes().data(), + v.strides().data()); + mv_tensor_accessor_kernel + <<<1, 1, 0, 0>>>(resa, ma, va); + })), + AT_FLOATING_TYPES); + return res; +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CUDA, m) { + m.impl("mv_tensor_accessor", TORCH_BOX(&mv_tensor_accessor_cuda)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul.cpp new file mode 100644 index 0000000000000..834a63afea646 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul.cpp @@ -0,0 +1,20 @@ +#include +#include +#include +#include + +using torch::stable::Tensor; + +std::vector my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef self, torch::headeronly::HeaderOnlyArrayRef other) { + std::array stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)}; + aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data()); + return torch::stable::detail::to>(stack[0]); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my__foreach_mul", TORCH_BOX(&my__foreach_mul)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul_.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul_.cpp new file mode 100644 index 0000000000000..8409e6890bdd0 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul_.cpp @@ -0,0 +1,19 @@ +#include +#include +#include +#include + +using torch::stable::Tensor; + +void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef self, torch::headeronly::HeaderOnlyArrayRef other) { + std::array stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)}; + aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data()); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my__foreach_mul_", TORCH_BOX(&my__foreach_mul_)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp new file mode 100644 index 0000000000000..6278dca9f281d --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp @@ -0,0 +1,25 @@ +#include +#include +#include +#include + +#include + +using torch::stable::Tensor; + +Tensor my_empty( + torch::headeronly::HeaderOnlyArrayRef size, + std::optional dtype, + std::optional device, + std::optional pin_memory) { + return empty(size, dtype, device, pin_memory); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def( + "my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my_empty", TORCH_BOX(&my_empty)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_reshape.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_reshape.cpp new file mode 100644 index 0000000000000..0a2b1f70f2156 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_reshape.cpp @@ -0,0 +1,17 @@ +#include +#include +#include + +using torch::stable::Tensor; + +Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef shape) { + return reshape(t, shape); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_reshape(Tensor t, int[] shape) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my_reshape", TORCH_BOX(&my_reshape)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_view.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_view.cpp new file mode 100644 index 0000000000000..25d8c54589247 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_view.cpp @@ -0,0 +1,20 @@ +#include +#include +#include + +using torch::stable::Tensor; + +Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef size) { + return view(t, size); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_view(Tensor t, int[] size) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL( + libtorch_agnostic_2_10, + CompositeExplicitAutograd, + m) { + m.impl("my_view", TORCH_BOX(&my_view)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/tensor_accessor_kernel.h b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/tensor_accessor_kernel.h new file mode 100644 index 0000000000000..f1031f38060cf --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/tensor_accessor_kernel.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +template +using Accessor_cpu = torch::headeronly::HeaderOnlyTensorAccessor; + +#if defined(__CUDACC__) || defined(__HIPCC__) +#define MAYBE_GLOBAL __global__ + +template +using Accessor_cuda = torch::headeronly::HeaderOnlyGenericPackedTensorAccessor; + +#else +#define MAYBE_GLOBAL +#endif + +template