Skip to content
Open
92 changes: 92 additions & 0 deletions torch_xla/csrc/xla_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,102 @@
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/CallOnce.h>
#include <c10/util/intrusive_ptr.h>

// XLA headers
#include <cstring>
#include <deque>
#include <vector>

#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/runtime.h"

namespace at {

namespace detail {

namespace {

// Total number of XLA devices in the system.
static int64_t num_xla_devices;

// Ensures default_gens_xla is initialized once.
static std::deque<c10::once_flag> xla_gens_init_flag;

// Default, global XLA generators, one per XLA device.
static std::vector<at::Generator> default_gens_xla;

/*
* Populates the global variables related to XLA generators
* Warning: this function must only be called once!
*/
static void initXLAGenVector() {
// Ensures we only call deviceCount only once.
static bool num_xla_device_init_flag [[maybe_unused]] = []() {
// Get local num of XLA devices
auto maybe_client = torch_xla::runtime::GetComputationClient();
if (!maybe_client.ok()) {
// If runtime client initialization failed, default to 1 device
num_xla_devices = 1;
} else {
auto* client = maybe_client.value();
num_xla_devices = static_cast<int64_t>(client->GetNumDevices());
}
xla_gens_init_flag.resize(num_xla_devices);
default_gens_xla.resize(num_xla_devices);
return true;
}();
}

} // anonymous namespace

/**
* PyTorch maintains a collection of default generators that get
* initialized once. The purpose of these default generators is to
* maintain a global running state of the pseudo random number generation,
* when a user does not explicitly mention any generator.
* getDefaultXLAGenerator gets the default generator for a particular
* XLA device.
*/
const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) {
initXLAGenVector();
c10::DeviceIndex idx = device_index;
if (idx == -1) {
idx = 0; // Default to device 0 for XLA
} else {
TORCH_CHECK(idx >= 0 && idx < num_xla_devices);
}
c10::call_once(xla_gens_init_flag[idx], [&] {
default_gens_xla[idx] = at::make_generator<XLAGeneratorImpl>(idx);
default_gens_xla[idx].seed();
});
return default_gens_xla[idx];
}

/**
* Utility to create a XLAGeneratorImpl. Returns a shared_ptr
*/
at::Generator createXLAGenerator(c10::DeviceIndex device_index) {
initXLAGenVector();
c10::DeviceIndex idx = device_index;
if (idx == -1) {
idx = torch_xla::bridge::GetCurrentAtenDevice()
.index(); // Use current XLA device
}
TORCH_CHECK(idx >= 0 && idx < num_xla_devices,
"The device_index is invalid.");
auto gen = at::make_generator<XLAGeneratorImpl>(idx);
auto xla_gen = at::check_generator<XLAGeneratorImpl>(gen);
xla_gen->set_current_seed(c10::default_rng_seed_val);
return gen;
}

} // namespace detail
} // namespace at

namespace at {

Expand Down
11 changes: 10 additions & 1 deletion torch_xla/csrc/xla_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <ATen/core/Generator.h>
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/util/intrusive_ptr.h>

#include <cstdint>
Expand Down Expand Up @@ -53,4 +55,11 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl {
c10::intrusive_ptr<XLAGeneratorState> state_;
};

} // namespace at
namespace detail {

const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index = -1);
at::Generator createXLAGenerator(c10::DeviceIndex device_index = -1);

} // namespace detail

} // namespace at
Loading