From 7d8992cea61fc1a3639536aa8839bfa8bb4f063b Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Tue, 11 Nov 2025 16:10:55 -0800 Subject: [PATCH 01/15] Add empty to stable ops [ghstack-poisoned] --- .../libtorch_agnostic/csrc/kernel.cpp | 29 +++++++++++++++++ .../libtorch_agnostic/ops.py | 14 ++++++++ .../test/test_libtorch_agnostic.py | 32 +++++++++++++++++++ torch/csrc/stable/ops.h | 20 ++++++++++++ 4 files changed, 95 insertions(+) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 96b6a17cf9187..b41820ce22964 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -824,3 +824,32 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("test_parallel_for", &boxed_test_parallel_for); m.impl("test_get_num_threads", &boxed_test_get_num_threads); } + +Tensor my_empty( + torch::headeronly::HeaderOnlyArrayRef size, + std::optional dtype, + std::optional device) { + return empty(size, dtype, device); +} + +void boxed_my_empty( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + Tensor res = my_empty( + torch::stable::detail::to>(stack[0]), + torch::stable::detail::to>( + stack[1]), + torch::stable::detail::to>( + stack[2])); + stack[0] = torch::stable::detail::from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def( + "my_empty(int[] size, ScalarType? dtype=None, Device? device=None) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("my_empty", &boxed_my_empty); +} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 59d8c17b68d77..5375f7b9b378d 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -487,3 +487,17 @@ def test_get_num_threads() -> int: Returns: int - the number of threads for the parallel backend """ return torch.ops.libtorch_agnostic.test_get_num_threads.default() + + +def my_empty(size, dtype=None, device=None) -> Tensor: + """ + Creates an empty tensor with the specified size, dtype, and device. + + Args: + size: list[int] - size of the tensor to create + dtype: ScalarType or None - data type of the tensor + device: Device or None - device on which to create the tensor + + Returns: Tensor - an uninitialized tensor with the specified properties + """ + return torch.ops.libtorch_agnostic.my_empty.default(size, dtype, device) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index 1149be388795a..ce1d8ff20059e 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -525,6 +525,38 @@ def test_get_num_threads(self, device): expected_num_threads = torch.get_num_threads() self.assertEqual(num_threads, expected_num_threads) + def test_my_empty(self, device): + import libtorch_agnostic + + deterministic = torch.are_deterministic_algorithms_enabled() + try: + # set use_deterministic_algorithms to fill uninitialized memory + torch.use_deterministic_algorithms(True) + + # Test with just size + size = [2, 3] + result = libtorch_agnostic.ops.my_empty(size, None, None) + expected = torch.empty(size) + self.assertEqual(result, expected, exact_device=True) + + # Test with size and dtype + result_float = libtorch_agnostic.ops.my_empty(size, torch.float32, None) + expected_float = torch.empty(size, dtype=torch.float32) + self.assertEqual(result_float, expected_float, exact_device=True) + + # Test with size, dtype, and device + result_with_device = libtorch_agnostic.ops.my_empty( + size, torch.float64, device + ) + expected_with_device = torch.empty( + size, dtype=torch.float64, device=device + ) + self.assertEqual( + result_with_device, expected_with_device, exact_device=True + ) + finally: + torch.use_deterministic_algorithms(deterministic) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index 1a14cf9765094..968f8d0aecaa9 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -306,6 +306,26 @@ inline uint32_t get_num_threads() { return num_threads; } +// We expect this to be the stable version of the empty op that takes in +// device and dtype parameters. The empty op creates a tensor with uninitialized +// values of the specified size, dtype, and device. +inline torch::stable::Tensor empty( + torch::headeronly::IntHeaderOnlyArrayRef size, + std::optional dtype = std::nullopt, + std::optional device = std::nullopt) { + const auto num_args = 6; + std::array stack{ + torch::stable::detail::from(size), + torch::stable::detail::from(dtype), + torch::stable::detail::from(std::nullopt), + torch::stable::detail::from(device), + torch::stable::detail::from(std::nullopt), + torch::stable::detail::from(std::nullopt)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::empty", "memory_format", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + #endif HIDDEN_NAMESPACE_END(torch, stable) From 85b1f20921d94fd40f49fb7bb991359683979a85 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Tue, 11 Nov 2025 17:29:04 -0800 Subject: [PATCH 02/15] Add reshape, view, flatten to torch/csrc/stable [ghstack-poisoned] --- .../libtorch_agnostic/csrc/kernel.cpp | 49 ++++++++++++++++++ .../libtorch_agnostic/ops.py | 40 +++++++++++++++ .../test/test_libtorch_agnostic.py | 50 +++++++++++++++++++ torch/csrc/stable/ops.h | 43 ++++++++++++++++ 4 files changed, 182 insertions(+) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index b41820ce22964..a70b522b4e6f8 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -845,11 +845,60 @@ void boxed_my_empty( stack[0] = torch::stable::detail::from(res); } +Tensor my_flatten(Tensor t, int64_t start_dim, int64_t end_dim) { + return flatten(t, start_dim, end_dim); +} + +void boxed_my_flatten( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + Tensor res = my_flatten( + torch::stable::detail::to(stack[0]), + torch::stable::detail::to(stack[1]), + torch::stable::detail::to(stack[2])); + stack[0] = torch::stable::detail::from(res); +} + +Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef shape) { + return reshape(t, shape); +} + +void boxed_my_reshape( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + Tensor res = my_reshape( + torch::stable::detail::to(stack[0]), + torch::stable::detail::to>(stack[1])); + stack[0] = torch::stable::detail::from(res); +} + +Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef size) { + return view(t, size); +} + +void boxed_my_view( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + Tensor res = my_view( + torch::stable::detail::to(stack[0]), + torch::stable::detail::to>(stack[1])); + stack[0] = torch::stable::detail::from(res); +} + STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { m.def( "my_empty(int[] size, ScalarType? dtype=None, Device? device=None) -> Tensor"); + m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor"); + m.def("my_reshape(Tensor t, int[] shape) -> Tensor"); + m.def("my_view(Tensor t, int[] size) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("my_empty", &boxed_my_empty); + m.impl("my_flatten", &boxed_my_flatten); + m.impl("my_reshape", &boxed_my_reshape); + m.impl("my_view", &boxed_my_view); } diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 5375f7b9b378d..19adf378111b0 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -501,3 +501,43 @@ def my_empty(size, dtype=None, device=None) -> Tensor: Returns: Tensor - an uninitialized tensor with the specified properties """ return torch.ops.libtorch_agnostic.my_empty.default(size, dtype, device) + + +def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor: + """ + Flattens the input tensor from start_dim to end_dim into a single dimension. + + Args: + t: Tensor - tensor to flatten + start_dim: int - first dimension to flatten (default: 0) + end_dim: int - last dimension to flatten (default: -1) + + Returns: Tensor - flattened tensor + """ + return torch.ops.libtorch_agnostic.my_flatten.default(t, start_dim, end_dim) + + +def my_reshape(t, shape) -> Tensor: + """ + Returns a tensor with the same data but different shape. + + Args: + t: Tensor - tensor to reshape + shape: list[int] - new shape for the tensor + + Returns: Tensor - reshaped tensor + """ + return torch.ops.libtorch_agnostic.my_reshape.default(t, shape) + + +def my_view(t, size) -> Tensor: + """ + Returns a new tensor with the same data as the input tensor but of a different shape. + + Args: + t: Tensor - tensor to view + size: list[int] - new size for the tensor + + Returns: Tensor - tensor with new view + """ + return torch.ops.libtorch_agnostic.my_view.default(t, size) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index ce1d8ff20059e..74e098bcfe94c 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -557,6 +557,56 @@ def test_my_empty(self, device): finally: torch.use_deterministic_algorithms(deterministic) + def test_my_flatten(self, device): + import libtorch_agnostic + + t = torch.randn(2, 3, 4, device=device) + result = libtorch_agnostic.ops.my_flatten(t) + expected = torch.flatten(t) + self.assertEqual(result, expected) + + result_start = libtorch_agnostic.ops.my_flatten(t, 1) + expected_start = torch.flatten(t, 1) + self.assertEqual(result_start, expected_start) + + result_range = libtorch_agnostic.ops.my_flatten(t, 2, -1) + expected_range = torch.flatten(t, 2, -1) + self.assertEqual(result_range, expected_range) + + def test_my_reshape(self, device): + import libtorch_agnostic + + t = torch.randn(2, 3, 4, device=device) + + result = libtorch_agnostic.ops.my_reshape(t, [6, 4]) + expected = torch.reshape(t, [6, 4]) + self.assertEqual(result, expected) + + result_infer = libtorch_agnostic.ops.my_reshape(t, [-1, 4]) + expected_infer = torch.reshape(t, [-1, 4]) + self.assertEqual(result_infer, expected_infer) + + result_flat = libtorch_agnostic.ops.my_reshape(t, [-1]) + expected_flat = torch.reshape(t, [-1]) + self.assertEqual(result_flat, expected_flat) + + def test_my_view(self, device): + import libtorch_agnostic + + t = torch.randn(2, 3, 4, device=device) + + result = libtorch_agnostic.ops.my_view(t, [6, 4]) + expected = t.view([6, 4]) + self.assertEqual(result, expected) + + result_infer = libtorch_agnostic.ops.my_view(t, [-1, 4]) + expected_infer = t.view([-1, 4]) + self.assertEqual(result_infer, expected_infer) + + result_flat = libtorch_agnostic.ops.my_view(t, [-1]) + expected_flat = t.view([-1]) + self.assertEqual(result_flat, expected_flat) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index 968f8d0aecaa9..ba55788f1934c 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -326,6 +326,49 @@ inline torch::stable::Tensor empty( return torch::stable::detail::to(stack[0]); } +// We expect this to be the stable version of the flatten.using_ints op. +// This flattens the tensor from start_dim to end_dim into a single dimension. +inline torch::stable::Tensor flatten( + const torch::stable::Tensor& self, + int64_t start_dim = 0, + int64_t end_dim = -1) { + const auto num_args = 3; + std::array stack{ + torch::stable::detail::from(self), + torch::stable::detail::from(start_dim), + torch::stable::detail::from(end_dim)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::flatten", "using_ints", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// We expect this to be the stable version of the reshape op. +// This returns a tensor with the same data but different shape. +inline torch::stable::Tensor reshape( + const torch::stable::Tensor& self, + torch::headeronly::IntHeaderOnlyArrayRef shape) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(shape)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::reshape", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// We expect this to be the stable version of the view op. +// This returns a new tensor with the same data as the self tensor but of a +// different shape. +inline torch::stable::Tensor view( + const torch::stable::Tensor& self, + torch::headeronly::IntHeaderOnlyArrayRef size) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(size)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::view", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + #endif HIDDEN_NAMESPACE_END(torch, stable) From c96c2f80efe5961a41e93883fd34338fbbde987e Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Tue, 11 Nov 2025 17:32:30 -0800 Subject: [PATCH 03/15] Update on "Add reshape, view, flatten to torch/csrc/stable" [ghstack-poisoned] --- torch/csrc/stable/ops.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index ba55788f1934c..9da92abe21a63 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -327,7 +327,6 @@ inline torch::stable::Tensor empty( } // We expect this to be the stable version of the flatten.using_ints op. -// This flattens the tensor from start_dim to end_dim into a single dimension. inline torch::stable::Tensor flatten( const torch::stable::Tensor& self, int64_t start_dim = 0, @@ -343,7 +342,6 @@ inline torch::stable::Tensor flatten( } // We expect this to be the stable version of the reshape op. -// This returns a tensor with the same data but different shape. inline torch::stable::Tensor reshape( const torch::stable::Tensor& self, torch::headeronly::IntHeaderOnlyArrayRef shape) { @@ -356,8 +354,6 @@ inline torch::stable::Tensor reshape( } // We expect this to be the stable version of the view op. -// This returns a new tensor with the same data as the self tensor but of a -// different shape. inline torch::stable::Tensor view( const torch::stable::Tensor& self, torch::headeronly::IntHeaderOnlyArrayRef size) { From 0b0e15a534398f6cd852703fae4d3ec3c3c1ea1b Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 13 Nov 2025 21:15:27 -0800 Subject: [PATCH 04/15] Fix TORCH_FEATURE_VERSION guards [ghstack-poisoned] --- torch/csrc/stable/ops.h | 4 +- torch/csrc/stable/stableivalue_conversions.h | 131 ++++++++++--------- 2 files changed, 72 insertions(+), 63 deletions(-) diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index 18d24fce2721c..b1ffe63b61ce6 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -231,6 +231,8 @@ inline torch::stable::Tensor zero_(torch::stable::Tensor& self) { return torch::stable::detail::to(stack[0]); } +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + // We expect this to be the stable version of the copy_ op with // identical semantics to the existing copy_ op. inline torch::stable::Tensor copy_( @@ -269,8 +271,6 @@ inline torch::stable::Tensor clone(const torch::stable::Tensor& self) { return torch::stable::detail::to(stack[0]); } -#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 - // New ops should be added here if they use a brand new shim API // Parallel utility wrapper that provides a stable interface to at::parallel_for diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index 600a661962f2b..62d6aaa33f87e 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -126,37 +127,6 @@ struct FromImpl { } }; -// Specialization for torch::headeronly::DeviceType => StableIValue -// Note that we call into the shim to translate between the user's -// DeviceType and libtorch's DeviceType, which can be different! -using torch::headeronly::DeviceType; -template <> -struct FromImpl { - static StableIValue call( - DeviceType val, - [[maybe_unused]] uint64_t extension_build_version, - [[maybe_unused]] bool is_internal) { - switch (val) { - case DeviceType::CPU: - return from(aoti_torch_device_type_cpu()); - case DeviceType::CUDA: - return from(aoti_torch_device_type_cuda()); - case DeviceType::Meta: - return from(aoti_torch_device_type_meta()); - case DeviceType::XPU: - return from(aoti_torch_device_type_xpu()); - case DeviceType::MPS: - return from(aoti_torch_device_type_mps()); - case DeviceType::PrivateUse1: - return from(aoti_torch_device_type_privateuse1()); - default: - STD_TORCH_CHECK( - false, - "Not yet supported DeviceType, please file an issue describing your use case."); - } - } -}; - // Specialization for std::nullopt_t => StableIValue template <> struct FromImpl { @@ -225,6 +195,8 @@ struct FromImpl { } }; +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + // Specialization for torch::headeronly::HeaderOnlyArrayRef => StableIValue // Returns a new owning reference of the underlying list. template @@ -287,6 +259,39 @@ struct FromImpl { } }; +// Specialization for torch::headeronly::DeviceType => StableIValue +// Note that we call into the shim to translate between the user's +// DeviceType and libtorch's DeviceType, which can be different! +using torch::headeronly::DeviceType; +template <> +struct FromImpl { + static StableIValue call( + DeviceType val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + switch (val) { + case DeviceType::CPU: + return from(aoti_torch_device_type_cpu()); + case DeviceType::CUDA: + return from(aoti_torch_device_type_cuda()); + case DeviceType::Meta: + return from(aoti_torch_device_type_meta()); + case DeviceType::XPU: + return from(aoti_torch_device_type_xpu()); + case DeviceType::MPS: + return from(aoti_torch_device_type_mps()); + case DeviceType::PrivateUse1: + return from(aoti_torch_device_type_privateuse1()); + default: + STD_TORCH_CHECK( + false, + "Not yet supported DeviceType, please file an issue describing your use case."); + } + } +}; + +#endif + // ============================================================================= // TO CONVERSIONS (StableIValue -> T) // ============================================================================= @@ -387,36 +392,6 @@ struct ToImpl { } }; -// Specialization for StableIValue => torch::headeronly::DeviceType -template <> -struct ToImpl { - static DeviceType call( - StableIValue val, - [[maybe_unused]] uint64_t extension_build_version, - [[maybe_unused]] bool is_internal) { - int32_t shim_devicetype = to(val); - if (shim_devicetype == aoti_torch_device_type_cpu()) { - return DeviceType::CPU; - } else if (shim_devicetype == aoti_torch_device_type_cuda()) { - return DeviceType::CUDA; - } else if (shim_devicetype == aoti_torch_device_type_meta()) { - return DeviceType::Meta; - } else if (shim_devicetype == aoti_torch_device_type_xpu()) { - return DeviceType::XPU; - } else if (shim_devicetype == aoti_torch_device_type_mps()) { - return DeviceType::MPS; - } else if (shim_devicetype == aoti_torch_device_type_privateuse1()) { - return DeviceType::PrivateUse1; - } else { - STD_TORCH_CHECK( - false, - "Not yet supported DeviceType ", - std::to_string(shim_devicetype), - ", please file an issue describing your use case."); - } - } -}; - // Specialization for StableIValue => std::nullopt_t template <> struct ToImpl { @@ -467,6 +442,8 @@ struct ToImpl { } }; +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + // Specialization for StableIValue => std::vector // std::vector should be represented as a StableListHandle // filled with StableIValues @@ -517,6 +494,38 @@ struct ToImpl { } }; +// Specialization for StableIValue => torch::headeronly::DeviceType +template <> +struct ToImpl { + static DeviceType call( + StableIValue val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + int32_t shim_devicetype = to(val); + if (shim_devicetype == aoti_torch_device_type_cpu()) { + return DeviceType::CPU; + } else if (shim_devicetype == aoti_torch_device_type_cuda()) { + return DeviceType::CUDA; + } else if (shim_devicetype == aoti_torch_device_type_meta()) { + return DeviceType::Meta; + } else if (shim_devicetype == aoti_torch_device_type_xpu()) { + return DeviceType::XPU; + } else if (shim_devicetype == aoti_torch_device_type_mps()) { + return DeviceType::MPS; + } else if (shim_devicetype == aoti_torch_device_type_privateuse1()) { + return DeviceType::PrivateUse1; + } else { + STD_TORCH_CHECK( + false, + "Not yet supported DeviceType ", + std::to_string(shim_devicetype), + ", please file an issue describing your use case."); + } + } +}; + +#endif + // ============================================================================= // end to helpers for converting between StableIValue and T // ============================================================================= From 65708c0064577168c5977a0dbb7b1c9096e1e7ec Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 13 Nov 2025 21:15:33 -0800 Subject: [PATCH 05/15] Split libtorch agnostic tests by feature version [ghstack-poisoned] --- .../libtorch_agnostic_2_10}/__init__.py | 0 .../libtorch_agnostic_2_10}/csrc/kernel.cpp | 70 +-- .../libtorch_agnostic_2_10}/ops.py | 86 ++-- .../setup.py | 18 +- .../libtorch_agnostic_2_9/__init__.py | 21 + .../libtorch_agnostic_2_9/csrc/kernel.cpp | 411 ++++++++++++++++++ .../libtorch_agnostic_2_9/ops.py | 309 +++++++++++++ .../libtorch_agnostic_2_9_extension/setup.py | 78 ++++ ...py => test_libtorch_agnostic_versioned.py} | 410 +++++++++-------- 9 files changed, 1105 insertions(+), 298 deletions(-) rename test/cpp_extensions/{libtorch_agnostic_extension/libtorch_agnostic => libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10}/__init__.py (100%) rename test/cpp_extensions/{libtorch_agnostic_extension/libtorch_agnostic => libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10}/csrc/kernel.cpp (88%) rename test/cpp_extensions/{libtorch_agnostic_extension/libtorch_agnostic => libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10}/ops.py (76%) rename test/cpp_extensions/{libtorch_agnostic_extension => libtorch_agnostic_2_10_extension}/setup.py (76%) create mode 100644 test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/__init__.py create mode 100644 test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py create mode 100644 test/cpp_extensions/libtorch_agnostic_2_9_extension/setup.py rename test/cpp_extensions/{libtorch_agnostic_extension/test/test_libtorch_agnostic.py => test_libtorch_agnostic_versioned.py} (54%) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/__init__.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/__init__.py similarity index 100% rename from test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/__init__.py rename to test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/__init__.py diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp similarity index 88% rename from test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp rename to test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp index 28c96f05c4f14..caed10e3fabc2 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp @@ -66,11 +66,11 @@ Tensor sgd_out_of_place( return out; } -STABLE_TORCH_LIBRARY(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY(libtorch_agnostic_2_10, m) { m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CPU, m) { m.impl("sgd_out_of_place", TORCH_BOX(&sgd_out_of_place)); } @@ -79,15 +79,15 @@ Tensor identity(Tensor t) { } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("identity(Tensor t) -> Tensor"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CUDA, m) { m.impl("identity", TORCH_BOX(&identity)); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CPU, m) { m.impl("identity", TORCH_BOX(&identity)); } @@ -99,11 +99,11 @@ Tensor my_abs(Tensor t) { return torch::stable::detail::to(stack[0]); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("my_abs(Tensor t) -> Tensor"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("my_abs", TORCH_BOX(&my_abs)); } @@ -125,11 +125,11 @@ Tensor my_ones_like(Tensor t, StableIValue device) { return torch::stable::detail::to(stack[0]); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("my_ones_like(Tensor t, Device d) -> Tensor"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("my_ones_like", TORCH_BOX(&my_ones_like)); } @@ -152,11 +152,11 @@ std::tuple exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3 torch::stable::detail::to(stack_is_leaf[0])); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("exp_neg_is_leaf", TORCH_BOX(&exp_neg_is_leaf)); } @@ -168,11 +168,11 @@ Tensor neg_exp(Tensor t) { return torch::stable::detail::to(stack[0]); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("neg_exp(Tensor t) -> Tensor"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("neg_exp", TORCH_BOX(&neg_exp)); } @@ -192,11 +192,11 @@ Tensor divide_neg_exp(Tensor t) { return torch::stable::detail::to(stack_div[0]); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("divide_neg_exp(Tensor t) -> Tensor"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("divide_neg_exp", TORCH_BOX(÷_neg_exp)); } @@ -204,11 +204,11 @@ bool is_contiguous(Tensor t) { return t.is_contiguous(); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("is_contiguous(Tensor t) -> bool"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("is_contiguous", TORCH_BOX(&is_contiguous)); } @@ -263,7 +263,7 @@ Tensor my_clone(Tensor t) { return clone(t); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor"); m.def("my_empty_like(Tensor t) -> Tensor"); m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)"); @@ -275,7 +275,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { m.def("my_clone(Tensor t) -> Tensor"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("my_transpose", TORCH_BOX(&my_transpose)); m.impl("my_empty_like", TORCH_BOX(&my_empty_like)); m.impl("fill_infinity", TORCH_BOX(&fill_infinity)); @@ -286,7 +286,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("my_clone", TORCH_BOX(&my_clone)); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeImplicitAutograd, m) { m.impl("my_pad", TORCH_BOX(&my_pad)); m.impl("my_narrow", TORCH_BOX(&my_narrow)); } @@ -303,7 +303,7 @@ Tensor my_amax_vec(Tensor t) { return amax(t, {0,1}, false); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)"); m.def("my_amax(Tensor a) -> Tensor"); m.def("my_amax_vec(Tensor a) -> Tensor"); @@ -331,7 +331,7 @@ bool test_default_constructor(bool defined) { } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("my_zero_", TORCH_BOX(&my_zero_)); m.impl("my_amax", TORCH_BOX(&my_amax)); m.impl("my_amax_vec", TORCH_BOX(&my_amax_vec)); @@ -359,13 +359,13 @@ std::vector make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) { return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2}); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]"); m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()"); m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("my__foreach_mul", TORCH_BOX(&my__foreach_mul)); m.impl("my__foreach_mul_", TORCH_BOX(&my__foreach_mul_)); m.impl("make_tensor_clones_and_call_foreach", TORCH_BOX(&make_tensor_clones_and_call_foreach)); @@ -493,7 +493,7 @@ void boxed_test_device_is_cpu( stack[0] = torch::stable::detail::from(res); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("test_tensor_device(Tensor t) -> Device"); m.def( "test_device_constructor(bool is_cuda, DeviceIndex index, bool use_str) -> Device"); @@ -504,7 +504,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { m.def("test_device_is_cpu(Device device) -> bool"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("test_tensor_device", &boxed_test_tensor_device); m.impl("test_device_constructor", &boxed_test_device_constructor); m.impl("test_device_equality", &boxed_test_device_equality); @@ -556,14 +556,14 @@ int64_t test_get_current_device_index() { return torch::stable::accelerator::getCurrentDeviceIndex(); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("test_device_guard(int device_index) -> int"); m.def("test_device_guard_set_index() -> int"); m.def("test_stream(int device_index) -> int"); m.def("test_get_current_device_index() -> int"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("test_device_guard", TORCH_BOX(&test_device_guard)); m.impl("test_device_guard_set_index", TORCH_BOX(&test_device_guard_set_index)); m.impl("test_stream", TORCH_BOX(&test_stream)); @@ -609,8 +609,8 @@ void boxed_test_parallel_for( StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - Tensor res = test_parallel_for(to(stack[0]), to(stack[1])); - stack[0] = from(res); + Tensor res = test_parallel_for(torch::stable::detail::to(stack[0]), torch::stable::detail::to(stack[1])); + stack[0] = torch::stable::detail::from(res); } uint32_t test_get_num_threads() { @@ -622,15 +622,15 @@ void boxed_test_get_num_threads( uint64_t num_args, uint64_t num_outputs) { uint32_t res = test_get_num_threads(); - stack[0] = from(res); + stack[0] = torch::stable::detail::from(res); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def("test_parallel_for(int size, int grain_size) -> Tensor"); m.def("test_get_num_threads() -> int"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("test_parallel_for", &boxed_test_parallel_for); m.impl("test_get_num_threads", &boxed_test_get_num_threads); } @@ -655,7 +655,7 @@ Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef size) { return view(t, size); } -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def( "my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor"); m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor"); @@ -663,7 +663,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { m.def("my_view(Tensor t, int[] size) -> Tensor"); } -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("my_empty", TORCH_BOX(&my_empty)); m.impl("my_flatten", TORCH_BOX(&my_flatten)); m.impl("my_reshape", TORCH_BOX(&my_reshape)); diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py similarity index 76% rename from test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py rename to test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index b920574b2c205..908c2bd411770 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -20,7 +20,7 @@ def sgd_out_of_place(param, grad, weight_decay, lr, maximize) -> Tensor: a 1D float Tensor the same shape as param """ - return torch.ops.libtorch_agnostic.sgd_out_of_place.default( + return torch.ops.libtorch_agnostic_2_10.sgd_out_of_place.default( param, grad, weight_decay, lr, maximize ) @@ -35,7 +35,7 @@ def identity(t) -> Tensor: Returns: a Tensor, the same as input. """ - return torch.ops.libtorch_agnostic.identity.default(t) + return torch.ops.libtorch_agnostic_2_10.identity.default(t) def my_abs(t) -> Tensor: @@ -48,7 +48,7 @@ def my_abs(t) -> Tensor: Returns: a Tensor """ - return torch.ops.libtorch_agnostic.my_abs.default(t) + return torch.ops.libtorch_agnostic_2_10.my_abs.default(t) def my_is_cpu(t) -> bool: @@ -61,7 +61,7 @@ def my_is_cpu(t) -> bool: Returns: a bool """ - return torch.ops.libtorch_agnostic.my_is_cpu.default(t) + return torch.ops.libtorch_agnostic_2_10.my_is_cpu.default(t) def my_ones_like(tensor, device) -> Tensor: @@ -76,7 +76,7 @@ def my_ones_like(tensor, device) -> Tensor: a ones Tensor with the same dtype and shape and other attributes like the input tensor """ - return torch.ops.libtorch_agnostic.my_ones_like.default(tensor, device) + return torch.ops.libtorch_agnostic_2_10.my_ones_like.default(tensor, device) def exp_neg_is_leaf(t1, t2, t3) -> tuple[Tensor, Tensor, bool]: @@ -92,7 +92,7 @@ def exp_neg_is_leaf(t1, t2, t3) -> tuple[Tensor, Tensor, bool]: Returns: (exp(t1), neg(t2), is_leaf(t3)) """ - return torch.ops.libtorch_agnostic.exp_neg_is_leaf.default(t1, t2, t3) + return torch.ops.libtorch_agnostic_2_10.exp_neg_is_leaf.default(t1, t2, t3) def neg_exp(t) -> Tensor: @@ -104,7 +104,7 @@ def neg_exp(t) -> Tensor: Returns: neg(exp(t)) """ - return torch.ops.libtorch_agnostic.neg_exp.default(t) + return torch.ops.libtorch_agnostic_2_10.neg_exp.default(t) def divide_neg_exp(t) -> Tensor: @@ -116,7 +116,7 @@ def divide_neg_exp(t) -> Tensor: Returns: divide(neg(t), exp(t)) """ - return torch.ops.libtorch_agnostic.divide_neg_exp.default(t) + return torch.ops.libtorch_agnostic_2_10.divide_neg_exp.default(t) def is_contiguous(t) -> bool: @@ -128,7 +128,7 @@ def is_contiguous(t) -> bool: Returns: is_contiguous(t) """ - return torch.ops.libtorch_agnostic.is_contiguous.default(t) + return torch.ops.libtorch_agnostic_2_10.is_contiguous.default(t) def my_transpose(t, dim0, dim1) -> Tensor: @@ -140,7 +140,7 @@ def my_transpose(t, dim0, dim1) -> Tensor: Returns: my_transpose(t, dim0, dim1) """ - return torch.ops.libtorch_agnostic.my_transpose.default(t, dim0, dim1) + return torch.ops.libtorch_agnostic_2_10.my_transpose.default(t, dim0, dim1) def my_empty_like(t) -> Tensor: @@ -152,7 +152,7 @@ def my_empty_like(t) -> Tensor: Returns: my_empty_like(t) """ - return torch.ops.libtorch_agnostic.my_empty_like.default(t) + return torch.ops.libtorch_agnostic_2_10.my_empty_like.default(t) def my_zero_(t) -> Tensor: @@ -164,7 +164,7 @@ def my_zero_(t) -> Tensor: Returns: my_zero_(t) """ - return torch.ops.libtorch_agnostic.my_zero_.default(t) + return torch.ops.libtorch_agnostic_2_10.my_zero_.default(t) def my_amax(t) -> Tensor: @@ -176,7 +176,7 @@ def my_amax(t) -> Tensor: Returns: amax(t) """ - return torch.ops.libtorch_agnostic.my_amax.default(t) + return torch.ops.libtorch_agnostic_2_10.my_amax.default(t) def my_amax_vec(t) -> Tensor: @@ -188,7 +188,7 @@ def my_amax_vec(t) -> Tensor: Returns: amax(t) """ - return torch.ops.libtorch_agnostic.my_amax_vec.default(t) + return torch.ops.libtorch_agnostic_2_10.my_amax_vec.default(t) def fill_infinity(t) -> Tensor: @@ -200,7 +200,7 @@ def fill_infinity(t) -> Tensor: Returns: The modified tensor (same as input) """ - return torch.ops.libtorch_agnostic.fill_infinity.default(t) + return torch.ops.libtorch_agnostic_2_10.fill_infinity.default(t) def test_default_constructor(defined) -> bool: @@ -212,7 +212,7 @@ def test_default_constructor(defined) -> bool: Returns: bool - result of calling .defined() on the tensor """ - return torch.ops.libtorch_agnostic.test_default_constructor.default(defined) + return torch.ops.libtorch_agnostic_2_10.test_default_constructor.default(defined) def test_tensor_device(t): @@ -224,7 +224,7 @@ def test_tensor_device(t): Returns: Device - device of the tensor """ - return torch.ops.libtorch_agnostic.test_tensor_device.default(t) + return torch.ops.libtorch_agnostic_2_10.test_tensor_device.default(t) def my_pad(t) -> Tensor: @@ -236,7 +236,7 @@ def my_pad(t) -> Tensor: Returns: Padded tensor with padding [1, 2, 2, 1], mode "constant", value 0.0 """ - return torch.ops.libtorch_agnostic.my_pad.default(t) + return torch.ops.libtorch_agnostic_2_10.my_pad.default(t) def my_narrow(t, dim, start, length) -> Tensor: @@ -251,7 +251,7 @@ def my_narrow(t, dim, start, length) -> Tensor: Returns: Narrowed tensor """ - return torch.ops.libtorch_agnostic.my_narrow.default(t, dim, start, length) + return torch.ops.libtorch_agnostic_2_10.my_narrow.default(t, dim, start, length) def my_copy_(dst, src, non_blocking) -> Tensor: @@ -265,7 +265,7 @@ def my_copy_(dst, src, non_blocking) -> Tensor: Returns: Updated tensor """ - return torch.ops.libtorch_agnostic.my_copy_.default(dst, src, non_blocking) + return torch.ops.libtorch_agnostic_2_10.my_copy_.default(dst, src, non_blocking) def my_clone(t) -> Tensor: @@ -277,7 +277,7 @@ def my_clone(t) -> Tensor: Returns: Cloned tensor """ - return torch.ops.libtorch_agnostic.my_clone.default(t) + return torch.ops.libtorch_agnostic_2_10.my_clone.default(t) def test_device_guard(device_index) -> int: @@ -289,7 +289,7 @@ def test_device_guard(device_index) -> int: Returns: result of cudaGetDevice() as an integer after using the guard """ - return torch.ops.libtorch_agnostic.test_device_guard.default(device_index) + return torch.ops.libtorch_agnostic_2_10.test_device_guard.default(device_index) def test_device_guard_set_index() -> int: @@ -299,7 +299,7 @@ def test_device_guard_set_index() -> int: Returns: result of cudaGetDevice() as an integer after using set_index """ - return torch.ops.libtorch_agnostic.test_device_guard_set_index.default() + return torch.ops.libtorch_agnostic_2_10.test_device_guard_set_index.default() def test_stream(device_index) -> int: @@ -311,7 +311,7 @@ def test_stream(device_index) -> int: Returns: Stream ID as an integer """ - return torch.ops.libtorch_agnostic.test_stream.default(device_index) + return torch.ops.libtorch_agnostic_2_10.test_stream.default(device_index) def test_get_current_device_index() -> int: @@ -320,7 +320,7 @@ def test_get_current_device_index() -> int: Returns: Current device index as an integer """ - return torch.ops.libtorch_agnostic.test_get_current_device_index.default() + return torch.ops.libtorch_agnostic_2_10.test_get_current_device_index.default() def my_new_empty_dtype_variant(t) -> Tensor: @@ -332,7 +332,7 @@ def my_new_empty_dtype_variant(t) -> Tensor: Returns: New empty tensor with shape [2, 5] and dtype bfloat16 """ - return torch.ops.libtorch_agnostic.my_new_empty_dtype_variant.default(t) + return torch.ops.libtorch_agnostic_2_10.my_new_empty_dtype_variant.default(t) def my_new_zeros_dtype_variant(t) -> Tensor: @@ -344,7 +344,7 @@ def my_new_zeros_dtype_variant(t) -> Tensor: Returns: New zeros tensor """ - return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t) + return torch.ops.libtorch_agnostic_2_10.my_new_zeros_dtype_variant.default(t) def my__foreach_mul_(tensors, others) -> (): @@ -357,7 +357,7 @@ def my__foreach_mul_(tensors, others) -> (): Returns: nothing, tensors is updated in place. """ - torch.ops.libtorch_agnostic.my__foreach_mul_.default(tensors, others) + torch.ops.libtorch_agnostic_2_10.my__foreach_mul_.default(tensors, others) def my__foreach_mul(tensors, others) -> list[Tensor]: @@ -371,7 +371,7 @@ def my__foreach_mul(tensors, others) -> list[Tensor]: Returns: list of multiplied tensors """ - return torch.ops.libtorch_agnostic.my__foreach_mul.default(tensors, others) + return torch.ops.libtorch_agnostic_2_10.my__foreach_mul.default(tensors, others) def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]: @@ -384,7 +384,7 @@ def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]: Returns: list of [t1^2, t2^2] """ - return torch.ops.libtorch_agnostic.make_tensor_clones_and_call_foreach.default( + return torch.ops.libtorch_agnostic_2_10.make_tensor_clones_and_call_foreach.default( t1, t2 ) @@ -400,7 +400,7 @@ def test_device_constructor(is_cuda, index, use_str): Returns: Device - A device with the specified type and index """ - return torch.ops.libtorch_agnostic.test_device_constructor.default( + return torch.ops.libtorch_agnostic_2_10.test_device_constructor.default( is_cuda, index, use_str ) @@ -415,7 +415,7 @@ def test_device_equality(d1, d2) -> bool: Returns: bool - True if devices are equal """ - return torch.ops.libtorch_agnostic.test_device_equality.default(d1, d2) + return torch.ops.libtorch_agnostic_2_10.test_device_equality.default(d1, d2) def test_device_set_index(device, index): @@ -428,7 +428,7 @@ def test_device_set_index(device, index): Returns: Device - device with updated index """ - return torch.ops.libtorch_agnostic.test_device_set_index.default(device, index) + return torch.ops.libtorch_agnostic_2_10.test_device_set_index.default(device, index) def test_device_index(device) -> int: @@ -440,7 +440,7 @@ def test_device_index(device) -> int: Returns: int - device index """ - return torch.ops.libtorch_agnostic.test_device_index.default(device) + return torch.ops.libtorch_agnostic_2_10.test_device_index.default(device) def test_device_is_cuda(device) -> bool: @@ -452,7 +452,7 @@ def test_device_is_cuda(device) -> bool: Returns: bool - True if device is CUDA """ - return torch.ops.libtorch_agnostic.test_device_is_cuda.default(device) + return torch.ops.libtorch_agnostic_2_10.test_device_is_cuda.default(device) def test_device_is_cpu(device) -> bool: @@ -464,7 +464,7 @@ def test_device_is_cpu(device) -> bool: Returns: bool - True if device is CPU """ - return torch.ops.libtorch_agnostic.test_device_is_cpu.default(device) + return torch.ops.libtorch_agnostic_2_10.test_device_is_cpu.default(device) def test_parallel_for(size, grain_size) -> Tensor: @@ -476,7 +476,7 @@ def test_parallel_for(size, grain_size) -> Tensor: Returns: Tensor - a 1D int64 tensor where each element contains its index (if multiple threads are used the threadid will be encoded in the upper 32 bits) """ - return torch.ops.libtorch_agnostic.test_parallel_for.default(size, grain_size) + return torch.ops.libtorch_agnostic_2_10.test_parallel_for.default(size, grain_size) def test_get_num_threads() -> int: @@ -486,7 +486,7 @@ def test_get_num_threads() -> int: Returns: int - the number of threads for the parallel backend """ - return torch.ops.libtorch_agnostic.test_get_num_threads.default() + return torch.ops.libtorch_agnostic_2_10.test_get_num_threads.default() def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor: @@ -501,7 +501,9 @@ def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor: Returns: Tensor - an uninitialized tensor with the specified properties """ - return torch.ops.libtorch_agnostic.my_empty.default(size, dtype, device, pin_memory) + return torch.ops.libtorch_agnostic_2_10.my_empty.default( + size, dtype, device, pin_memory + ) def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor: @@ -515,7 +517,7 @@ def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor: Returns: Tensor - flattened tensor """ - return torch.ops.libtorch_agnostic.my_flatten.default(t, start_dim, end_dim) + return torch.ops.libtorch_agnostic_2_10.my_flatten.default(t, start_dim, end_dim) def my_reshape(t, shape) -> Tensor: @@ -528,7 +530,7 @@ def my_reshape(t, shape) -> Tensor: Returns: Tensor - reshaped tensor """ - return torch.ops.libtorch_agnostic.my_reshape.default(t, shape) + return torch.ops.libtorch_agnostic_2_10.my_reshape.default(t, shape) def my_view(t, size) -> Tensor: @@ -541,4 +543,4 @@ def my_view(t, size) -> Tensor: Returns: Tensor - tensor with new view """ - return torch.ops.libtorch_agnostic.my_view.default(t, size) + return torch.ops.libtorch_agnostic_2_10.my_view.default(t, size) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py similarity index 76% rename from test/cpp_extensions/libtorch_agnostic_extension/setup.py rename to test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py index 30a93193d557d..2e2ed62e32fea 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/setup.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py @@ -9,7 +9,7 @@ ROOT_DIR = Path(__file__).parent -CSRC_DIR = ROOT_DIR / "libtorch_agnostic" / "csrc" +CSRC_DIR = ROOT_DIR / "libtorch_agnostic_2_10" / "csrc" class clean(distutils.command.clean.clean): @@ -18,13 +18,13 @@ def run(self): distutils.command.clean.clean.run(self) # Remove extension - for path in (ROOT_DIR / "libtorch_agnostic").glob("**/*.so"): + for path in (ROOT_DIR / "libtorch_agnostic_2_10").glob("**/*.so"): path.unlink() # Remove build and dist and egg-info directories dirs = [ ROOT_DIR / "build", ROOT_DIR / "dist", - ROOT_DIR / "libtorch_agnostic.egg-info", + ROOT_DIR / "libtorch_agnostic_2_10.egg-info", ] for path in dirs: if path.exists(): @@ -33,7 +33,11 @@ def run(self): def get_extension(): extra_compile_args = { - "cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"], + "cxx": [ + "-fdiagnostics-color=always", + "-DTORCH_STABLE_ONLY", + "-DTORCH_TARGET_VERSION=0x020a000000000000", + ], } extension = CppExtension @@ -46,7 +50,7 @@ def get_extension(): return [ extension( - "libtorch_agnostic._C", + "libtorch_agnostic_2_10._C", sources=sorted(str(s) for s in sources), py_limited_api=True, extra_compile_args=extra_compile_args, @@ -56,12 +60,12 @@ def get_extension(): setup( - name="libtorch_agnostic", + name="libtorch_agnostic_2_10", version="0.0", author="PyTorch Core Team", description="Example of libtorch agnostic extension", packages=find_packages(exclude=("test",)), - package_data={"libtorch_agnostic": ["*.dll", "*.dylib", "*.so"]}, + package_data={"libtorch_agnostic_2_10": ["*.dll", "*.dylib", "*.so"]}, install_requires=[ "torch", ], diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/__init__.py b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/__init__.py new file mode 100644 index 0000000000000..7fa8732335cf0 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/__init__.py @@ -0,0 +1,21 @@ +import ctypes +from pathlib import Path + +import torch + + +so_files = list(Path(__file__).parent.glob("_C*.so")) +assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" + +# use ctypes.CDLL instead of load_library to be able to test the unload logic +# below code is reduced from the load_library code +with torch._ops.dl_open_guard(): + loaded_lib = ctypes.CDLL(so_files[0]) + +from . import ops + + +__all__ = [ + "loaded_lib", + "ops", +] diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp new file mode 100644 index 0000000000000..8907dc9ea0a68 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp @@ -0,0 +1,411 @@ +#include +#include +#include +#include +#include +#include +#include + +#ifdef LAE_USE_CUDA +#include +#endif + +#include + +void inline sgd_math( + float* param_ptr, + float* grad_ptr, + float* out_ptr, + const float weight_decay, + const double lr, + const bool maximize, + int64_t size +){ + int64_t d = 0; + for (; d < size; d++) { + float grad_val = grad_ptr[d]; + if (maximize) grad_val = -grad_val; + if (weight_decay != 0.0){ + grad_val += param_ptr[d] * weight_decay; + } + out_ptr[d] = param_ptr[d] - grad_val * float(lr); + } +} + +using torch::stable::Tensor; + +Tensor sgd_out_of_place( + const Tensor param, + const Tensor grad, + const double weight_decay, + const double lr, + const bool maximize) { + STD_TORCH_CHECK(param.dim() == 1, "param must be 1D"); + + // these test the get_device() and get_device_index() methods + // while ascertaining that we are still on CPU + STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1"); + STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1"); + + // testing Tensor strides + stride + STD_TORCH_CHECK(param.strides()[0] == param.stride(0)); + + auto out = new_empty(param, param.sizes()); + + sgd_math( + reinterpret_cast(param.data_ptr()), + reinterpret_cast(grad.data_ptr()), + reinterpret_cast(out.data_ptr()), + float(weight_decay), + lr, + maximize, + param.numel() + ); + + return out; +} + + +STABLE_TORCH_LIBRARY(libtorch_agnostic_2_9, m) { + m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CPU, m) { + m.impl("sgd_out_of_place", TORCH_BOX(&sgd_out_of_place)); +} + +Tensor identity(Tensor t) { + return t; +} + + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("identity(Tensor t) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CUDA, m) { + m.impl("identity", TORCH_BOX(&identity)); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CPU, m) { + m.impl("identity", TORCH_BOX(&identity)); +} + +Tensor my_abs(Tensor t) { + const auto num_args = 1; + StableIValue stack[num_args]; + stack[0] = torch::stable::detail::from(t); + aoti_torch_call_dispatcher("aten::abs", "", stack); + return torch::stable::detail::to(stack[0]); +} + + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("my_abs(Tensor t) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("my_abs", TORCH_BOX(&my_abs)); +} + +Tensor my_ones_like(Tensor t, StableIValue device) { + const auto num_args = 6; + StableIValue stack[num_args]; + + auto mf = aoti_torch_memory_format_contiguous_format(); + + stack[0] = torch::stable::detail::from(t); + stack[1] = torch::stable::detail::from(std::optional(t.scalar_type())); // dtype + stack[2] = torch::stable::detail::from(std::nullopt); // layout + stack[3] = torch::stable::detail::from(std::optional(device)); // device + stack[4] = torch::stable::detail::from(std::optional(false)); // pin_memory + stack[5] = torch::stable::detail::from(std::optional(mf)); // memory_format + + aoti_torch_call_dispatcher("aten::ones_like", "", stack); + + return torch::stable::detail::to(stack[0]); +} + + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("my_ones_like(Tensor t, Device d) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("my_ones_like", TORCH_BOX(&my_ones_like)); +} + +std::tuple exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) { + StableIValue stack_exp[1]; + stack_exp[0] = torch::stable::detail::from(t1); + aoti_torch_call_dispatcher("aten::exp", "", stack_exp); + + StableIValue stack_neg[1]; + stack_neg[0] = torch::stable::detail::from(t2); + aoti_torch_call_dispatcher("aten::neg", "", stack_neg); + + StableIValue stack_is_leaf[1]; + stack_is_leaf[0] = torch::stable::detail::from(t3); + aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf); + + return std::make_tuple( + torch::stable::detail::to(stack_exp[0]), + torch::stable::detail::to(stack_neg[0]), + torch::stable::detail::to(stack_is_leaf[0])); +} + + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("exp_neg_is_leaf", TORCH_BOX(&exp_neg_is_leaf)); +} + +Tensor neg_exp(Tensor t) { + StableIValue stack[1]; + stack[0] = torch::stable::detail::from(t); + aoti_torch_call_dispatcher("aten::exp", "", stack); + aoti_torch_call_dispatcher("aten::neg", "", stack); + return torch::stable::detail::to(stack[0]); +} + + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("neg_exp(Tensor t) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("neg_exp", TORCH_BOX(&neg_exp)); +} + +Tensor divide_neg_exp(Tensor t) { + StableIValue stack_neg[1]; + stack_neg[0] = torch::stable::detail::from(t); + + StableIValue stack_exp[1]; + stack_exp[0] = torch::stable::detail::from(t); + aoti_torch_call_dispatcher("aten::exp", "", stack_exp); + aoti_torch_call_dispatcher("aten::neg", "", stack_neg); + + StableIValue stack_div[2]; + stack_div[0] = stack_neg[0]; + stack_div[1] = stack_exp[0]; + aoti_torch_call_dispatcher("aten::divide", "Tensor", stack_div); + return torch::stable::detail::to(stack_div[0]); +} + + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("divide_neg_exp(Tensor t) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("divide_neg_exp", TORCH_BOX(÷_neg_exp)); +} + +bool is_contiguous(Tensor t) { + return t.is_contiguous(); +} + + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("is_contiguous(Tensor t) -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("is_contiguous", TORCH_BOX(&is_contiguous)); +} + +Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) { + return transpose(t, dim0, dim1); +} + + +Tensor my_empty_like(Tensor t) { + return empty_like(t); +} + + +bool my_is_cpu(Tensor t) { + return t.is_cpu(); +} + + + +Tensor fill_infinity(Tensor t) { + auto value = std::numeric_limits::infinity(); + return fill_(t, value); +} + + +Tensor my_pad(Tensor t) { + std::vector padding = {1, 2, 2, 1}; + std::string mode = "constant"; + double value = 0.0; + return pad(t, padding, mode, value); +} + + +Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) { + return narrow(t, dim, start, length); +} + + +Tensor my_new_empty_dtype_variant(Tensor t) { + std::vector sizes = {2, 5}; + auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16); + return new_empty(t, sizes, dtype); +} + + +Tensor my_new_zeros_dtype_variant(Tensor t) { + std::vector sizes = {2, 5}; + auto dtype = std::make_optional(at::ScalarType::Float); + return new_zeros(t, sizes, dtype); +} + + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor"); + m.def("my_empty_like(Tensor t) -> Tensor"); + m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)"); + m.def("my_pad(Tensor t) -> Tensor"); + m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor"); + m.def("my_new_empty_dtype_variant(Tensor t) -> Tensor"); + m.def("my_new_zeros_dtype_variant(Tensor t) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("my_transpose", TORCH_BOX(&my_transpose)); + m.impl("my_empty_like", TORCH_BOX(&my_empty_like)); + m.impl("fill_infinity", TORCH_BOX(&fill_infinity)); + m.impl("my_is_cpu", TORCH_BOX(&my_is_cpu)); + m.impl("my_new_empty_dtype_variant", TORCH_BOX(&my_new_empty_dtype_variant)); + m.impl("my_new_zeros_dtype_variant", TORCH_BOX(&my_new_zeros_dtype_variant)); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeImplicitAutograd, m) { + m.impl("my_pad", TORCH_BOX(&my_pad)); + m.impl("my_narrow", TORCH_BOX(&my_narrow)); +} + +Tensor my_zero_(Tensor t) { + return zero_(t); +} + + +Tensor my_amax(Tensor t) { + return amax(t, 0, false); +} + + +Tensor my_amax_vec(Tensor t) { + std::vector v = {0,1}; + return amax(t, v, false); +} + + + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)"); + m.def("my_amax(Tensor a) -> Tensor"); + m.def("my_amax_vec(Tensor a) -> Tensor"); + m.def("my_is_cpu(Tensor t) -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CPU, m) { + m.impl("my_zero_", TORCH_BOX(&my_zero_)); +} + +bool test_default_constructor(bool defined) { + Tensor out; + if (defined) { + AtenTensorHandle defined_ath; + int64_t sizes[] = {2, 3}; + int64_t strides[] = {3, 1}; + aoti_torch_empty_strided( + 2, + sizes, + strides, + aoti_torch_dtype_float32(), + aoti_torch_device_type_cpu(), + 0, + &defined_ath); + out = Tensor(defined_ath); + } + return out.defined(); +} + + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("test_default_constructor(bool undefined) -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("test_default_constructor", TORCH_BOX(&test_default_constructor)); + m.impl("my_amax", TORCH_BOX(&my_amax)); + m.impl("my_amax_vec", TORCH_BOX(&my_amax_vec)); +} + +// Test functions for torch::stable::accelerator APIs + +#ifdef LAE_USE_CUDA +int64_t test_device_guard(int64_t device_index) { + using torch::stable::accelerator::DeviceGuard; + + STD_TORCH_CHECK( + device_index >= std::numeric_limits::min() && + device_index <= std::numeric_limits::max(), + "Device index is out of range of DeviceIndex (int32_t)."); + + DeviceGuard guard(device_index); + int currentDevice; + cudaError_t err = cudaGetDevice(¤tDevice); + STD_TORCH_CHECK(err == cudaSuccess); + return currentDevice; +} + + +int64_t test_device_guard_set_index() { + using torch::stable::accelerator::DeviceGuard; + + DeviceGuard guard(1); + guard.set_index(0); + int currentDevice; + cudaError_t err = cudaGetDevice(¤tDevice); + STD_TORCH_CHECK(err == cudaSuccess); + return currentDevice; +} + + +int64_t test_stream(int32_t device_index) { + STD_TORCH_CHECK( + device_index >= std::numeric_limits::min() && + device_index <= std::numeric_limits::max(), + "Device index is out of range of DeviceIndex (int32_t)."); + + return torch::stable::accelerator::getCurrentStream(device_index).id(); +} + + +int64_t test_get_current_device_index() { + return torch::stable::accelerator::getCurrentDeviceIndex(); +} + + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("test_device_guard(int device_index) -> int"); + m.def("test_device_guard_set_index() -> int"); + m.def("test_stream(int device_index) -> int"); + m.def("test_get_current_device_index() -> int"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("test_device_guard", TORCH_BOX(&test_device_guard)); + m.impl("test_device_guard_set_index", TORCH_BOX(&test_device_guard_set_index)); + m.impl("test_stream", TORCH_BOX(&test_stream)); + m.impl("test_get_current_device_index", TORCH_BOX(&test_get_current_device_index)); +} +#endif // LAE_USE_CUDA diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py new file mode 100644 index 0000000000000..b8f906faab399 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py @@ -0,0 +1,309 @@ +import torch +from torch import Tensor + + +def sgd_out_of_place(param, grad, weight_decay, lr, maximize) -> Tensor: + """ + Computes a single step of SGD on a single parameter Tensor with grad. + + Assumes: + - param and grad are the same shape and are 1D. + - param and grad are float and on CPU + + Args: + param: a 1D tensor of floats + grad: a 1D tensor of floats + weight_decay: a python double between 0 and 1 + lr: a python double + + Returns: + a 1D float Tensor the same shape as param + + """ + return torch.ops.libtorch_agnostic_2_9.sgd_out_of_place.default( + param, grad, weight_decay, lr, maximize + ) + + +def identity(t) -> Tensor: + """ + Returns the input tensor + + Args: + t: any Tensor + + Returns: + a Tensor, the same as input. + """ + return torch.ops.libtorch_agnostic_2_9.identity.default(t) + + +def my_abs(t) -> Tensor: + """ + Returns abs on the input tensor, outputs a new Tensor + + Args: + t: any Tensor + + Returns: + a Tensor + """ + return torch.ops.libtorch_agnostic_2_9.my_abs.default(t) + + +def my_is_cpu(t) -> bool: + """ + Returns is_cpu on the input tensor. + + Args: + t: any Tensor + + Returns: + a bool + """ + return torch.ops.libtorch_agnostic_2_9.my_is_cpu.default(t) + + +def my_ones_like(tensor, device) -> Tensor: + """ + Returns a new Tensor like the input tensor, but with all ones + + Args: + tensor: any Tensor + device: a device string + + Returns: + a ones Tensor with the same dtype and shape and other attributes + like the input tensor + """ + return torch.ops.libtorch_agnostic_2_9.my_ones_like.default(tensor, device) + + +def exp_neg_is_leaf(t1, t2, t3) -> tuple[Tensor, Tensor, bool]: + """ + Returns a Tensor, Tensor, bool tuple corresponding to the respective inputs + t1, t2, and t3. + + Args: + t1: Tensor + t2: Tensor + t3: Tensor + + Returns: + (exp(t1), neg(t2), is_leaf(t3)) + """ + return torch.ops.libtorch_agnostic_2_9.exp_neg_is_leaf.default(t1, t2, t3) + + +def neg_exp(t) -> Tensor: + """ + Returns a Tensor composing neg of exp + + Args: + t: Tensor + + Returns: neg(exp(t)) + """ + return torch.ops.libtorch_agnostic_2_9.neg_exp.default(t) + + +def divide_neg_exp(t) -> Tensor: + """ + Returns a Tensor division of neg and exp + + Args: + t: Tensor + + Returns: divide(neg(t), exp(t)) + """ + return torch.ops.libtorch_agnostic_2_9.divide_neg_exp.default(t) + + +def is_contiguous(t) -> bool: + """ + Returns a bool indicating if the input tensor is contiguous + + Args: + t: Tensor + + Returns: is_contiguous(t) + """ + return torch.ops.libtorch_agnostic_2_9.is_contiguous.default(t) + + +def my_transpose(t, dim0, dim1) -> Tensor: + """ + Returns t.transpose(dim0, dim1) + + Args: + t: Tensor + + Returns: my_transpose(t, dim0, dim1) + """ + return torch.ops.libtorch_agnostic_2_9.my_transpose.default(t, dim0, dim1) + + +def my_empty_like(t) -> Tensor: + """ + Returns t.empty_like() + + Args: + t: Tensor + + Returns: my_empty_like(t) + """ + return torch.ops.libtorch_agnostic_2_9.my_empty_like.default(t) + + +def my_zero_(t) -> Tensor: + """ + Returns t.zero_() + + Args: + t: Tensor + + Returns: my_zero_(t) + """ + return torch.ops.libtorch_agnostic_2_9.my_zero_.default(t) + + +def my_amax(t) -> Tensor: + """ + Returns t.amax() + + Args: + t: Tensor + + Returns: amax(t) + """ + return torch.ops.libtorch_agnostic_2_9.my_amax.default(t) + + +def my_amax_vec(t) -> Tensor: + """ + Returns t.amax() + + Args: + t: Tensor + + Returns: amax(t) + """ + return torch.ops.libtorch_agnostic_2_9.my_amax_vec.default(t) + + +def fill_infinity(t) -> Tensor: + """ + Fills the tensor with inf. + + Args: + t: Tensor to fill + + Returns: The modified tensor (same as input) + """ + return torch.ops.libtorch_agnostic_2_9.fill_infinity.default(t) + + +def test_default_constructor(defined) -> bool: + """ + Tests the default constructor for torch::stable::Tensor. + + Args: + defined: bool - if True, tests defined tensor; if False, tests undefined tensor + + Returns: bool - result of calling .defined() on the tensor + """ + return torch.ops.libtorch_agnostic_2_9.test_default_constructor.default(defined) + + +def my_pad(t) -> Tensor: + """ + Pads the input tensor with hardcoded padding parameters. + + Args: + t: Input tensor + + Returns: Padded tensor with padding [1, 2, 2, 1], mode "constant", value 0.0 + """ + return torch.ops.libtorch_agnostic_2_9.my_pad.default(t) + + +def my_narrow(t, dim, start, length) -> Tensor: + """ + Returns a new tensor that is a narrowed version of the input tensor. + + Args: + t: Input tensor + dim: Dimension along which to narrow + start: Starting position + length: Length of the narrowed section + + Returns: Narrowed tensor + """ + return torch.ops.libtorch_agnostic_2_9.my_narrow.default(t, dim, start, length) + + +def test_device_guard(device_index) -> int: + """ + Tests the DeviceGuard functionality by creating a device guard and returning an empty tensor. + + Args: + device_index: Device index to set the guard to + + Returns: result of cudaGetDevice() as an integer after using the guard + """ + return torch.ops.libtorch_agnostic_2_9.test_device_guard.default(device_index) + + +def test_device_guard_set_index() -> int: + """ + Tests the DeviceGuard set_index functionality by creating a device guard with index 1, + then setting it to index 0, and returning the current device. + + Returns: result of cudaGetDevice() as an integer after using set_index + """ + return torch.ops.libtorch_agnostic_2_9.test_device_guard_set_index.default() + + +def test_stream(device_index) -> int: + """ + Tests the Stream functionality by getting the current stream ID for the specified device. + + Args: + device_index: Device index to get the stream for + + Returns: Stream ID as an integer + """ + return torch.ops.libtorch_agnostic_2_9.test_stream.default(device_index) + + +def test_get_current_device_index() -> int: + """ + Tests the getCurrentDeviceIndex functionality by getting the current device index. + + Returns: Current device index as an integer + """ + return torch.ops.libtorch_agnostic_2_9.test_get_current_device_index.default() + + +def my_new_empty_dtype_variant(t) -> Tensor: + """ + Returns a new empty tensor with shape [2, 5] and dtype bfloat16 + + Args: + t: Input tensor used as a reference for device and other properties + + Returns: New empty tensor with shape [2, 5] and dtype bfloat16 + """ + return torch.ops.libtorch_agnostic_2_9.my_new_empty_dtype_variant.default(t) + + +def my_new_zeros_dtype_variant(t) -> Tensor: + """ + Returns a new tensor filled with 0s with shape [2, 5] and dtype Float + + Args: + t: Input tensor used as a reference for device and other properties + + Returns: New zeros tensor + """ + return torch.ops.libtorch_agnostic_2_9.my_new_zeros_dtype_variant.default(t) diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_2_9_extension/setup.py new file mode 100644 index 0000000000000..a5ea6040c6990 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/setup.py @@ -0,0 +1,78 @@ +import distutils.command.clean +import shutil +from pathlib import Path + +from setuptools import find_packages, setup + +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension + + +ROOT_DIR = Path(__file__).parent +CSRC_DIR = ROOT_DIR / "libtorch_agnostic_2_9" / "csrc" + + +class clean(distutils.command.clean.clean): + def run(self): + # Run default behavior first + distutils.command.clean.clean.run(self) + + # Remove extension + for path in (ROOT_DIR / "libtorch_agnostic_2_9").glob("**/*.so"): + path.unlink() + # Remove build and dist and egg-info directories + dirs = [ + ROOT_DIR / "build", + ROOT_DIR / "dist", + ROOT_DIR / "libtorch_agnostic_2_9.egg-info", + ] + for path in dirs: + if path.exists(): + shutil.rmtree(str(path), ignore_errors=True) + + +def get_extension(): + extra_compile_args = { + "cxx": [ + "-fdiagnostics-color=always", + "-DTORCH_STABLE_ONLY", + "-DTORCH_TARGET_VERSION=0x0209000000000000", + ], + } + + extension = CppExtension + # allow including + if torch.cuda.is_available(): + extra_compile_args["cxx"].append("-DLAE_USE_CUDA") + extension = CUDAExtension + + sources = list(CSRC_DIR.glob("**/*.cpp")) + + return [ + extension( + "libtorch_agnostic_2_9._C", + sources=sorted(str(s) for s in sources), + py_limited_api=True, + extra_compile_args=extra_compile_args, + extra_link_args=[], + ) + ] + + +setup( + name="libtorch_agnostic_2_9", + version="0.0", + author="PyTorch Core Team", + description="Example of libtorch agnostic extension for PyTorch 2.9.0", + packages=find_packages(exclude=("test",)), + package_data={"libtorch_agnostic_2_9": ["*.dll", "*.dylib", "*.so"]}, + install_requires=[ + "torch", + ], + ext_modules=get_extension(), + cmdclass={ + "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), + "clean": clean, + }, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, +) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic_versioned.py similarity index 54% rename from test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py rename to test/cpp_extensions/test_libtorch_agnostic_versioned.py index 8cfba98349817..cff85479b755c 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic_versioned.py @@ -23,17 +23,57 @@ # LINK : error LNK2001: unresolved external symbol PyInit__C if not IS_WINDOWS: - class TestLibtorchAgnostic(TestCase): + class TestLibtorchAgnosticVersioned(TestCase): + """ + Tests for versioned libtorch_agnostic extensions. + + This test class supports testing both: + - libtorch_agnostic_2_9: Extension built with TORCH_TARGET_VERSION=2.9.0 + - libtorch_agnostic_2_10: Extension built with TORCH_TARGET_VERSION=2.10.0 + + Both extensions must be available for the tests to run. + """ + + ops_2_9 = None + ops_2_10 = None + @classmethod def setUpClass(cls): + # Install and import the 2.9 extension + try: + import libtorch_agnostic_2_9 + + cls.ops_2_9 = libtorch_agnostic_2_9.ops + except ImportError: + extension_root = ( + Path(__file__).parent / "libtorch_agnostic_2_9_extension" + ) + install_cpp_extension(extension_root=extension_root) + import libtorch_agnostic_2_9 + + cls.ops_2_9 = libtorch_agnostic_2_9.ops + + # Install and import the 2.10 extension try: - import libtorch_agnostic # noqa: F401 - except Exception: - install_cpp_extension(extension_root=Path(__file__).parent.parent) + import libtorch_agnostic_2_10 + + cls.ops_2_10 = libtorch_agnostic_2_10.ops + except ImportError: + extension_root = ( + Path(__file__).parent / "libtorch_agnostic_2_10_extension" + ) + install_cpp_extension(extension_root=extension_root) + import libtorch_agnostic_2_10 + + cls.ops_2_10 = libtorch_agnostic_2_10.ops + + # ============================================================================ + # Tests for 2.9 features + # ============================================================================ @onlyCPU - def test_slow_sgd(self, device): - import libtorch_agnostic + def test_2_9_slow_sgd(self, device): + ops = self.ops_2_9 param = torch.rand(5, device=device) grad = torch.rand_like(param) @@ -41,9 +81,7 @@ def test_slow_sgd(self, device): lr = 0.001 maximize = False - new_param = libtorch_agnostic.ops.sgd_out_of_place( - param, grad, weight_decay, lr, maximize - ) + new_param = ops.sgd_out_of_place(param, grad, weight_decay, lr, maximize) torch._fused_sgd_( (param,), (grad,), @@ -59,13 +97,13 @@ def test_slow_sgd(self, device): self.assertEqual(new_param, param) @onlyCUDA - def test_identity_does_not_hog_memory(self, device): - import libtorch_agnostic + def test_2_9_identity_does_not_hog_memory(self, device): + ops = self.ops_2_9 def _run_identity(prior_mem): t = torch.rand(32, 32, device=device) self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) - identi_t = libtorch_agnostic.ops.identity(t) + identi_t = ops.identity(t) assert identi_t is t init_mem = torch.cuda.memory_allocated(device) @@ -75,27 +113,27 @@ def _run_identity(prior_mem): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) - def test_exp_neg_is_leaf(self, device): - import libtorch_agnostic + def test_2_9_exp_neg_is_leaf(self, device): + ops = self.ops_2_9 t1 = torch.rand(2, 3, device=device) t2 = torch.rand(3, 2, device=device) t3 = torch.rand(2, device=device) - exp, neg, is_leaf = libtorch_agnostic.ops.exp_neg_is_leaf(t1, t2, t3) + exp, neg, is_leaf = ops.exp_neg_is_leaf(t1, t2, t3) self.assertEqual(exp, torch.exp(t1)) self.assertEqual(neg, torch.neg(t2)) self.assertEqual(is_leaf, t3.is_leaf) - def test_my_abs(self, device): - import libtorch_agnostic + def test_2_9_my_abs(self, device): + ops = self.ops_2_9 t = torch.rand(32, 16, device=device) - 0.5 - res = libtorch_agnostic.ops.my_abs(t) + res = ops.my_abs(t) self.assertEqual(res, torch.abs(t)) def _make_cuda_tensors(prior_mem): - cuda_t = libtorch_agnostic.ops.my_abs(t) + cuda_t = ops.my_abs(t) self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) self.assertEqual(cuda_t, torch.abs(t)) @@ -106,15 +144,15 @@ def _make_cuda_tensors(prior_mem): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) - def test_neg_exp(self, device): - import libtorch_agnostic + def test_2_9_neg_exp(self, device): + ops = self.ops_2_9 t = torch.rand(32, 16, device=device) - 0.5 - res = libtorch_agnostic.ops.neg_exp(t) + res = ops.neg_exp(t) self.assertEqual(res, torch.neg(torch.exp(t))) def _make_cuda_tensors(prior_mem): - cuda_res = libtorch_agnostic.ops.neg_exp(t) + cuda_res = ops.neg_exp(t) self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) self.assertEqual(cuda_res, torch.neg(torch.exp(t))) @@ -125,15 +163,15 @@ def _make_cuda_tensors(prior_mem): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) - def test_divide_neg_exp(self, device): - import libtorch_agnostic + def test_2_9_divide_neg_exp(self, device): + ops = self.ops_2_9 t = torch.zeros(2, 3, device=device) - 0.5 - res = libtorch_agnostic.ops.divide_neg_exp(t) + res = ops.divide_neg_exp(t) self.assertEqual(res, torch.neg(t) / torch.exp(t)) def _make_cuda_tensors(prior_mem): - cuda_res = libtorch_agnostic.ops.divide_neg_exp(t) + cuda_res = ops.divide_neg_exp(t) self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) self.assertEqual(cuda_res, torch.neg(t) / torch.exp(t)) @@ -144,27 +182,23 @@ def _make_cuda_tensors(prior_mem): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) - def test_is_contiguous(self, device): - import libtorch_agnostic + def test_2_9_is_contiguous(self, device): + ops = self.ops_2_9 t = torch.rand(2, 7, device=device) - self.assertTrue(libtorch_agnostic.ops.is_contiguous(t)) - self.assertFalse(libtorch_agnostic.ops.is_contiguous(t.transpose(0, 1))) + self.assertTrue(ops.is_contiguous(t)) + self.assertFalse(ops.is_contiguous(t.transpose(0, 1))) - # TODO: Debug this: - # torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: - # call_function libtorch_agnostic.my_ones_like.default(*(FakeTensor(..., size=(3, 1)), 'cpu'), - # **{}): got AssertionError("tensor's device must be `meta`, got cpu instead") @xfailIfTorchDynamo - def test_my_ones_like(self, device): - import libtorch_agnostic + def test_2_9_my_ones_like(self, device): + ops = self.ops_2_9 t = torch.rand(3, 1, device=device) - 0.5 - cpu_t = libtorch_agnostic.ops.my_ones_like(t, "cpu") + cpu_t = ops.my_ones_like(t, "cpu") self.assertEqual(cpu_t, torch.ones_like(t, device="cpu")) def _make_cuda_tensors(prior_mem): - cuda_t = libtorch_agnostic.ops.my_ones_like(t, device) + cuda_t = ops.my_ones_like(t, device) self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) self.assertEqual(cuda_t, torch.ones_like(t, device=device)) @@ -175,141 +209,134 @@ def _make_cuda_tensors(prior_mem): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) - def test_my_transpose(self, device): - import libtorch_agnostic + def test_2_9_my_transpose(self, device): + ops = self.ops_2_9 t = torch.rand(2, 7, device=device) - out = libtorch_agnostic.ops.my_transpose(t, 0, 1) + out = ops.my_transpose(t, 0, 1) self.assertEqual(out, torch.transpose(t, 0, 1)) with self.assertRaisesRegex(RuntimeError, "API call failed"): - libtorch_agnostic.ops.my_transpose(t, 1, 2) + ops.my_transpose(t, 1, 2) - def test_my_empty_like(self, device): - import libtorch_agnostic + def test_2_9_my_empty_like(self, device): + ops = self.ops_2_9 deterministic = torch.are_deterministic_algorithms_enabled() try: - # set use_deterministic_algorithms to fill uninitialized memory torch.use_deterministic_algorithms(True) t = torch.rand(2, 7, device=device) - out = libtorch_agnostic.ops.my_empty_like(t) + out = ops.my_empty_like(t) self.assertTrue(id(out != id(t))) self.assertEqual(out, torch.empty_like(t)) finally: torch.use_deterministic_algorithms(deterministic) @onlyCPU - def test_my_zero_(self, device): - import libtorch_agnostic + def test_2_9_my_zero_(self, device): + ops = self.ops_2_9 t = torch.rand(2, 7, device=device) - out = libtorch_agnostic.ops.my_zero_(t) + out = ops.my_zero_(t) self.assertEqual(id(out), id(t)) self.assertEqual(out, torch.zeros_like(t)) - def test_my_amax(self, device): - import libtorch_agnostic + def test_2_9_my_amax(self, device): + ops = self.ops_2_9 t = torch.rand(2, 7, device=device) - out = libtorch_agnostic.ops.my_amax(t) + out = ops.my_amax(t) self.assertEqual(out, torch.amax(t, 0)) - def test_my_amax_vec(self, device): - import libtorch_agnostic + def test_2_9_my_amax_vec(self, device): + ops = self.ops_2_9 t = torch.rand(2, 7, 5, device=device) - out = libtorch_agnostic.ops.my_amax_vec(t) + out = ops.my_amax_vec(t) self.assertEqual(out, torch.amax(t, (0, 1))) - def test_my_is_cpu(self, device): - import libtorch_agnostic + def test_2_9_my_is_cpu(self, device): + ops = self.ops_2_9 t = torch.rand(2, 7, device=device) - out = libtorch_agnostic.ops.my_is_cpu(t) + out = ops.my_is_cpu(t) self.assertEqual(out, t.is_cpu) - def test_fill_infinity(self, device): - import libtorch_agnostic + def test_2_9_fill_infinity(self, device): + ops = self.ops_2_9 t = torch.rand(3, 4, device=device) - out = libtorch_agnostic.ops.fill_infinity(t) + out = ops.fill_infinity(t) self.assertEqual(id(out), id(t)) expected = torch.full_like(t, math.inf) self.assertEqual(out, expected) @onlyCPU - def test_default_constructor(self): - import libtorch_agnostic + def test_2_9_default_constructor(self): + ops = self.ops_2_9 - defined_tensor_is_defined = libtorch_agnostic.ops.test_default_constructor( - True - ) + defined_tensor_is_defined = ops.test_default_constructor(True) self.assertTrue(defined_tensor_is_defined) - undefined_tensor_is_defined = ( - libtorch_agnostic.ops.test_default_constructor(False) - ) + undefined_tensor_is_defined = ops.test_default_constructor(False) self.assertFalse(undefined_tensor_is_defined) - def test_my_pad(self, device): - import libtorch_agnostic + def test_2_9_my_pad(self, device): + ops = self.ops_2_9 t = torch.rand(2, 3, device=device) - out = libtorch_agnostic.ops.my_pad(t) + out = ops.my_pad(t) expected = torch.nn.functional.pad(t, [1, 2, 2, 1], "constant", 0.0) self.assertEqual(out, expected) - def test_my_narrow(self, device): - import libtorch_agnostic + def test_2_9_my_narrow(self, device): + ops = self.ops_2_9 t = torch.randn(2, 5, device=device) dim0 = 0 start0 = 0 length0 = 1 - out0 = libtorch_agnostic.ops.my_narrow(t, dim0, start0, length0) + out0 = ops.my_narrow(t, dim0, start0, length0) expected0 = torch.narrow(t, dim0, start0, length0) self.assertEqual(out0, expected0) @onlyCUDA @deviceCountAtLeast(2) - def test_device_guard(self, device): - import libtorch_agnostic + def test_2_9_device_guard(self, device): + ops = self.ops_2_9 device_index = 1 - out = libtorch_agnostic.ops.test_device_guard(device_index) + out = ops.test_device_guard(device_index) self.assertEqual(out, device_index) @onlyCUDA @deviceCountAtLeast(2) - def test_device_guard_set_index(self, device): - import libtorch_agnostic + def test_2_9_device_guard_set_index(self, device): + ops = self.ops_2_9 - # This test creates a DeviceGuard with index 1, then sets it to index 0 - # and returns the current device (should be 0) - out = libtorch_agnostic.ops.test_device_guard_set_index() + out = ops.test_device_guard_set_index() self.assertEqual(out, 0) @onlyCUDA - def test_stream(self, device): - import libtorch_agnostic + def test_2_9_stream(self, device): + ops = self.ops_2_9 stream = torch.cuda.Stream() device = torch.cuda.current_device() with stream: expected_stream_id = torch.cuda.current_stream(0).stream_id - stream_id = libtorch_agnostic.ops.test_stream(device) + stream_id = ops.test_stream(device) self.assertEqual(stream_id, expected_stream_id) @onlyCUDA @deviceCountAtLeast(2) - def test_get_current_device_index(self, device): - import libtorch_agnostic + def test_2_9_get_current_device_index(self, device): + ops = self.ops_2_9 prev_device = torch.cuda.current_device() @@ -317,85 +344,88 @@ def test_get_current_device_index(self, device): expected_device = 1 torch.cuda.set_device(expected_device) - current_device = libtorch_agnostic.ops.test_get_current_device_index() + current_device = ops.test_get_current_device_index() self.assertEqual(current_device, expected_device) finally: torch.cuda.set_device(prev_device) - def test_my_new_empty_dtype_variant(self, device): - import libtorch_agnostic + def test_2_9_my_new_empty_dtype_variant(self, device): + ops = self.ops_2_9 deterministic = torch.are_deterministic_algorithms_enabled() try: - # set use_deterministic_algorithms to fill uninitialized memory torch.use_deterministic_algorithms(True) t = torch.randn(3, 4, device=device) - out = libtorch_agnostic.ops.my_new_empty_dtype_variant(t) + out = ops.my_new_empty_dtype_variant(t) ref_out = t.new_empty((2, 5), dtype=torch.bfloat16) self.assertEqual(out, ref_out, exact_device=True) finally: torch.use_deterministic_algorithms(deterministic) - def test_my_new_zeros_dtype_variant(self, device): - import libtorch_agnostic + def test_2_9_my_new_zeros_dtype_variant(self, device): + ops = self.ops_2_9 t = torch.randn(3, 4, device=device) - out = libtorch_agnostic.ops.my_new_zeros_dtype_variant(t) + out = ops.my_new_zeros_dtype_variant(t) ref_out = t.new_zeros((2, 5), dtype=torch.float) self.assertEqual(out, ref_out, exact_device=True) - def test_my_copy_(self, device): - import libtorch_agnostic + # ============================================================================ + # Tests for 2.10 features (only work with 2.10 extension) + # ============================================================================ + + def test_2_10_my_copy_(self, device): + ops = self.ops_2_10 dst = torch.empty(2, 5, device=device) src = torch.randn(2, 5, device=device) - result = libtorch_agnostic.ops.my_copy_(dst, src, False) + result = ops.my_copy_(dst, src, False) expected = src self.assertEqual(result, expected) self.assertEqual(result.data_ptr(), dst.data_ptr()) - def test_my_clone(self, device): - import libtorch_agnostic + def test_2_10_my_clone(self, device): + ops = self.ops_2_10 t = torch.randn(2, 5, device=device) - result = libtorch_agnostic.ops.my_clone(t) + result = ops.my_clone(t) expected = t.clone() self.assertEqual(result, expected) self.assertNotEqual(result.data_ptr(), expected.data_ptr()) self.assertEqual(result.stride(), expected.stride()) - def test_my__foreach_mul_(self, device): - import libtorch_agnostic + def test_2_10_my__foreach_mul_(self, device): + ops = self.ops_2_10 N = 5 tensors = [torch.rand(32, 16, device=device) for _ in range(N)] tensors_c = [t.clone() for t in tensors] others = [torch.rand(32, 16, device=device) for _ in range(N)] - libtorch_agnostic.ops.my__foreach_mul_(tensors, others) + ops.my__foreach_mul_(tensors, others) expected_values = torch._foreach_mul(tensors_c, others) for tensor_t, expected_t in zip(tensors, expected_values): self.assertEqual(tensor_t, expected_t) - def test_my__foreach_mul(self, device): - import libtorch_agnostic + def test_2_10_my__foreach_mul(self, device): + ops = self.ops_2_10 N = 5 tensors = [torch.rand(32, 16, device=device) for _ in range(N)] others = [torch.rand(32, 16, device=device) for _ in range(N)] - result = libtorch_agnostic.ops.my__foreach_mul(tensors, others) + result = ops.my__foreach_mul(tensors, others) expected = torch._foreach_mul(tensors, others) for result_t, expected_t in zip(result, expected): self.assertEqual(result_t, expected_t) def _make_cuda_tensors(prior_mem): - cuda_res = libtorch_agnostic.ops.my__foreach_mul(tensors, others) + cuda_res = ops.my__foreach_mul(tensors, others) self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) expected = torch._foreach_mul(tensors, others) @@ -409,98 +439,76 @@ def _make_cuda_tensors(prior_mem): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) - def test_make_tensor_clones_and_call_foreach(self, device): - import libtorch_agnostic + def test_2_10_make_tensor_clones_and_call_foreach(self, device): + ops = self.ops_2_10 t1 = torch.rand(2, 5, device=device) t2 = torch.rand(3, 4, device=device) - result = libtorch_agnostic.ops.make_tensor_clones_and_call_foreach(t1, t2) + result = ops.make_tensor_clones_and_call_foreach(t1, t2) self.assertEqual(result[0], t1 * t1) self.assertEqual(result[1], t2 * t2) @onlyCUDA - def test_device(self, device): - import libtorch_agnostic + def test_2_10_device(self, device): + ops = self.ops_2_10 - cuda_device = libtorch_agnostic.ops.test_device_constructor( + cuda_device = ops.test_device_constructor( is_cuda=True, index=1, use_str=False ) self.assertEqual(cuda_device, torch.device("cuda:1")) - cuda_device = libtorch_agnostic.ops.test_device_constructor( + cuda_device = ops.test_device_constructor( is_cuda=True, index=1, use_str=True ) self.assertEqual(cuda_device, torch.device("cuda:1")) - self.assertEqual(libtorch_agnostic.ops.test_device_index(cuda_device), 1) + self.assertEqual(ops.test_device_index(cuda_device), 1) self.assertTrue( - libtorch_agnostic.ops.test_device_equality( - cuda_device, torch.device("cuda:1") - ) + ops.test_device_equality(cuda_device, torch.device("cuda:1")) ) self.assertFalse( - libtorch_agnostic.ops.test_device_equality( - cuda_device, torch.device("cuda:0") - ) + ops.test_device_equality(cuda_device, torch.device("cuda:0")) ) - self.assertFalse(libtorch_agnostic.ops.test_device_is_cpu(cuda_device)) - self.assertTrue(libtorch_agnostic.ops.test_device_is_cuda(cuda_device)) + self.assertFalse(ops.test_device_is_cpu(cuda_device)) + self.assertTrue(ops.test_device_is_cuda(cuda_device)) - cuda_0_device = libtorch_agnostic.ops.test_device_set_index(cuda_device, 0) + cuda_0_device = ops.test_device_set_index(cuda_device, 0) self.assertEqual(cuda_0_device, torch.device("cuda:0")) - cpu_device = libtorch_agnostic.ops.test_device_constructor(False, 0, False) + cpu_device = ops.test_device_constructor(False, 0, False) self.assertEqual(cpu_device, torch.device("cpu")) - self.assertTrue( - libtorch_agnostic.ops.test_device_equality( - cpu_device, torch.device("cpu") - ) - ) - self.assertTrue(libtorch_agnostic.ops.test_device_is_cpu(cpu_device)) - self.assertFalse(libtorch_agnostic.ops.test_device_is_cuda(cpu_device)) - self.assertFalse( - libtorch_agnostic.ops.test_device_equality(cpu_device, cuda_device) - ) + self.assertTrue(ops.test_device_equality(cpu_device, torch.device("cpu"))) + self.assertTrue(ops.test_device_is_cpu(cpu_device)) + self.assertFalse(ops.test_device_is_cuda(cpu_device)) + self.assertFalse(ops.test_device_equality(cpu_device, cuda_device)) with self.assertRaisesRegex( RuntimeError, "Device index 129 is out of range for int8_t" ): - libtorch_agnostic.ops.test_device_constructor( - is_cuda=True, index=129, use_str=False - ) + ops.test_device_constructor(is_cuda=True, index=129, use_str=False) with self.assertRaisesRegex( RuntimeError, "Device index 129 is out of range for int8_t" ): - libtorch_agnostic.ops.test_device_set_index(cuda_device, 129) + ops.test_device_set_index(cuda_device, 129) @onlyCUDA @deviceCountAtLeast(2) - def test_tensor_device(self, device): - import libtorch_agnostic + def test_2_10_tensor_device(self, device): + ops = self.ops_2_10 t = torch.randn(2, 3) - self.assertEqual(libtorch_agnostic.ops.test_tensor_device(t), t.device) + self.assertEqual(ops.test_tensor_device(t), t.device) t_cuda = torch.randn(2, 3, device="cuda") - self.assertEqual( - libtorch_agnostic.ops.test_tensor_device(t_cuda), t_cuda.device - ) + self.assertEqual(ops.test_tensor_device(t_cuda), t_cuda.device) t_cuda_1 = torch.randn(2, 3, device="cuda:1") - self.assertEqual( - libtorch_agnostic.ops.test_tensor_device(t_cuda_1), t_cuda_1.device - ) + self.assertEqual(ops.test_tensor_device(t_cuda_1), t_cuda_1.device) @onlyCPU - # TODO: Debug this: - # Dynamo failed to run FX node with fake tensors: - # call_function libtorch_agnostic.test_parallel_for.default(*(100, 10), **{}): - # got RuntimeError('libtorch_agnostic::test_parallel_for() expected at most - # 2 argument(s) but received 3 argument(s). - # Declaration: libtorch_agnostic::test_parallel_for(int size, int grain_size) -> Tensor') @xfailIfTorchDynamo - def test_parallel_for(self, device): - import libtorch_agnostic + def test_2_10_parallel_for(self, device): + ops = self.ops_2_10 num_threads = torch.get_num_threads() size = 100 @@ -509,7 +517,7 @@ def test_parallel_for(self, device): (size + grain_size - 1) // grain_size, num_threads ) - result = libtorch_agnostic.ops.test_parallel_for(size, grain_size) + result = ops.test_parallel_for(size, grain_size) result_thread_ids = torch.unique(torch.bitwise_right_shift(result, 32)) result_values = torch.bitwise_and(result, 0xFFFFFFFF) expected = torch.arange(size, dtype=torch.int64) @@ -518,35 +526,30 @@ def test_parallel_for(self, device): self.assertEqual(result_thread_ids, torch.arange(expected_num_threads_used)) @onlyCPU - def test_get_num_threads(self, device): - import libtorch_agnostic + def test_2_10_get_num_threads(self, device): + ops = self.ops_2_10 - num_threads = libtorch_agnostic.ops.test_get_num_threads() + num_threads = ops.test_get_num_threads() expected_num_threads = torch.get_num_threads() self.assertEqual(num_threads, expected_num_threads) - def test_my_empty(self, device): - import libtorch_agnostic + def test_2_10_my_empty(self, device): + ops = self.ops_2_10 deterministic = torch.are_deterministic_algorithms_enabled() try: - # set use_deterministic_algorithms to fill uninitialized memory torch.use_deterministic_algorithms(True) size = [2, 3] - result = libtorch_agnostic.ops.my_empty(size, None, None, None) + result = ops.my_empty(size, None, None, None) expected = torch.empty(size) self.assertEqual(result, expected, exact_device=True) - result_float = libtorch_agnostic.ops.my_empty( - size, torch.float32, None, None - ) + result_float = ops.my_empty(size, torch.float32, None, None) expected_float = torch.empty(size, dtype=torch.float32) self.assertEqual(result_float, expected_float, exact_device=True) - result_with_device = libtorch_agnostic.ops.my_empty( - size, torch.float64, device, None - ) + result_with_device = ops.my_empty(size, torch.float64, device, None) expected_with_device = torch.empty( size, dtype=torch.float64, device=device ) @@ -555,68 +558,47 @@ def test_my_empty(self, device): ) if device == "cuda": - result_pinned = libtorch_agnostic.ops.my_empty( - size, torch.float32, "cpu", True - ) + result_pinned = ops.my_empty(size, torch.float32, "cpu", True) expected_pinned = torch.empty( size, dtype=torch.float32, device="cpu", pin_memory=True ) - self.assertEqual(result_pinned, expected_pinned) + self.assertEqual(result_pinned, expected_pinned, exact_device=True) self.assertTrue(result_pinned.is_pinned()) finally: torch.use_deterministic_algorithms(deterministic) - def test_my_flatten(self, device): - import libtorch_agnostic + def test_2_10_my_flatten(self, device): + ops = self.ops_2_10 t = torch.randn(2, 3, 4, device=device) - result = libtorch_agnostic.ops.my_flatten(t) - expected = torch.flatten(t) + result = ops.my_flatten(t, 0, 1) + expected = torch.flatten(t, 0, 1) self.assertEqual(result, expected) - result_start = libtorch_agnostic.ops.my_flatten(t, 1) - expected_start = torch.flatten(t, 1) - self.assertEqual(result_start, expected_start) + result_all = ops.my_flatten(t, 0, -1) + expected_all = torch.flatten(t, 0, -1) + self.assertEqual(result_all, expected_all) - result_range = libtorch_agnostic.ops.my_flatten(t, 2, -1) - expected_range = torch.flatten(t, 2, -1) - self.assertEqual(result_range, expected_range) - - def test_my_reshape(self, device): - import libtorch_agnostic + def test_2_10_my_reshape(self, device): + ops = self.ops_2_10 t = torch.randn(2, 3, 4, device=device) - - result = libtorch_agnostic.ops.my_reshape(t, [6, 4]) - expected = torch.reshape(t, [6, 4]) + shape = [6, 4] + result = ops.my_reshape(t, shape) + expected = torch.reshape(t, shape) self.assertEqual(result, expected) - result_infer = libtorch_agnostic.ops.my_reshape(t, [-1, 4]) - expected_infer = torch.reshape(t, [-1, 4]) - self.assertEqual(result_infer, expected_infer) - - result_flat = libtorch_agnostic.ops.my_reshape(t, [-1]) - expected_flat = torch.reshape(t, [-1]) - self.assertEqual(result_flat, expected_flat) - - def test_my_view(self, device): - import libtorch_agnostic + def test_2_10_my_view(self, device): + ops = self.ops_2_10 t = torch.randn(2, 3, 4, device=device) - - result = libtorch_agnostic.ops.my_view(t, [6, 4]) - expected = t.view([6, 4]) + size = [6, 4] + result = ops.my_view(t, size) + expected = t.view(size) self.assertEqual(result, expected) - result_infer = libtorch_agnostic.ops.my_view(t, [-1, 4]) - expected_infer = t.view([-1, 4]) - self.assertEqual(result_infer, expected_infer) - - result_flat = libtorch_agnostic.ops.my_view(t, [-1]) - expected_flat = t.view([-1]) - self.assertEqual(result_flat, expected_flat) + instantiate_device_type_tests(TestLibtorchAgnosticVersioned, globals()) - instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": run_tests() From 3e01a0dfc74f1ac07e858d25a44b190db1f99380 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 13 Nov 2025 21:15:39 -0800 Subject: [PATCH 06/15] Test libtorch_agnostic with TORCH_TARGET_VERSION on target pytorch version [ghstack-poisoned] --- .ci/pytorch/test.sh | 96 +++++++++++++++++++ .github/workflows/pull.yml | 1 + .../test_libtorch_agnostic_versioned.py | 65 ++++++++++--- 3 files changed, 150 insertions(+), 12 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 687ec4b9e0ae4..534e21349f0c5 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1243,6 +1243,100 @@ test_custom_script_ops() { assert_git_not_dirty } +test_libtorch_agnostic_targetting() { + echo "Testing libtorch_agnostic backward compatibility" + + # Unset PYTORCH_TESTING_DEVICE_ONLY_FOR to test both CPU and CUDA + unset PYTORCH_TESTING_DEVICE_ONLY_FOR + + REPO_DIR=$(pwd) + WHEEL_DIR="${REPO_DIR}/test/cpp_extensions/.wheels" + + # Build wheel with current PyTorch (this has TORCH_TARGET_VERSION 2_9_0) + echo "Building 2.9 extension wheel with current PyTorch..." + pushd test/cpp_extensions/libtorch_agnostic_2_9_extension + time python setup.py bdist_wheel + + # Save the wheel + mkdir -p "$WHEEL_DIR" + cp dist/*.whl "$WHEEL_DIR/" + WHEEL_FILE=$(ls "$WHEEL_DIR"/*.whl | head -1) + echo "Built wheel: $(basename "$WHEEL_FILE")" + popd + + # Create venv and install PyTorch 2.9 + python -m venv venv_pytorch_2_9 + # shellcheck disable=SC1091 + . venv_pytorch_2_9/bin/activate + + # Clear PYTHONPATH to avoid using the development PyTorch + echo "Clearing PYTHONPATH to use only venv packages..." + unset PYTHONPATH + + # Upgrade pip to latest version + echo "Upgrading pip to latest version..." + pip install --upgrade pip + pip --version + + echo "Installing PyTorch 2.9..." + + # Install from release channel only + PYTORCH_VERSION="2.9.0" + + # Extract CUDA version from BUILD_ENVIRONMENT (e.g., "cuda12.1" -> "cu121") + if [[ "$BUILD_ENVIRONMENT" =~ cuda([0-9]+)\.([0-9]+) ]]; then + CUDA_MAJOR="${BASH_REMATCH[1]}" + CUDA_MINOR="${BASH_REMATCH[2]}" + CUDA_VERSION="cu${CUDA_MAJOR}${CUDA_MINOR}" + echo " Detected CUDA ${CUDA_MAJOR}.${CUDA_MINOR} from BUILD_ENVIRONMENT, using ${CUDA_VERSION}" + else + # Default to CPU build + CUDA_VERSION="cpu" + echo " No CUDA detected in BUILD_ENVIRONMENT, using CPU build" + fi + + if pip install torch=="${PYTORCH_VERSION}" --index-url https://download.pytorch.org/whl/${CUDA_VERSION}/; then + echo "Installed PyTorch ${PYTORCH_VERSION} from release channel (${CUDA_VERSION})" + else + echo " FAILED to install PyTorch 2.9.0 from release channel" + echo " URL: https://download.pytorch.org/whl/${CUDA_VERSION}/" + deactivate + rm -rf venv_pytorch_2_9 + return 1 + fi + + INSTALLED_VERSION=$(python -c "import torch; print(torch.__version__)" 2>/dev/null || echo "unknown") + echo " Installed version: $INSTALLED_VERSION" + + # Install test dependencies + echo "Installing test dependencies..." + pip install expecttest numpy + + # Install the pre-built wheel + echo "" + echo "Installing pre-built 2.9 extension wheel (built with PyTorch 2.10)..." + pip install "$WHEEL_FILE" + echo "Installed $(basename "$WHEEL_FILE") into PyTorch 2.9 environment" + + # Run tests with PyTorch 2.9 runtime (2.10 tests will be skipped automatically) + echo "" + echo "Running tests with PyTorch 2.9 runtime (using wheel built on PyTorch 2.10)..." + if time python test/cpp_extensions/test_libtorch_agnostic_versioned.py -v; then + echo "" + echo " Wheel built with current torch and TORCH_TARGET_VERSION 2_9_0 works with PyTorch 2.9 runtime!" + else + echo "targeting test failed" + deactivate + rm -rf venv_pytorch_2_9 "$WHEEL_DIR" + return 1 + fi + + deactivate + rm -rf venv_pytorch_2_9 "$WHEEL_DIR" + + assert_git_not_dirty +} + test_jit_hooks() { echo "Testing jit hooks in cpp" HOOK_BUILD="${CUSTOM_TEST_ARTIFACT_BUILD_DIR}/jit-hook-build" @@ -1715,6 +1809,8 @@ elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; elif [[ "${TEST_CONFIG}" == *backward* ]]; then test_forward_backward_compatibility # Do NOT add tests after bc check tests, see its comment. +elif [[ "${TEST_CONFIG}" == *libtorch_agnostic_targetting* ]]; then + test_libtorch_agnostic_targetting elif [[ "${TEST_CONFIG}" == *xla* ]]; then install_torchvision build_xla diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index e5fd10c70db61..47b27b87a93e8 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -265,6 +265,7 @@ jobs: test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, + { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1 }, ]} secrets: inherit diff --git a/test/cpp_extensions/test_libtorch_agnostic_versioned.py b/test/cpp_extensions/test_libtorch_agnostic_versioned.py index cff85479b755c..d7ee526a7be02 100644 --- a/test/cpp_extensions/test_libtorch_agnostic_versioned.py +++ b/test/cpp_extensions/test_libtorch_agnostic_versioned.py @@ -1,6 +1,8 @@ # Owner(s): ["module: cpp"] import math +import re +import unittest from pathlib import Path import torch @@ -19,6 +21,25 @@ ) +def get_pytorch_version(): + """Get the PyTorch version as a tuple (major, minor, patch)""" + version_str = torch.__version__.split("+")[0] # Remove git hash + match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_str) + if match: + return tuple(int(x) for x in match.groups()) + return (2, 10, 0) # Default to 2.10.0 + + +PYTORCH_VERSION = get_pytorch_version() +IS_PYTORCH_2_9 = PYTORCH_VERSION < (2, 10, 0) +IS_PYTORCH_2_10_OR_LATER = PYTORCH_VERSION >= (2, 10, 0) + + +def skipIfPyTorch2_9(reason): + """Skip test if running on PyTorch 2.9""" + return unittest.skipIf(IS_PYTORCH_2_9, reason) + + # TODO: Fix this error in Windows: # LINK : error LNK2001: unresolved external symbol PyInit__C if not IS_WINDOWS: @@ -39,7 +60,9 @@ class TestLibtorchAgnosticVersioned(TestCase): @classmethod def setUpClass(cls): - # Install and import the 2.9 extension + print(f"Running tests with PyTorch {'.'.join(map(str, PYTORCH_VERSION))}") + + # Install and import the 2.9 extension (always needed) try: import libtorch_agnostic_2_9 @@ -53,19 +76,23 @@ def setUpClass(cls): cls.ops_2_9 = libtorch_agnostic_2_9.ops - # Install and import the 2.10 extension - try: - import libtorch_agnostic_2_10 + # Install and import the 2.10 extension (only if on PyTorch 2.10+) + if IS_PYTORCH_2_10_OR_LATER: + try: + import libtorch_agnostic_2_10 - cls.ops_2_10 = libtorch_agnostic_2_10.ops - except ImportError: - extension_root = ( - Path(__file__).parent / "libtorch_agnostic_2_10_extension" - ) - install_cpp_extension(extension_root=extension_root) - import libtorch_agnostic_2_10 + cls.ops_2_10 = libtorch_agnostic_2_10.ops + except ImportError: + extension_root = ( + Path(__file__).parent / "libtorch_agnostic_2_10_extension" + ) + install_cpp_extension(extension_root=extension_root) + import libtorch_agnostic_2_10 - cls.ops_2_10 = libtorch_agnostic_2_10.ops + cls.ops_2_10 = libtorch_agnostic_2_10.ops + else: + print("Skipping 2.10 extension (running on PyTorch 2.9)") + cls.ops_2_10 = None # ============================================================================ # Tests for 2.9 features @@ -373,8 +400,10 @@ def test_2_9_my_new_zeros_dtype_variant(self, device): # ============================================================================ # Tests for 2.10 features (only work with 2.10 extension) + # These tests are skipped when running on PyTorch 2.9 runtime # ============================================================================ + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") def test_2_10_my_copy_(self, device): ops = self.ops_2_10 @@ -386,6 +415,7 @@ def test_2_10_my_copy_(self, device): self.assertEqual(result, expected) self.assertEqual(result.data_ptr(), dst.data_ptr()) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") def test_2_10_my_clone(self, device): ops = self.ops_2_10 @@ -397,6 +427,7 @@ def test_2_10_my_clone(self, device): self.assertNotEqual(result.data_ptr(), expected.data_ptr()) self.assertEqual(result.stride(), expected.stride()) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") def test_2_10_my__foreach_mul_(self, device): ops = self.ops_2_10 @@ -411,6 +442,7 @@ def test_2_10_my__foreach_mul_(self, device): for tensor_t, expected_t in zip(tensors, expected_values): self.assertEqual(tensor_t, expected_t) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") def test_2_10_my__foreach_mul(self, device): ops = self.ops_2_10 @@ -439,6 +471,7 @@ def _make_cuda_tensors(prior_mem): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") def test_2_10_make_tensor_clones_and_call_foreach(self, device): ops = self.ops_2_10 @@ -448,6 +481,7 @@ def test_2_10_make_tensor_clones_and_call_foreach(self, device): self.assertEqual(result[0], t1 * t1) self.assertEqual(result[1], t2 * t2) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") @onlyCUDA def test_2_10_device(self, device): ops = self.ops_2_10 @@ -491,6 +525,7 @@ def test_2_10_device(self, device): ): ops.test_device_set_index(cuda_device, 129) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") @onlyCUDA @deviceCountAtLeast(2) def test_2_10_tensor_device(self, device): @@ -505,6 +540,7 @@ def test_2_10_tensor_device(self, device): t_cuda_1 = torch.randn(2, 3, device="cuda:1") self.assertEqual(ops.test_tensor_device(t_cuda_1), t_cuda_1.device) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") @onlyCPU @xfailIfTorchDynamo def test_2_10_parallel_for(self, device): @@ -525,6 +561,7 @@ def test_2_10_parallel_for(self, device): self.assertEqual(result_values, expected) self.assertEqual(result_thread_ids, torch.arange(expected_num_threads_used)) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") @onlyCPU def test_2_10_get_num_threads(self, device): ops = self.ops_2_10 @@ -533,6 +570,7 @@ def test_2_10_get_num_threads(self, device): expected_num_threads = torch.get_num_threads() self.assertEqual(num_threads, expected_num_threads) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") def test_2_10_my_empty(self, device): ops = self.ops_2_10 @@ -567,6 +605,7 @@ def test_2_10_my_empty(self, device): finally: torch.use_deterministic_algorithms(deterministic) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") def test_2_10_my_flatten(self, device): ops = self.ops_2_10 @@ -579,6 +618,7 @@ def test_2_10_my_flatten(self, device): expected_all = torch.flatten(t, 0, -1) self.assertEqual(result_all, expected_all) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") def test_2_10_my_reshape(self, device): ops = self.ops_2_10 @@ -588,6 +628,7 @@ def test_2_10_my_reshape(self, device): expected = torch.reshape(t, shape) self.assertEqual(result, expected) + @skipIfPyTorch2_9("Requires PyTorch 2.10+ runtime") def test_2_10_my_view(self, device): ops = self.ops_2_10 From 7e517c5aa80498f04ad371361211db16b88878aa Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 13 Nov 2025 22:10:29 -0800 Subject: [PATCH 07/15] Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target pytorch version" [ghstack-poisoned] --- .github/workflows/pull.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 47b27b87a93e8..5200ea00c687f 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -265,7 +265,6 @@ jobs: test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, - { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1 }, ]} secrets: inherit @@ -330,6 +329,7 @@ jobs: test-matrix: | { include: [ { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, + { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1 }, ]} secrets: inherit From fd02d3f8f927c24b826cf00c7b337291b58950a9 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 14 Nov 2025 09:42:31 -0800 Subject: [PATCH 08/15] Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target pytorch version" [ghstack-poisoned] --- .github/workflows/pull.yml | 2 +- .github/workflows/trunk.yml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 5200ea00c687f..51e211a5ad2ad 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -70,6 +70,7 @@ jobs: { config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, + { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit @@ -329,7 +330,6 @@ jobs: test-matrix: | { include: [ { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, - { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1 }, ]} secrets: inherit diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 6ba810c3a9582..667c37727045b 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -83,6 +83,7 @@ jobs: { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, + { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, ]} secrets: inherit From f3897df23707ed50f1d80691c148f95ec3df9628 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 14 Nov 2025 10:16:41 -0800 Subject: [PATCH 09/15] Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target pytorch version" [ghstack-poisoned] --- .ci/pytorch/test.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 534e21349f0c5..27a7ae0997863 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1249,6 +1249,9 @@ test_libtorch_agnostic_targetting() { # Unset PYTORCH_TESTING_DEVICE_ONLY_FOR to test both CPU and CUDA unset PYTORCH_TESTING_DEVICE_ONLY_FOR + # avoid failure due to import xmlrunner + unset TEST_SAVE_XML + REPO_DIR=$(pwd) WHEEL_DIR="${REPO_DIR}/test/cpp_extensions/.wheels" From 9db7e47606125bf248f8e1e475a7bc281db83332 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 14 Nov 2025 10:52:45 -0800 Subject: [PATCH 10/15] Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target pytorch version" [ghstack-poisoned] --- .ci/pytorch/test.sh | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 27a7ae0997863..5693441fbdbe6 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1249,9 +1249,6 @@ test_libtorch_agnostic_targetting() { # Unset PYTORCH_TESTING_DEVICE_ONLY_FOR to test both CPU and CUDA unset PYTORCH_TESTING_DEVICE_ONLY_FOR - # avoid failure due to import xmlrunner - unset TEST_SAVE_XML - REPO_DIR=$(pwd) WHEEL_DIR="${REPO_DIR}/test/cpp_extensions/.wheels" @@ -1313,7 +1310,7 @@ test_libtorch_agnostic_targetting() { # Install test dependencies echo "Installing test dependencies..." - pip install expecttest numpy + pip install expecttest numpy xmlrunner # Install the pre-built wheel echo "" From f816f3dfd1409de86cae84d4ffab556d3e2cda35 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 14 Nov 2025 11:29:43 -0800 Subject: [PATCH 11/15] Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target pytorch version" [ghstack-poisoned] --- .ci/pytorch/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 5693441fbdbe6..6c5f99996ab1a 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1310,7 +1310,7 @@ test_libtorch_agnostic_targetting() { # Install test dependencies echo "Installing test dependencies..." - pip install expecttest numpy xmlrunner + pip install expecttest numpy unittest-xml-reporting # Install the pre-built wheel echo "" From 605d100dde0f1a468e42e9a06cf7cc986e2adc74 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Sun, 16 Nov 2025 17:44:07 -0800 Subject: [PATCH 12/15] Update base for Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target pytorch version" [ghstack-poisoned] --- .../libtorch_agnostic_2_10/csrc/kernel.cpp | 6 - .../libtorch_agnostic_2_10/ops.py | 14 -- .../libtorch_agnostic_2_9/csrc/kernel.cpp | 12 + .../libtorch_agnostic_2_9/ops.py | 211 +----------------- test/cpp_extensions/test_libtorch_agnostic.py | 3 +- 5 files changed, 20 insertions(+), 226 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp index b923c2bc1f4e6..71a1a24bc0b15 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp @@ -185,10 +185,6 @@ Tensor my_empty( return empty(size, dtype, device, pin_memory); } -Tensor my_flatten(Tensor t, int64_t start_dim, int64_t end_dim) { - return flatten(t, start_dim, end_dim); -} - Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef shape) { return reshape(t, shape); } @@ -200,14 +196,12 @@ Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef size) { STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def( "my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor"); - m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor"); m.def("my_reshape(Tensor t, int[] shape) -> Tensor"); m.def("my_view(Tensor t, int[] size) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { m.impl("my_empty", TORCH_BOX(&my_empty)); - m.impl("my_flatten", TORCH_BOX(&my_flatten)); m.impl("my_reshape", TORCH_BOX(&my_reshape)); m.impl("my_view", TORCH_BOX(&my_view)); } diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index c23e8b699552f..42c437ebf755e 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -173,20 +173,6 @@ def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor: ) -def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor: - """ - Flattens the input tensor from start_dim to end_dim into a single dimension. - - Args: - t: Tensor - tensor to flatten - start_dim: int - first dimension to flatten (default: 0) - end_dim: int - last dimension to flatten (default: -1) - - Returns: Tensor - flattened tensor - """ - return torch.ops.libtorch_agnostic_2_10.my_flatten.default(t, start_dim, end_dim) - - def my_reshape(t, shape) -> Tensor: """ Returns a tensor with the same data but different shape. diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp index e32a631ec0920..9541c77a87380 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp @@ -422,3 +422,15 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { } #endif // LAE_USE_CUDA + +Tensor my_flatten(Tensor t, int64_t start_dim, int64_t end_dim) { + return flatten(t, start_dim, end_dim); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("my_flatten", TORCH_BOX(&my_flatten)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py index dd13f55b8e2ab..04a1377836554 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py @@ -215,16 +215,18 @@ def test_default_constructor(defined) -> bool: return torch.ops.libtorch_agnostic_2_9.test_default_constructor.default(defined) -def test_tensor_device(t): +def mv_tensor_accessor(m, v) -> Tensor: """ - Tests Tensor device() method. + Returns matrix-vector product. Args: - t: Tensor - tensor to get device from + m: any 2-D Tensor with shape (N, M) + v: any 1-D Tensor with shape (M,) - Returns: Device - device of the tensor + Returns: + a 1-D Tensor with shape (N,) """ - return torch.ops.libtorch_agnostic_2_9.test_tensor_device.default(t) + return torch.ops.libtorch_agnostic_2_9.mv_tensor_accessor.default(m, v) def my_pad(t) -> Tensor: @@ -347,165 +349,6 @@ def my_new_zeros_dtype_variant(t) -> Tensor: return torch.ops.libtorch_agnostic_2_9.my_new_zeros_dtype_variant.default(t) -def my__foreach_mul_(tensors, others) -> (): - """ - Updates tensors to be the result of pointwise multiplying with others. - - Args: - tensors: list of tensors - others: list of tensors (with the same corresponding shapes as tensors) - - Returns: nothing, tensors is updated in place. - """ - torch.ops.libtorch_agnostic_2_9.my__foreach_mul_.default(tensors, others) - - -def my__foreach_mul(tensors, others) -> list[Tensor]: - """ - Returns a list of tensors that are the results of pointwise multiplying - tensors and others. - - Args: - tensors: list of tensors - others: list of tensors (with the same corresponding shapes as tensors) - - Returns: list of multiplied tensors - """ - return torch.ops.libtorch_agnostic_2_9.my__foreach_mul.default(tensors, others) - - -def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]: - """ - Returns a list of 2 tensors corresponding to the square of the inputs. - - Args: - t1: Tensor - t2: Tensor - - Returns: list of [t1^2, t2^2] - """ - return torch.ops.libtorch_agnostic_2_9.make_tensor_clones_and_call_foreach.default( - t1, t2 - ) - - -def test_device_constructor(is_cuda, index, use_str): - """ - Tests creating a Device from DeviceType and index, or from a string. - - Args: - is_cuda: bool - if True, creates CUDA device; if False, creates CPU device - index: int - device index - use_str: bool - if True, constructs from string; if False, constructs from DeviceType - - Returns: Device - A device with the specified type and index - """ - return torch.ops.libtorch_agnostic_2_9.test_device_constructor.default( - is_cuda, index, use_str - ) - - -def test_device_equality(d1, d2) -> bool: - """ - Tests Device equality operator. - - Args: - d1: Device - first device - d2: Device - second device - - Returns: bool - True if devices are equal - """ - return torch.ops.libtorch_agnostic_2_9.test_device_equality.default(d1, d2) - - -def test_device_set_index(device, index): - """ - Tests Device set_index() method. - - Args: - device: Device - device to modify - index: int - new device index - - Returns: Device - device with updated index - """ - return torch.ops.libtorch_agnostic_2_9.test_device_set_index.default(device, index) - - -def test_device_index(device) -> int: - """ - Tests Device index() method. - - Args: - device: Device - device to query - - Returns: int - device index - """ - return torch.ops.libtorch_agnostic_2_9.test_device_index.default(device) - - -def test_device_is_cuda(device) -> bool: - """ - Tests Device is_cuda() method. - - Args: - device: Device - device to check - - Returns: bool - True if device is CUDA - """ - return torch.ops.libtorch_agnostic_2_9.test_device_is_cuda.default(device) - - -def test_device_is_cpu(device) -> bool: - """ - Tests Device is_cpu() method. - - Args: - device: Device - device to check - - Returns: bool - True if device is CPU - """ - return torch.ops.libtorch_agnostic_2_9.test_device_is_cpu.default(device) - - -def test_parallel_for(size, grain_size) -> Tensor: - """ - Tests the parallel_for functionality by using it to fill a tensor with indices. - Args: - size: int - size of the tensor to create - grain_size: int - grain size for parallel_for - Returns: Tensor - a 1D int64 tensor where each element contains its index - (if multiple threads are used the threadid will be encoded in the upper 32 bits) - """ - return torch.ops.libtorch_agnostic_2_9.test_parallel_for.default(size, grain_size) - - -def test_get_num_threads() -> int: - """ - Tests the get_num_threads functionality by returning the number of threads - for the parallel backend. - - Returns: int - the number of threads for the parallel backend - """ - return torch.ops.libtorch_agnostic_2_9.test_get_num_threads.default() - - -def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor: - """ - Creates an empty tensor with the specified size, dtype, device, and pin_memory. - - Args: - size: list[int] - size of the tensor to create - dtype: ScalarType or None - data type of the tensor - device: Device or None - device on which to create the tensor - pin_memory: bool or None - whether to use pinned memory - - Returns: Tensor - an uninitialized tensor with the specified properties - """ - return torch.ops.libtorch_agnostic_2_9.my_empty.default( - size, dtype, device, pin_memory - ) - - def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor: """ Flattens the input tensor from start_dim to end_dim into a single dimension. @@ -518,43 +361,3 @@ def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor: Returns: Tensor - flattened tensor """ return torch.ops.libtorch_agnostic_2_9.my_flatten.default(t, start_dim, end_dim) - - -def my_reshape(t, shape) -> Tensor: - """ - Returns a tensor with the same data but different shape. - - Args: - t: Tensor - tensor to reshape - shape: list[int] - new shape for the tensor - - Returns: Tensor - reshaped tensor - """ - return torch.ops.libtorch_agnostic_2_9.my_reshape.default(t, shape) - - -def my_view(t, size) -> Tensor: - """ - Returns a new tensor with the same data as the input tensor but of a different shape. - - Args: - t: Tensor - tensor to view - size: list[int] - new size for the tensor - - Returns: Tensor - tensor with new view - """ - return torch.ops.libtorch_agnostic_2_9.my_view.default(t, size) - - -def mv_tensor_accessor(m, v) -> Tensor: - """ - Returns matrix-vector product. - - Args: - m: any 2-D Tensor with shape (N, M) - v: any 1-D Tensor with shape (M,) - - Returns: - a 1-D Tensor with shape (N,) - """ - return torch.ops.libtorch_agnostic_2_9.mv_tensor_accessor.default(m, v) diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index f92948b6fe4bb..2ba1200f230d7 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -632,9 +632,8 @@ def test_my_empty(self, device): finally: torch.use_deterministic_algorithms(deterministic) - @skipIfTorchVersionLessThan(2, 10) def test_my_flatten(self, device): - import libtorch_agnostic_2_10 as libtorch_agnostic + import libtorch_agnostic_2_9 as libtorch_agnostic t = torch.randn(2, 3, 4, device=device) result = libtorch_agnostic.ops.my_flatten(t) From f7054bd29b2a3435f053d96c7f42648a8e7f2efb Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Sun, 16 Nov 2025 21:28:27 -0800 Subject: [PATCH 13/15] Update base for Update on "Test libtorch_agnostic with TORCH_TARGET_VERSION on target pytorch version" [ghstack-poisoned] --- .../libtorch_agnostic_2_10/csrc/kernel.cpp | 2 -- .../libtorch_agnostic_2_10/csrc/kernel.h | 26 ------------------- test/run_test.py | 19 +++++++++----- 3 files changed, 12 insertions(+), 35 deletions(-) delete mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.h diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp index 71a1a24bc0b15..72c78984b5215 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp @@ -1,5 +1,3 @@ -#include "kernel.h" - #include #include #include diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.h b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.h deleted file mode 100644 index 3bbc6d118da52..0000000000000 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.h +++ /dev/null @@ -1,26 +0,0 @@ -#include -#include - -template -using Accessor_cpu = torch::headeronly::HeaderOnlyTensorAccessor; - -#if defined(__CUDACC__) || defined(__HIPCC__) -#define MAYBE_GLOBAL __global__ - -template -using Accessor_cuda = torch::headeronly::HeaderOnlyGenericPackedTensorAccessor; - -#else -#define MAYBE_GLOBAL -#endif - -template