From 07ca2391005471be6f1b8851b7a5fcd96cf91a66 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 5 Aug 2025 05:14:42 +0000 Subject: [PATCH 1/7] Add xla random generator. --- .github/scripts/run_tests.sh | 1 + BUILD | 7 ++- test/cpp/BUILD | 12 ++++ test/cpp/run_tests.sh | 1 + test/cpp/test_xla_generator.cpp | 103 +++++++++++++++++++++++++++++++ torch_xla/csrc/BUILD | 2 + torch_xla/csrc/xla_generator.cpp | 82 ++++++++++++++++++++++++ torch_xla/csrc/xla_generator.h | 56 +++++++++++++++++ 8 files changed, 261 insertions(+), 3 deletions(-) create mode 100644 test/cpp/test_xla_generator.cpp create mode 100644 torch_xla/csrc/xla_generator.cpp create mode 100644 torch_xla/csrc/xla_generator.h diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh index d685cc40ee4..ccdc0b5e3d7 100755 --- a/.github/scripts/run_tests.sh +++ b/.github/scripts/run_tests.sh @@ -55,6 +55,7 @@ function run_torch_xla_cpp_tests() { "test_tensor" # disable test_xla_backend_intf since it is flaky on upstream #"test_xla_backend_intf" + "test_xla_generator" "test_xla_sharding" "test_runtime" "test_status_dont_show_cpp_stacktraces" diff --git a/BUILD b/BUILD index ee4fa07844a..900dfa4bc3b 100644 --- a/BUILD +++ b/BUILD @@ -72,15 +72,16 @@ test_suite( "//test/cpp:test_aten_xla_tensor_4", "//test/cpp:test_aten_xla_tensor_5", "//test/cpp:test_aten_xla_tensor_6", + "//test/cpp:test_debug_macros", "//test/cpp:test_ir", "//test/cpp:test_lazy", "//test/cpp:test_replication", - "//test/cpp:test_tensor", - "//test/cpp:test_xla_sharding", "//test/cpp:test_runtime", "//test/cpp:test_status_dont_show_cpp_stacktraces", "//test/cpp:test_status_show_cpp_stacktraces", - "//test/cpp:test_debug_macros", + "//test/cpp:test_tensor", + "//test/cpp:test_xla_generator", + "//test/cpp:test_xla_sharding", "//torch_xla/csrc/runtime:pjrt_computation_client_test", # "//torch_xla/csrc/runtime:ifrt_computation_client_test", ], diff --git a/test/cpp/BUILD b/test/cpp/BUILD index e752eab4f67..dab678af767 100644 --- a/test/cpp/BUILD +++ b/test/cpp/BUILD @@ -202,3 +202,15 @@ ptxla_cc_test( "@com_google_googletest//:gtest_main", ], ) + +ptxla_cc_test( + name = "test_xla_generator", + srcs = ["test_xla_generator.cpp"], + deps = [ + ":cpp_test_util", + ":torch_xla_test", + "//torch_xla/csrc:tensor", + "//torch_xla/csrc:aten_cuda_functions", + "@com_google_googletest//:gtest_main", + ], +) \ No newline at end of file diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 8c3fea6bcdc..2da0ccb5569 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -100,6 +100,7 @@ if [[ "$RUN_CPP_TESTS" == "cpp_tests" ]]; then # disable test_xla_backend_intf since it is flaky on upstream #"test_xla_backend_intf" "test_xla_sharding" + "test_xla_generator" "test_runtime" "test_status_dont_show_cpp_stacktraces" "test_status_show_cpp_stacktraces" diff --git a/test/cpp/test_xla_generator.cpp b/test/cpp/test_xla_generator.cpp new file mode 100644 index 00000000000..687f7d4eea9 --- /dev/null +++ b/test/cpp/test_xla_generator.cpp @@ -0,0 +1,103 @@ +#include +#include +#include "test/cpp/torch_xla_test.h" +#include "torch_xla/csrc/xla_generator.h" + +namespace torch_xla { +namespace cpp_test { + +// Test fixture for XLAGenerator tests +class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest { + protected: + void SetUp() { + // Create a generator for XLA device 0 + gen_ = at::make_generator(0); + } + + at::Generator gen_; +}; + +TEST_F(XLAGeneratorTest, Constructor) { + // Check that the generator was created for the correct device + ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA); + ASSERT_EQ(gen_.device().index(), 0); + + // Check that the initial seed is 0 + ASSERT_EQ(gen_.current_seed(), 0); +} + +TEST_F(XLAGeneratorTest, Seed) { + // Test setting and getting the current seed + uint64_t seed_val = 12345; + gen_.set_current_seed(seed_val); + ASSERT_EQ(gen_.current_seed(), seed_val); + + // Test the seed() method, which should set a non-deterministic seed + uint64_t old_seed = gen_.current_seed(); + uint64_t new_seed = gen_.seed(); + // The new seed should be different from the old one and set as the current seed + ASSERT_NE(new_seed, old_seed); + ASSERT_EQ(gen_.current_seed(), new_seed); +} + +TEST_F(XLAGeneratorTest, GetAndSetState) { + uint64_t seed_val = 98765; + uint64_t offset_val = 0; + + // Set seed and offset on the original generator + gen_.set_current_seed(seed_val); + gen_.set_offset(offset_val); + + // Get the state from the original generator + at::Tensor state_tensor = gen_.get_state(); + + // Create a new generator + auto new_gen = at::make_generator(1); + ASSERT_NE(new_gen.current_seed(), seed_val); + + // Set the state of the new generator + new_gen.set_state(state_tensor); + + // Verify the state of the new generator + ASSERT_EQ(new_gen.current_seed(), seed_val); + ASSERT_EQ(new_gen.get_offset(), offset_val); +} + +TEST_F(XLAGeneratorTest, SetStateValidation) { + // Test that set_state throws with incorrect tensor properties + auto new_gen = at::make_generator(0); + + // Incorrect size + auto wrong_size_tensor = at::empty({10}, at::kByte); + EXPECT_THROW(new_gen.set_state(wrong_size_tensor), c10::Error); + + // Incorrect dtype + auto wrong_dtype_tensor = at::empty({16}, at::kInt); + EXPECT_THROW(new_gen.set_state(wrong_dtype_tensor), c10::Error); +} + +TEST_F(XLAGeneratorTest, Clone) { + uint64_t seed_val = 1; + uint64_t offset_val = 0; + + // Set state on the original generator + gen_.set_current_seed(seed_val); + gen_.set_offset(offset_val); + + // Clone the generator + auto cloned_gen = gen_.clone(); + + // Verify that the cloned generator has the same state but is a different object + ASSERT_NE(std::addressof(cloned_gen), std::addressof(gen_)); + ASSERT_EQ(cloned_gen.device(), gen_.device()); + ASSERT_EQ(cloned_gen.current_seed(), gen_.current_seed()); + ASSERT_EQ(cloned_gen.get_offset(), offset_val); + + // Modify the original generator's seed and check that the clone is unaffected + gen_.set_current_seed(9999); + ASSERT_EQ(cloned_gen.current_seed(), seed_val); + ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed()); +} + +} // namespace cpp_test +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 31ab65dbbca..a871feaa346 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -64,6 +64,7 @@ ptxla_cc_library( "torch_util.cpp", "view.cpp", "xla_backend_impl.cpp", + "xla_generator.cpp", "xla_graph_executor.cpp", "xla_lower_util.cpp", "xla_op_builder.cpp", @@ -107,6 +108,7 @@ ptxla_cc_library( "torch_util.h", "view.h", "xla_backend_impl.h", + "xla_generator.h", "xla_graph_executor.h", "xla_lower_util.h", "xla_op_builder.h", diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp new file mode 100644 index 00000000000..492f086bdce --- /dev/null +++ b/torch_xla/csrc/xla_generator.cpp @@ -0,0 +1,82 @@ +#include "xla_generator.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index) + : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), DispatchKeySet(c10::DispatchKey::XLA)} { + state_ = c10::make_intrusive(); +} + +XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index, c10::intrusive_ptr state) + : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), DispatchKeySet(c10::DispatchKey::XLA)}, state_(std::move(state)) {} + +DeviceType XLAGeneratorImpl::device_type() { + return DeviceType::XLA; +} + +std::shared_ptr XLAGeneratorImpl::clone() const { + return std::shared_ptr(clone_impl()); +} + +XLAGeneratorImpl* XLAGeneratorImpl::clone_impl() const { + return new XLAGeneratorImpl(device_.index(), state_->clone()); +} + +void XLAGeneratorImpl::set_current_seed(uint64_t seed) { + state_->seed_ = seed; +} + +uint64_t XLAGeneratorImpl::current_seed() const { + return state_->seed_; +} + +uint64_t XLAGeneratorImpl::seed() { + uint64_t random = c10::detail::getNonDeterministicRandom(true); + set_current_seed(random); + return random; +} + +void XLAGeneratorImpl::set_offset(uint64_t offset) { + state_->offset_ = offset; +} + +uint64_t XLAGeneratorImpl::get_offset() const { + return state_->offset_; +} + +/* Serialize the generator state into a CPU tensor. */ +c10::intrusive_ptr XLAGeneratorImpl::get_state() const { + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(uint64_t); + static const size_t total_size = seed_size + offset_size; + + auto state_tensor = at::empty({(int64_t)total_size}, at::TensorOptions().dtype(at::kByte).device(at::kCPU)); + uint8_t* data_ptr = state_tensor.data_ptr(); + memcpy(data_ptr, &state_->seed_, seed_size); + memcpy(data_ptr + seed_size, &state_->offset_, offset_size); + return state_tensor.getIntrusivePtr(); +} + +void XLAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(uint64_t); + static const size_t total_size = seed_size + offset_size; + + TORCH_CHECK(new_state.numel() == total_size, "The given state must be a byte tensor of size ", total_size, ", but was size ", new_state.numel()); + TORCH_CHECK(new_state.dtype() == at::kByte, "The given state must be a byte tensor, but was ", new_state.dtype()); + TORCH_CHECK(new_state.is_cpu(), "The given state must be a CPU tensor"); + + auto new_rng_state = new_state.data_dtype_initialized(); + memcpy(&state_->seed_, new_rng_state, seed_size); + memcpy(&state_->offset_, new_rng_state + seed_size, offset_size); +} + +} // namespace at diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h new file mode 100644 index 00000000000..b8b7dc46e9a --- /dev/null +++ b/torch_xla/csrc/xla_generator.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include + +#include + +namespace at { + +// Holds the actual state variables for the XLA generator. +struct XLAGeneratorState : c10::intrusive_ptr_target { + uint64_t seed_ = 0; + uint64_t offset_ = 0; + + // Constructor + XLAGeneratorState(uint64_t seed = 0, uint64_t offset = 0) + : seed_(seed), offset_(offset) {} + + // Cloning method + c10::intrusive_ptr clone() { + return c10::make_intrusive(seed_, offset_); + } +}; + +struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { + // Constructors + XLAGeneratorImpl(DeviceIndex device_index = -1); + XLAGeneratorImpl(DeviceIndex device_index, c10::intrusive_ptr state); + ~XLAGeneratorImpl() override = default; + + // Cloning support + std::shared_ptr clone() const; + + // --- Core Virtual Methods to Override --- + void set_current_seed(uint64_t seed) override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_offset(uint64_t offset) override; + uint64_t get_offset() const override; + c10::intrusive_ptr get_state() const override; + void set_state(const c10::TensorImpl& new_state) override; + + // --- Additional Methods --- + static c10::DeviceType device_type(); + + private: + // Private clone implementation + XLAGeneratorImpl* clone_impl() const override; + + // The actual state is held in a separate, cloneable object. + c10::intrusive_ptr state_; + +}; + +} // namespace at \ No newline at end of file From 1d67a02d76f9e3931d5a04d3a25b97a0bb13c7d4 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 13 Aug 2025 00:13:27 +0000 Subject: [PATCH 2/7] format cpp files --- test/cpp/test_xla_generator.cpp | 131 ++++++++++++++++--------------- torch_xla/csrc/xla_generator.cpp | 52 ++++++------ torch_xla/csrc/xla_generator.h | 4 +- 3 files changed, 96 insertions(+), 91 deletions(-) diff --git a/test/cpp/test_xla_generator.cpp b/test/cpp/test_xla_generator.cpp index 687f7d4eea9..d45991f72d3 100644 --- a/test/cpp/test_xla_generator.cpp +++ b/test/cpp/test_xla_generator.cpp @@ -1,5 +1,6 @@ #include #include + #include "test/cpp/torch_xla_test.h" #include "torch_xla/csrc/xla_generator.h" @@ -9,94 +10,96 @@ namespace cpp_test { // Test fixture for XLAGenerator tests class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest { protected: - void SetUp() { - // Create a generator for XLA device 0 - gen_ = at::make_generator(0); - } + void SetUp() { + // Create a generator for XLA device 0 + gen_ = at::make_generator(0); + } - at::Generator gen_; + at::Generator gen_; }; TEST_F(XLAGeneratorTest, Constructor) { - // Check that the generator was created for the correct device - ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA); - ASSERT_EQ(gen_.device().index(), 0); + // Check that the generator was created for the correct device + ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA); + ASSERT_EQ(gen_.device().index(), 0); - // Check that the initial seed is 0 - ASSERT_EQ(gen_.current_seed(), 0); + // Check that the initial seed is 0 + ASSERT_EQ(gen_.current_seed(), 0); } TEST_F(XLAGeneratorTest, Seed) { - // Test setting and getting the current seed - uint64_t seed_val = 12345; - gen_.set_current_seed(seed_val); - ASSERT_EQ(gen_.current_seed(), seed_val); - - // Test the seed() method, which should set a non-deterministic seed - uint64_t old_seed = gen_.current_seed(); - uint64_t new_seed = gen_.seed(); - // The new seed should be different from the old one and set as the current seed - ASSERT_NE(new_seed, old_seed); - ASSERT_EQ(gen_.current_seed(), new_seed); + // Test setting and getting the current seed + uint64_t seed_val = 12345; + gen_.set_current_seed(seed_val); + ASSERT_EQ(gen_.current_seed(), seed_val); + + // Test the seed() method, which should set a non-deterministic seed + uint64_t old_seed = gen_.current_seed(); + uint64_t new_seed = gen_.seed(); + // The new seed should be different from the old one and set as the current + // seed + ASSERT_NE(new_seed, old_seed); + ASSERT_EQ(gen_.current_seed(), new_seed); } TEST_F(XLAGeneratorTest, GetAndSetState) { - uint64_t seed_val = 98765; - uint64_t offset_val = 0; + uint64_t seed_val = 98765; + uint64_t offset_val = 0; - // Set seed and offset on the original generator - gen_.set_current_seed(seed_val); - gen_.set_offset(offset_val); + // Set seed and offset on the original generator + gen_.set_current_seed(seed_val); + gen_.set_offset(offset_val); - // Get the state from the original generator - at::Tensor state_tensor = gen_.get_state(); + // Get the state from the original generator + at::Tensor state_tensor = gen_.get_state(); - // Create a new generator - auto new_gen = at::make_generator(1); - ASSERT_NE(new_gen.current_seed(), seed_val); + // Create a new generator + auto new_gen = at::make_generator(1); + ASSERT_NE(new_gen.current_seed(), seed_val); - // Set the state of the new generator - new_gen.set_state(state_tensor); + // Set the state of the new generator + new_gen.set_state(state_tensor); - // Verify the state of the new generator - ASSERT_EQ(new_gen.current_seed(), seed_val); - ASSERT_EQ(new_gen.get_offset(), offset_val); + // Verify the state of the new generator + ASSERT_EQ(new_gen.current_seed(), seed_val); + ASSERT_EQ(new_gen.get_offset(), offset_val); } TEST_F(XLAGeneratorTest, SetStateValidation) { - // Test that set_state throws with incorrect tensor properties - auto new_gen = at::make_generator(0); + // Test that set_state throws with incorrect tensor properties + auto new_gen = at::make_generator(0); - // Incorrect size - auto wrong_size_tensor = at::empty({10}, at::kByte); - EXPECT_THROW(new_gen.set_state(wrong_size_tensor), c10::Error); + // Incorrect size + auto wrong_size_tensor = at::empty({10}, at::kByte); + EXPECT_THROW(new_gen.set_state(wrong_size_tensor), c10::Error); - // Incorrect dtype - auto wrong_dtype_tensor = at::empty({16}, at::kInt); - EXPECT_THROW(new_gen.set_state(wrong_dtype_tensor), c10::Error); + // Incorrect dtype + auto wrong_dtype_tensor = at::empty({16}, at::kInt); + EXPECT_THROW(new_gen.set_state(wrong_dtype_tensor), c10::Error); } TEST_F(XLAGeneratorTest, Clone) { - uint64_t seed_val = 1; - uint64_t offset_val = 0; - - // Set state on the original generator - gen_.set_current_seed(seed_val); - gen_.set_offset(offset_val); - - // Clone the generator - auto cloned_gen = gen_.clone(); - - // Verify that the cloned generator has the same state but is a different object - ASSERT_NE(std::addressof(cloned_gen), std::addressof(gen_)); - ASSERT_EQ(cloned_gen.device(), gen_.device()); - ASSERT_EQ(cloned_gen.current_seed(), gen_.current_seed()); - ASSERT_EQ(cloned_gen.get_offset(), offset_val); - - // Modify the original generator's seed and check that the clone is unaffected - gen_.set_current_seed(9999); - ASSERT_EQ(cloned_gen.current_seed(), seed_val); - ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed()); + uint64_t seed_val = 1; + uint64_t offset_val = 0; + + // Set state on the original generator + gen_.set_current_seed(seed_val); + gen_.set_offset(offset_val); + + // Clone the generator + auto cloned_gen = gen_.clone(); + + // Verify that the cloned generator has the same state but is a different + // object + ASSERT_NE(std::addressof(cloned_gen), std::addressof(gen_)); + ASSERT_EQ(cloned_gen.device(), gen_.device()); + ASSERT_EQ(cloned_gen.current_seed(), gen_.current_seed()); + ASSERT_EQ(cloned_gen.get_offset(), offset_val); + + // Modify the original generator's seed and check that the clone is unaffected + gen_.set_current_seed(9999); + ASSERT_EQ(cloned_gen.current_seed(), seed_val); + ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed()); } } // namespace cpp_test diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 492f086bdce..5d0a7c15866 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -1,26 +1,30 @@ #include "xla_generator.h" + +#include #include #include -#include -#include -#include #include +#include #include +#include + #include namespace at { XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index) - : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), DispatchKeySet(c10::DispatchKey::XLA)} { + : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), + DispatchKeySet(c10::DispatchKey::XLA)} { state_ = c10::make_intrusive(); } -XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index, c10::intrusive_ptr state) - : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), DispatchKeySet(c10::DispatchKey::XLA)}, state_(std::move(state)) {} +XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index, + c10::intrusive_ptr state) + : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), + DispatchKeySet(c10::DispatchKey::XLA)}, + state_(std::move(state)) {} -DeviceType XLAGeneratorImpl::device_type() { - return DeviceType::XLA; -} +DeviceType XLAGeneratorImpl::device_type() { return DeviceType::XLA; } std::shared_ptr XLAGeneratorImpl::clone() const { return std::shared_ptr(clone_impl()); @@ -30,13 +34,9 @@ XLAGeneratorImpl* XLAGeneratorImpl::clone_impl() const { return new XLAGeneratorImpl(device_.index(), state_->clone()); } -void XLAGeneratorImpl::set_current_seed(uint64_t seed) { - state_->seed_ = seed; -} +void XLAGeneratorImpl::set_current_seed(uint64_t seed) { state_->seed_ = seed; } -uint64_t XLAGeneratorImpl::current_seed() const { - return state_->seed_; -} +uint64_t XLAGeneratorImpl::current_seed() const { return state_->seed_; } uint64_t XLAGeneratorImpl::seed() { uint64_t random = c10::detail::getNonDeterministicRandom(true); @@ -44,13 +44,9 @@ uint64_t XLAGeneratorImpl::seed() { return random; } -void XLAGeneratorImpl::set_offset(uint64_t offset) { - state_->offset_ = offset; -} +void XLAGeneratorImpl::set_offset(uint64_t offset) { state_->offset_ = offset; } -uint64_t XLAGeneratorImpl::get_offset() const { - return state_->offset_; -} +uint64_t XLAGeneratorImpl::get_offset() const { return state_->offset_; } /* Serialize the generator state into a CPU tensor. */ c10::intrusive_ptr XLAGeneratorImpl::get_state() const { @@ -58,7 +54,9 @@ c10::intrusive_ptr XLAGeneratorImpl::get_state() const { static const size_t offset_size = sizeof(uint64_t); static const size_t total_size = seed_size + offset_size; - auto state_tensor = at::empty({(int64_t)total_size}, at::TensorOptions().dtype(at::kByte).device(at::kCPU)); + auto state_tensor = + at::empty({(int64_t)total_size}, + at::TensorOptions().dtype(at::kByte).device(at::kCPU)); uint8_t* data_ptr = state_tensor.data_ptr(); memcpy(data_ptr, &state_->seed_, seed_size); memcpy(data_ptr + seed_size, &state_->offset_, offset_size); @@ -70,8 +68,12 @@ void XLAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { static const size_t offset_size = sizeof(uint64_t); static const size_t total_size = seed_size + offset_size; - TORCH_CHECK(new_state.numel() == total_size, "The given state must be a byte tensor of size ", total_size, ", but was size ", new_state.numel()); - TORCH_CHECK(new_state.dtype() == at::kByte, "The given state must be a byte tensor, but was ", new_state.dtype()); + TORCH_CHECK(new_state.numel() == total_size, + "The given state must be a byte tensor of size ", total_size, + ", but was size ", new_state.numel()); + TORCH_CHECK(new_state.dtype() == at::kByte, + "The given state must be a byte tensor, but was ", + new_state.dtype()); TORCH_CHECK(new_state.is_cpu(), "The given state must be a CPU tensor"); auto new_rng_state = new_state.data_dtype_initialized(); @@ -79,4 +81,4 @@ void XLAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { memcpy(&state_->offset_, new_rng_state + seed_size, offset_size); } -} // namespace at +} // namespace at diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h index b8b7dc46e9a..330d3286120 100644 --- a/torch_xla/csrc/xla_generator.h +++ b/torch_xla/csrc/xla_generator.h @@ -26,7 +26,8 @@ struct XLAGeneratorState : c10::intrusive_ptr_target { struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { // Constructors XLAGeneratorImpl(DeviceIndex device_index = -1); - XLAGeneratorImpl(DeviceIndex device_index, c10::intrusive_ptr state); + XLAGeneratorImpl(DeviceIndex device_index, + c10::intrusive_ptr state); ~XLAGeneratorImpl() override = default; // Cloning support @@ -50,7 +51,6 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { // The actual state is held in a separate, cloneable object. c10::intrusive_ptr state_; - }; } // namespace at \ No newline at end of file From 79ff99bce64b34b0e7d6d52b4d1998c17cd44567 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 20 Oct 2025 03:15:34 +0000 Subject: [PATCH 3/7] Add helper functions `getDefaultXLAGenerator` and `createXLAGenerator` to XLA random number generator --- torch_xla/csrc/xla_generator.cpp | 89 ++++++++++++++++++++++++++++++++ torch_xla/csrc/xla_generator.h | 11 +++- 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 5d0a7c15866..5102be5df4e 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -6,9 +6,98 @@ #include #include #include +#include #include +#include + +// XLA headers +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/aten_xla_bridge.h" #include +#include +#include + +namespace at { + +namespace detail { + +namespace { + +// Total number of XLA devices in the system. +static int64_t num_xla_devices; + +// Ensures default_gens_xla is initialized once. +static std::deque xla_gens_init_flag; + +// Default, global XLA generators, one per XLA device. +static std::vector default_gens_xla; + +/* + * Populates the global variables related to XLA generators + * Warning: this function must only be called once! + */ +static void initXLAGenVector() { + // Ensures we only call deviceCount only once. + static bool num_xla_device_init_flag [[maybe_unused]] = []() { + // Get local num of XLA devices + auto maybe_client = torch_xla::runtime::GetComputationClient(); + if (!maybe_client.ok()) { + // If runtime client initialization failed, default to 1 device + num_xla_devices = 1; + } else { + auto* client = maybe_client.value(); + num_xla_devices = static_cast(client->GetNumDevices()); + } + xla_gens_init_flag.resize(num_xla_devices); + default_gens_xla.resize(num_xla_devices); + return true; + }(); +} + +} // anonymous namespace + +/** + * PyTorch maintains a collection of default generators that get + * initialized once. The purpose of these default generators is to + * maintain a global running state of the pseudo random number generation, + * when a user does not explicitly mention any generator. + * getDefaultXLAGenerator gets the default generator for a particular + * XLA device. + */ +const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) { + initXLAGenVector(); + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = 0; // Default to device 0 for XLA + } else { + TORCH_CHECK(idx >= 0 && idx < num_xla_devices); + } + c10::call_once(xla_gens_init_flag[idx], [&] { + default_gens_xla[idx] = at::make_generator(idx); + default_gens_xla[idx].seed(); + }); + return default_gens_xla[idx]; +} + +/** + * Utility to create a XLAGeneratorImpl. Returns a shared_ptr + */ +at::Generator createXLAGenerator(c10::DeviceIndex device_index) { + initXLAGenVector(); + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = torch_xla::bridge::GetCurrentAtenDevice().index(); // Use current XLA device + } + TORCH_CHECK(idx >= 0 && idx < num_xla_devices, "The device_index is invalid."); + auto gen = at::make_generator(idx); + auto xla_gen = at::check_generator(gen); + xla_gen->set_current_seed(c10::default_rng_seed_val); + return gen; +} + +} // namespace detail +} // namespace at namespace at { diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h index 330d3286120..62621f7c37c 100644 --- a/torch_xla/csrc/xla_generator.h +++ b/torch_xla/csrc/xla_generator.h @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include @@ -53,4 +55,11 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { c10::intrusive_ptr state_; }; -} // namespace at \ No newline at end of file +namespace detail { + +const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index = -1); +at::Generator createXLAGenerator(c10::DeviceIndex device_index = -1); + +} // namespace detail + +} // namespace at From 79d4b4236aebd2f4b7624237015f30ca4be4b15f Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 20 Oct 2025 03:17:51 +0000 Subject: [PATCH 4/7] implement `XLAHooks` and register it to PyTorch when loaded. --- torch_xla/csrc/BUILD | 19 +++++++ torch_xla/csrc/xla_hooks.cpp | 99 ++++++++++++++++++++++++++++++++++++ torch_xla/csrc/xla_hooks.h | 40 +++++++++++++++ 3 files changed, 158 insertions(+) create mode 100644 torch_xla/csrc/xla_hooks.cpp create mode 100644 torch_xla/csrc/xla_hooks.h diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 8132ae73316..9c8455be184 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -270,6 +270,7 @@ ptxla_cc_library( ":status", ":tensor", ":version", + ":xla_hooks", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:pjrt_computation_client", "//torch_xla/csrc/runtime:metrics", @@ -374,3 +375,21 @@ cc_library( "@com_google_absl//absl/status:statusor", ], ) + +ptxla_cc_library( + name = "xla_hooks", + srcs = [ + "xla_hooks.cpp", + ], + hdrs = [ + "xla_hooks.h", + ], + deps = [ + "//torch_xla/csrc:device", + "//torch_xla/csrc:tensor", + "//torch_xla/csrc/runtime:computation_client", + "//torch_xla/csrc/runtime", + "//torch_xla/csrc/runtime:xla_util", + ], +) + diff --git a/torch_xla/csrc/xla_hooks.cpp b/torch_xla/csrc/xla_hooks.cpp new file mode 100644 index 00000000000..257dc7677fe --- /dev/null +++ b/torch_xla/csrc/xla_hooks.cpp @@ -0,0 +1,99 @@ +#include "xla_hooks.h" + +#include +#include + +// PyTorch integration headers +#include +#include +#include +#include +#include +#include +#include + +// XLA headers +#include "xla_generator.h" +#include "xla_backend_impl.h" +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/runtime.h" + + +namespace torch_xla::detail { + +void XLAHooks::init() const { + C10_LOG_API_USAGE_ONCE("aten.init.xla"); + + // Initialize XLA backend - this registers XLA functions and sets up + // the backend infrastructure + torch_xla::InitXlaBackend(); +} + +bool XLAHooks::hasXLA() const { + return isAvailable(); +} + +bool XLAHooks::isAvailable() const { + try { + return deviceCount() > 0; + } catch (...) { + // If device enumeration fails, XLA is not available + return false; + } +} + +std::string XLAHooks::showConfig() const { + std::ostringstream oss; + oss << "XLA Backend Configuration:\n"; + oss << " - XLA devices available: " << deviceCount() << "\n"; + return oss.str(); +} + +c10::DeviceIndex XLAHooks::deviceCount() const { + auto maybe_client = torch_xla::runtime::GetComputationClient(); + if (!maybe_client.ok()) { + // If runtime client initialization failed, return 0 devices + return 0; + } + + auto* client = maybe_client.value(); + return static_cast(client->GetNumDevices()); +} + +c10::DeviceIndex XLAHooks::getCurrentDevice() const { + return bridge::GetCurrentAtenDevice().index(); +} + +bool XLAHooks::hasPrimaryContext(c10::DeviceIndex device_index) const { + TORCH_CHECK(false, "hasPrimaryContext is not implemented."); +} + +bool XLAHooks::isPinnedPtr(const void* data) const { + TORCH_CHECK(false, "isPinnedPtr is not implemented."); +} + +c10::Allocator* XLAHooks::getPinnedMemoryAllocator() const { + TORCH_CHECK(false, "getPinnedMemoryAllocator is not implemented."); +} + +c10::Device XLAHooks::getDeviceFromPtr(void* data) const { + TORCH_CHECK(false, "getDeviceFromPtr is not implemented."); +} + +const at::Generator& XLAHooks::getDefaultGenerator(c10::DeviceIndex device_index) const { + return at::detail::getDefaultXLAGenerator(device_index); +} + +at::Generator XLAHooks::getNewGenerator(c10::DeviceIndex device_index) const { + // Create and return a new XLA generator using the make_generator template function + return at::make_generator(device_index); +} + +} // namespace torch_xla::detail + +// Register XLA hooks with PyTorch on module load +namespace at { +REGISTER_XLA_HOOKS(torch_xla::detail::XLAHooks) +} // namespace at diff --git a/torch_xla/csrc/xla_hooks.h b/torch_xla/csrc/xla_hooks.h new file mode 100644 index 00000000000..f56c039a8a9 --- /dev/null +++ b/torch_xla/csrc/xla_hooks.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +// PyTorch integration headers +#include +#include +#include +#include +#include + +namespace torch_xla::detail { + +// XLA hooks implementation following PyTorch patterns +struct XLAHooks : public at::XLAHooksInterface { + XLAHooks(const at::XLAHooksArgs& args) {} + + // Core accelerator interface methods + void init() const override; + bool hasXLA() const override; + bool isAvailable() const override; + bool isBuilt() const override { return true; } + std::string showConfig() const override; + + // Device management + c10::DeviceIndex deviceCount() const override; + c10::DeviceIndex getCurrentDevice() const override; + bool hasPrimaryContext(c10::DeviceIndex device_index) const override; + + // Memory management + bool isPinnedPtr(const void* data) const override; + c10::Allocator* getPinnedMemoryAllocator() const override; + c10::Device getDeviceFromPtr(void* data) const override; + + // Generator methods + const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index = -1) const override; + at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override; +}; + +} // namespace torch_xla::detail From 914d708989aac940caa5a8b6c5ce15a8705f82bc Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 20 Oct 2025 03:39:41 +0000 Subject: [PATCH 5/7] format --- torch_xla/csrc/xla_generator.cpp | 24 +++++++++++++----------- torch_xla/csrc/xla_generator.h | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 5102be5df4e..56aa2fe3bf4 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -5,19 +5,19 @@ #include #include #include -#include #include -#include +#include #include +#include // XLA headers -#include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/aten_xla_bridge.h" - #include #include #include +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/runtime/computation_client.h" + namespace at { namespace detail { @@ -55,7 +55,7 @@ static void initXLAGenVector() { }(); } -} // anonymous namespace +} // anonymous namespace /** * PyTorch maintains a collection of default generators that get @@ -69,7 +69,7 @@ const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) { initXLAGenVector(); c10::DeviceIndex idx = device_index; if (idx == -1) { - idx = 0; // Default to device 0 for XLA + idx = 0; // Default to device 0 for XLA } else { TORCH_CHECK(idx >= 0 && idx < num_xla_devices); } @@ -87,17 +87,19 @@ at::Generator createXLAGenerator(c10::DeviceIndex device_index) { initXLAGenVector(); c10::DeviceIndex idx = device_index; if (idx == -1) { - idx = torch_xla::bridge::GetCurrentAtenDevice().index(); // Use current XLA device + idx = torch_xla::bridge::GetCurrentAtenDevice() + .index(); // Use current XLA device } - TORCH_CHECK(idx >= 0 && idx < num_xla_devices, "The device_index is invalid."); + TORCH_CHECK(idx >= 0 && idx < num_xla_devices, + "The device_index is invalid."); auto gen = at::make_generator(idx); auto xla_gen = at::check_generator(gen); xla_gen->set_current_seed(c10::default_rng_seed_val); return gen; } -} // namespace detail -} // namespace at +} // namespace detail +} // namespace at namespace at { diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h index 62621f7c37c..0d0173157df 100644 --- a/torch_xla/csrc/xla_generator.h +++ b/torch_xla/csrc/xla_generator.h @@ -60,6 +60,6 @@ namespace detail { const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index = -1); at::Generator createXLAGenerator(c10::DeviceIndex device_index = -1); -} // namespace detail +} // namespace detail } // namespace at From 1224531fd0aad36177cc3b765a840cc4bf39ed2c Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 20 Oct 2025 03:42:57 +0000 Subject: [PATCH 6/7] Revert "implement `XLAHooks` and register it to PyTorch when loaded." This reverts commit 79d4b4236aebd2f4b7624237015f30ca4be4b15f. --- torch_xla/csrc/BUILD | 19 ------- torch_xla/csrc/xla_hooks.cpp | 99 ------------------------------------ torch_xla/csrc/xla_hooks.h | 40 --------------- 3 files changed, 158 deletions(-) delete mode 100644 torch_xla/csrc/xla_hooks.cpp delete mode 100644 torch_xla/csrc/xla_hooks.h diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 9c8455be184..8132ae73316 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -270,7 +270,6 @@ ptxla_cc_library( ":status", ":tensor", ":version", - ":xla_hooks", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:pjrt_computation_client", "//torch_xla/csrc/runtime:metrics", @@ -375,21 +374,3 @@ cc_library( "@com_google_absl//absl/status:statusor", ], ) - -ptxla_cc_library( - name = "xla_hooks", - srcs = [ - "xla_hooks.cpp", - ], - hdrs = [ - "xla_hooks.h", - ], - deps = [ - "//torch_xla/csrc:device", - "//torch_xla/csrc:tensor", - "//torch_xla/csrc/runtime:computation_client", - "//torch_xla/csrc/runtime", - "//torch_xla/csrc/runtime:xla_util", - ], -) - diff --git a/torch_xla/csrc/xla_hooks.cpp b/torch_xla/csrc/xla_hooks.cpp deleted file mode 100644 index 257dc7677fe..00000000000 --- a/torch_xla/csrc/xla_hooks.cpp +++ /dev/null @@ -1,99 +0,0 @@ -#include "xla_hooks.h" - -#include -#include - -// PyTorch integration headers -#include -#include -#include -#include -#include -#include -#include - -// XLA headers -#include "xla_generator.h" -#include "xla_backend_impl.h" -#include "torch_xla/csrc/aten_xla_bridge.h" -#include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/runtime.h" - - -namespace torch_xla::detail { - -void XLAHooks::init() const { - C10_LOG_API_USAGE_ONCE("aten.init.xla"); - - // Initialize XLA backend - this registers XLA functions and sets up - // the backend infrastructure - torch_xla::InitXlaBackend(); -} - -bool XLAHooks::hasXLA() const { - return isAvailable(); -} - -bool XLAHooks::isAvailable() const { - try { - return deviceCount() > 0; - } catch (...) { - // If device enumeration fails, XLA is not available - return false; - } -} - -std::string XLAHooks::showConfig() const { - std::ostringstream oss; - oss << "XLA Backend Configuration:\n"; - oss << " - XLA devices available: " << deviceCount() << "\n"; - return oss.str(); -} - -c10::DeviceIndex XLAHooks::deviceCount() const { - auto maybe_client = torch_xla::runtime::GetComputationClient(); - if (!maybe_client.ok()) { - // If runtime client initialization failed, return 0 devices - return 0; - } - - auto* client = maybe_client.value(); - return static_cast(client->GetNumDevices()); -} - -c10::DeviceIndex XLAHooks::getCurrentDevice() const { - return bridge::GetCurrentAtenDevice().index(); -} - -bool XLAHooks::hasPrimaryContext(c10::DeviceIndex device_index) const { - TORCH_CHECK(false, "hasPrimaryContext is not implemented."); -} - -bool XLAHooks::isPinnedPtr(const void* data) const { - TORCH_CHECK(false, "isPinnedPtr is not implemented."); -} - -c10::Allocator* XLAHooks::getPinnedMemoryAllocator() const { - TORCH_CHECK(false, "getPinnedMemoryAllocator is not implemented."); -} - -c10::Device XLAHooks::getDeviceFromPtr(void* data) const { - TORCH_CHECK(false, "getDeviceFromPtr is not implemented."); -} - -const at::Generator& XLAHooks::getDefaultGenerator(c10::DeviceIndex device_index) const { - return at::detail::getDefaultXLAGenerator(device_index); -} - -at::Generator XLAHooks::getNewGenerator(c10::DeviceIndex device_index) const { - // Create and return a new XLA generator using the make_generator template function - return at::make_generator(device_index); -} - -} // namespace torch_xla::detail - -// Register XLA hooks with PyTorch on module load -namespace at { -REGISTER_XLA_HOOKS(torch_xla::detail::XLAHooks) -} // namespace at diff --git a/torch_xla/csrc/xla_hooks.h b/torch_xla/csrc/xla_hooks.h deleted file mode 100644 index f56c039a8a9..00000000000 --- a/torch_xla/csrc/xla_hooks.h +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once - -#include - -// PyTorch integration headers -#include -#include -#include -#include -#include - -namespace torch_xla::detail { - -// XLA hooks implementation following PyTorch patterns -struct XLAHooks : public at::XLAHooksInterface { - XLAHooks(const at::XLAHooksArgs& args) {} - - // Core accelerator interface methods - void init() const override; - bool hasXLA() const override; - bool isAvailable() const override; - bool isBuilt() const override { return true; } - std::string showConfig() const override; - - // Device management - c10::DeviceIndex deviceCount() const override; - c10::DeviceIndex getCurrentDevice() const override; - bool hasPrimaryContext(c10::DeviceIndex device_index) const override; - - // Memory management - bool isPinnedPtr(const void* data) const override; - c10::Allocator* getPinnedMemoryAllocator() const override; - c10::Device getDeviceFromPtr(void* data) const override; - - // Generator methods - const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index = -1) const override; - at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override; -}; - -} // namespace torch_xla::detail From 80b2078302ccc600c96d688fbf7132efa5339094 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 21 Oct 2025 03:41:56 +0000 Subject: [PATCH 7/7] Add missing include --- torch_xla/csrc/xla_generator.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 56aa2fe3bf4..e7d2115552a 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -17,6 +17,7 @@ #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/runtime.h" namespace at {