diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 5d0a7c15866..e7d2115552a 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -5,10 +5,102 @@ #include #include #include +#include #include +#include #include +// XLA headers #include +#include +#include + +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/runtime.h" + +namespace at { + +namespace detail { + +namespace { + +// Total number of XLA devices in the system. +static int64_t num_xla_devices; + +// Ensures default_gens_xla is initialized once. +static std::deque xla_gens_init_flag; + +// Default, global XLA generators, one per XLA device. +static std::vector default_gens_xla; + +/* + * Populates the global variables related to XLA generators + * Warning: this function must only be called once! + */ +static void initXLAGenVector() { + // Ensures we only call deviceCount only once. + static bool num_xla_device_init_flag [[maybe_unused]] = []() { + // Get local num of XLA devices + auto maybe_client = torch_xla::runtime::GetComputationClient(); + if (!maybe_client.ok()) { + // If runtime client initialization failed, default to 1 device + num_xla_devices = 1; + } else { + auto* client = maybe_client.value(); + num_xla_devices = static_cast(client->GetNumDevices()); + } + xla_gens_init_flag.resize(num_xla_devices); + default_gens_xla.resize(num_xla_devices); + return true; + }(); +} + +} // anonymous namespace + +/** + * PyTorch maintains a collection of default generators that get + * initialized once. The purpose of these default generators is to + * maintain a global running state of the pseudo random number generation, + * when a user does not explicitly mention any generator. + * getDefaultXLAGenerator gets the default generator for a particular + * XLA device. + */ +const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) { + initXLAGenVector(); + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = 0; // Default to device 0 for XLA + } else { + TORCH_CHECK(idx >= 0 && idx < num_xla_devices); + } + c10::call_once(xla_gens_init_flag[idx], [&] { + default_gens_xla[idx] = at::make_generator(idx); + default_gens_xla[idx].seed(); + }); + return default_gens_xla[idx]; +} + +/** + * Utility to create a XLAGeneratorImpl. Returns a shared_ptr + */ +at::Generator createXLAGenerator(c10::DeviceIndex device_index) { + initXLAGenVector(); + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = torch_xla::bridge::GetCurrentAtenDevice() + .index(); // Use current XLA device + } + TORCH_CHECK(idx >= 0 && idx < num_xla_devices, + "The device_index is invalid."); + auto gen = at::make_generator(idx); + auto xla_gen = at::check_generator(gen); + xla_gen->set_current_seed(c10::default_rng_seed_val); + return gen; +} + +} // namespace detail +} // namespace at namespace at { diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h index 330d3286120..0d0173157df 100644 --- a/torch_xla/csrc/xla_generator.h +++ b/torch_xla/csrc/xla_generator.h @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include @@ -53,4 +55,11 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { c10::intrusive_ptr state_; }; -} // namespace at \ No newline at end of file +namespace detail { + +const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index = -1); +at::Generator createXLAGenerator(c10::DeviceIndex device_index = -1); + +} // namespace detail + +} // namespace at