From 04e416863b7006ced0b986d213716f2830e70fbb Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 20 Oct 2025 03:17:51 +0000 Subject: [PATCH 1/2] implement `XLAHooks` and register it to PyTorch when loaded. --- torch_xla/csrc/BUILD | 19 +++++++ torch_xla/csrc/xla_hooks.cpp | 99 ++++++++++++++++++++++++++++++++++++ torch_xla/csrc/xla_hooks.h | 40 +++++++++++++++ 3 files changed, 158 insertions(+) create mode 100644 torch_xla/csrc/xla_hooks.cpp create mode 100644 torch_xla/csrc/xla_hooks.h diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 8132ae733160..9c8455be184a 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -270,6 +270,7 @@ ptxla_cc_library( ":status", ":tensor", ":version", + ":xla_hooks", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:pjrt_computation_client", "//torch_xla/csrc/runtime:metrics", @@ -374,3 +375,21 @@ cc_library( "@com_google_absl//absl/status:statusor", ], ) + +ptxla_cc_library( + name = "xla_hooks", + srcs = [ + "xla_hooks.cpp", + ], + hdrs = [ + "xla_hooks.h", + ], + deps = [ + "//torch_xla/csrc:device", + "//torch_xla/csrc:tensor", + "//torch_xla/csrc/runtime:computation_client", + "//torch_xla/csrc/runtime", + "//torch_xla/csrc/runtime:xla_util", + ], +) + diff --git a/torch_xla/csrc/xla_hooks.cpp b/torch_xla/csrc/xla_hooks.cpp new file mode 100644 index 000000000000..257dc7677fef --- /dev/null +++ b/torch_xla/csrc/xla_hooks.cpp @@ -0,0 +1,99 @@ +#include "xla_hooks.h" + +#include +#include + +// PyTorch integration headers +#include +#include +#include +#include +#include +#include +#include + +// XLA headers +#include "xla_generator.h" +#include "xla_backend_impl.h" +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/runtime.h" + + +namespace torch_xla::detail { + +void XLAHooks::init() const { + C10_LOG_API_USAGE_ONCE("aten.init.xla"); + + // Initialize XLA backend - this registers XLA functions and sets up + // the backend infrastructure + torch_xla::InitXlaBackend(); +} + +bool XLAHooks::hasXLA() const { + return isAvailable(); +} + +bool XLAHooks::isAvailable() const { + try { + return deviceCount() > 0; + } catch (...) { + // If device enumeration fails, XLA is not available + return false; + } +} + +std::string XLAHooks::showConfig() const { + std::ostringstream oss; + oss << "XLA Backend Configuration:\n"; + oss << " - XLA devices available: " << deviceCount() << "\n"; + return oss.str(); +} + +c10::DeviceIndex XLAHooks::deviceCount() const { + auto maybe_client = torch_xla::runtime::GetComputationClient(); + if (!maybe_client.ok()) { + // If runtime client initialization failed, return 0 devices + return 0; + } + + auto* client = maybe_client.value(); + return static_cast(client->GetNumDevices()); +} + +c10::DeviceIndex XLAHooks::getCurrentDevice() const { + return bridge::GetCurrentAtenDevice().index(); +} + +bool XLAHooks::hasPrimaryContext(c10::DeviceIndex device_index) const { + TORCH_CHECK(false, "hasPrimaryContext is not implemented."); +} + +bool XLAHooks::isPinnedPtr(const void* data) const { + TORCH_CHECK(false, "isPinnedPtr is not implemented."); +} + +c10::Allocator* XLAHooks::getPinnedMemoryAllocator() const { + TORCH_CHECK(false, "getPinnedMemoryAllocator is not implemented."); +} + +c10::Device XLAHooks::getDeviceFromPtr(void* data) const { + TORCH_CHECK(false, "getDeviceFromPtr is not implemented."); +} + +const at::Generator& XLAHooks::getDefaultGenerator(c10::DeviceIndex device_index) const { + return at::detail::getDefaultXLAGenerator(device_index); +} + +at::Generator XLAHooks::getNewGenerator(c10::DeviceIndex device_index) const { + // Create and return a new XLA generator using the make_generator template function + return at::make_generator(device_index); +} + +} // namespace torch_xla::detail + +// Register XLA hooks with PyTorch on module load +namespace at { +REGISTER_XLA_HOOKS(torch_xla::detail::XLAHooks) +} // namespace at diff --git a/torch_xla/csrc/xla_hooks.h b/torch_xla/csrc/xla_hooks.h new file mode 100644 index 000000000000..f56c039a8a95 --- /dev/null +++ b/torch_xla/csrc/xla_hooks.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +// PyTorch integration headers +#include +#include +#include +#include +#include + +namespace torch_xla::detail { + +// XLA hooks implementation following PyTorch patterns +struct XLAHooks : public at::XLAHooksInterface { + XLAHooks(const at::XLAHooksArgs& args) {} + + // Core accelerator interface methods + void init() const override; + bool hasXLA() const override; + bool isAvailable() const override; + bool isBuilt() const override { return true; } + std::string showConfig() const override; + + // Device management + c10::DeviceIndex deviceCount() const override; + c10::DeviceIndex getCurrentDevice() const override; + bool hasPrimaryContext(c10::DeviceIndex device_index) const override; + + // Memory management + bool isPinnedPtr(const void* data) const override; + c10::Allocator* getPinnedMemoryAllocator() const override; + c10::Device getDeviceFromPtr(void* data) const override; + + // Generator methods + const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index = -1) const override; + at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override; +}; + +} // namespace torch_xla::detail From 539ece41e39991753abcde24419e66216533bba1 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 20 Oct 2025 03:47:58 +0000 Subject: [PATCH 2/2] format --- torch_xla/csrc/xla_hooks.cpp | 31 +++++++++++++++---------------- torch_xla/csrc/xla_hooks.h | 22 ++++++++++++---------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/torch_xla/csrc/xla_hooks.cpp b/torch_xla/csrc/xla_hooks.cpp index 257dc7677fef..b36e6d5abf69 100644 --- a/torch_xla/csrc/xla_hooks.cpp +++ b/torch_xla/csrc/xla_hooks.cpp @@ -1,39 +1,36 @@ #include "xla_hooks.h" -#include #include +#include // PyTorch integration headers #include #include -#include #include #include -#include +#include #include +#include // XLA headers -#include "xla_generator.h" -#include "xla_backend_impl.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" - +#include "xla_backend_impl.h" +#include "xla_generator.h" namespace torch_xla::detail { void XLAHooks::init() const { C10_LOG_API_USAGE_ONCE("aten.init.xla"); - - // Initialize XLA backend - this registers XLA functions and sets up + + // Initialize XLA backend - this registers XLA functions and sets up // the backend infrastructure torch_xla::InitXlaBackend(); } -bool XLAHooks::hasXLA() const { - return isAvailable(); -} +bool XLAHooks::hasXLA() const { return isAvailable(); } bool XLAHooks::isAvailable() const { try { @@ -57,7 +54,7 @@ c10::DeviceIndex XLAHooks::deviceCount() const { // If runtime client initialization failed, return 0 devices return 0; } - + auto* client = maybe_client.value(); return static_cast(client->GetNumDevices()); } @@ -82,18 +79,20 @@ c10::Device XLAHooks::getDeviceFromPtr(void* data) const { TORCH_CHECK(false, "getDeviceFromPtr is not implemented."); } -const at::Generator& XLAHooks::getDefaultGenerator(c10::DeviceIndex device_index) const { +const at::Generator& XLAHooks::getDefaultGenerator( + c10::DeviceIndex device_index) const { return at::detail::getDefaultXLAGenerator(device_index); } at::Generator XLAHooks::getNewGenerator(c10::DeviceIndex device_index) const { - // Create and return a new XLA generator using the make_generator template function + // Create and return a new XLA generator using the make_generator template + // function return at::make_generator(device_index); } -} // namespace torch_xla::detail +} // namespace torch_xla::detail // Register XLA hooks with PyTorch on module load namespace at { REGISTER_XLA_HOOKS(torch_xla::detail::XLAHooks) -} // namespace at +} // namespace at diff --git a/torch_xla/csrc/xla_hooks.h b/torch_xla/csrc/xla_hooks.h index f56c039a8a95..66323e967e21 100644 --- a/torch_xla/csrc/xla_hooks.h +++ b/torch_xla/csrc/xla_hooks.h @@ -3,38 +3,40 @@ #include // PyTorch integration headers +#include #include -#include #include +#include #include -#include namespace torch_xla::detail { // XLA hooks implementation following PyTorch patterns struct XLAHooks : public at::XLAHooksInterface { XLAHooks(const at::XLAHooksArgs& args) {} - + // Core accelerator interface methods void init() const override; bool hasXLA() const override; bool isAvailable() const override; bool isBuilt() const override { return true; } std::string showConfig() const override; - + // Device management c10::DeviceIndex deviceCount() const override; c10::DeviceIndex getCurrentDevice() const override; bool hasPrimaryContext(c10::DeviceIndex device_index) const override; - - // Memory management + + // Memory management bool isPinnedPtr(const void* data) const override; c10::Allocator* getPinnedMemoryAllocator() const override; c10::Device getDeviceFromPtr(void* data) const override; - + // Generator methods - const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index = -1) const override; - at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override; + const at::Generator& getDefaultGenerator( + c10::DeviceIndex device_index = -1) const override; + at::Generator getNewGenerator( + c10::DeviceIndex device_index = -1) const override; }; -} // namespace torch_xla::detail +} // namespace torch_xla::detail