From e3c6453a62174cd92ca10fcbf8c9b8edf3c5f99d Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 13 Nov 2025 16:50:40 -0300 Subject: [PATCH 1/8] Make `DeviceType` constructor not throw. --- torch_xla/csrc/device.cpp | 84 ++++++++++++++++++++++++--------------- torch_xla/csrc/device.h | 54 +++++++++++++++++++------ 2 files changed, 92 insertions(+), 46 deletions(-) diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index ead6ec9dea5..cdf608774cc 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -3,12 +3,17 @@ #include #include #include +#include +#include +#include +#include +#include #include "absl/status/status.h" +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" -#include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/status.h" @@ -20,45 +25,58 @@ namespace { static bool spmd_config_is_locked = false; static bool use_virtual_device = false; -} // namespace - -std::string DeviceType::XlaDeviceTypeToString(XlaDeviceType hw_type) { - XLA_CHECK(hw_type != XlaDeviceType::PLUGIN) << "PLUGIN type name unknown"; - - switch (hw_type) { - case XlaDeviceType::CPU: - return "CPU"; - case XlaDeviceType::CUDA: - return "CUDA"; - case XlaDeviceType::TPU: - return "TPU"; - case XlaDeviceType::NEURON: - return "NEURON"; - case XlaDeviceType::SPMD: - return "SPMD"; - default: - XLA_ERROR() << "Invalid device type"; - } +constexpr std::size_t kNativeXlaDeviceTypeNumber = 5; +constexpr std::array, + kNativeXlaDeviceTypeNumber> + kNativeXlaDeviceTypeWithName = {{ +#define XLA_DEVICE_NAME_PAIR(name, _) {XlaDeviceType::name, #name}, + XLA_FOR_ALL_NATIVE_DEVICE_TYPES_(XLA_DEVICE_NAME_PAIR) +#undef XLA_DEVICE_NAME_PAIR + }}; + +std::string_view XlaDeviceTypeToString(XlaDeviceType type) { + int8_t value = static_cast(type); + // This check makes sure we are not dealing with: + // + // 1. Invalid XlaDeviceType (i.e. result of conversion of a number bigger + // than PLUGIN -- the last enum value) + // + // 2. The XlaDeviceType::PLUGIN enum, since it's not considered a "native" + // device type + ABSL_CHECK(value < kNativeXlaDeviceTypeNumber); + return kNativeXlaDeviceTypeWithName[value].second; } -XlaDeviceType DeviceType::StringToXlaDeviceType(const std::string& type_name) { - if (type_name == "SPMD") { - return XlaDeviceType::SPMD; - } else if (type_name == "TPU") { - return XlaDeviceType::TPU; - } else if (type_name == "CPU") { - return XlaDeviceType::CPU; - } else if (type_name == "CUDA") { - return XlaDeviceType::CUDA; - } else if (type_name == "NEURON") { - return XlaDeviceType::NEURON; +XlaDeviceType StringToXlaDeviceType(std::string_view type_name) { + std::array, + kNativeXlaDeviceTypeNumber>::const_iterator it = + std::find_if(kNativeXlaDeviceTypeWithName.begin(), + kNativeXlaDeviceTypeWithName.end(), + [=](const std::pair& pair) { + return pair.second == type_name; + }); + if (it != kNativeXlaDeviceTypeWithName.end()) { + return it->first; } - return XlaDeviceType::PLUGIN; } +} // namespace + +DeviceType::DeviceType(XlaDeviceType xla_device_type) + : torch::lazy::BackendDeviceType(static_cast(xla_device_type)), + type_name_() {} + +DeviceType::DeviceType(std::string_view type_name) + : torch::lazy::BackendDeviceType( + static_cast(StringToXlaDeviceType(type_name))), + type_name_(type_name) {} + std::string DeviceType::toString() const { - return absl::StrCat(type_name_, ":"); + std::string_view str = (getType() == XlaDeviceType::PLUGIN) + ? type_name_ + : XlaDeviceTypeToString(getType()); + return absl::StrCat(str, ":"); } XlaDeviceType DeviceType::getType() const { diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index d0fd9423c53..b57eead9a3d 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -8,32 +8,60 @@ #include #include -#include "torch_xla/csrc/runtime/util.h" - namespace torch_xla { +// Convenient macro for applying another macro to all native device types. +// +// Add new device type +// =================== +// +// Add a new line to the macro below: +// +// _(, ) +// +// Where is the enum of the given device, and is the +// previous number plus 1. +// +#define XLA_FOR_ALL_NATIVE_DEVICE_TYPES_(_) \ + _(CPU, 0) \ + _(CUDA, 1) \ + _(TPU, 2) \ + _(NEURON, 3) \ + _(SPMD, 4) + // TODO(yeounoh) `SPMD` is a virtual device that defers data `TransferToDevice` // until after the paritioning pass. This avoids transfering the full input // tensor to the device. -enum class XlaDeviceType { CPU, CUDA, TPU, NEURON, SPMD, PLUGIN }; +enum class XlaDeviceType : int8_t { +#define XLA_DECLARE_ENUM(name, value) name = value, + XLA_FOR_ALL_NATIVE_DEVICE_TYPES_(XLA_DECLARE_ENUM) +#undef XLA_DECLARE_ENUM + + // Plugin is not considered a native device type. + // It has a special treatment for some functions. + PLUGIN, +}; struct DeviceType : public torch::lazy::BackendDeviceType { - DeviceType(XlaDeviceType xla_device_type) - : torch::lazy::BackendDeviceType(static_cast(xla_device_type)), - type_name_(XlaDeviceTypeToString(xla_device_type)) {} - DeviceType(const std::string& type_name) - : torch::lazy::BackendDeviceType( - static_cast(StringToXlaDeviceType(type_name))), - type_name_(type_name) {} + DeviceType(XlaDeviceType xla_device_type); + + // Constructor parses the `type_name` into an `XlaDeviceType`. + // + // This should in 2 cases: + // + // 1. When using non-native device types. + // Although `XlaDeviceType::PLUGIN` will be used, the `type_name` + // parameter will be stored internally. + // + // 2. When parsing string device types. + // + DeviceType(std::string_view type_name); std::string toString() const override; XlaDeviceType getType() const; private: std::string type_name_; - - static std::string XlaDeviceTypeToString(XlaDeviceType hw_type); - static XlaDeviceType StringToXlaDeviceType(const std::string& type_name); }; // Parses the given `device_spec` into a new `BackendDevice`. From 857328e15d7726576a1b11548388279947236995 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 14 Nov 2025 18:04:34 -0300 Subject: [PATCH 2/8] Rename function. --- torch_xla/csrc/device.cpp | 4 ++-- torch_xla/csrc/device.h | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index cdf608774cc..9ef6b272a5e 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -34,7 +34,7 @@ constexpr std::array, #undef XLA_DEVICE_NAME_PAIR }}; -std::string_view XlaDeviceTypeToString(XlaDeviceType type) { +std::string_view NativeXlaDeviceTypeToString(XlaDeviceType type) { int8_t value = static_cast(type); // This check makes sure we are not dealing with: // @@ -75,7 +75,7 @@ DeviceType::DeviceType(std::string_view type_name) std::string DeviceType::toString() const { std::string_view str = (getType() == XlaDeviceType::PLUGIN) ? type_name_ - : XlaDeviceTypeToString(getType()); + : NativeXlaDeviceTypeToString(getType()); return absl::StrCat(str, ":"); } diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index b57eead9a3d..c18ffb6ab1e 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -47,14 +47,13 @@ struct DeviceType : public torch::lazy::BackendDeviceType { // Constructor parses the `type_name` into an `XlaDeviceType`. // - // This should in 2 cases: + // This should be used in 2 cases: // // 1. When using non-native device types. // Although `XlaDeviceType::PLUGIN` will be used, the `type_name` // parameter will be stored internally. // // 2. When parsing string device types. - // DeviceType(std::string_view type_name); std::string toString() const override; From 4e14cb9cde80dbab412e3a02140f6ddd886235d1 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 14 Nov 2025 18:07:17 -0300 Subject: [PATCH 3/8] Fix lint issues. --- torch_xla/csrc/device.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index 9ef6b272a5e..da04d161cee 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -1,16 +1,14 @@ #include "torch_xla/csrc/device.h" -#include -#include -#include #include #include +#include #include #include #include -#include "absl/status/status.h" #include "absl/log/absl_check.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" From 15054f4041db56fde71da6bb03a5fc9aa2d860a9 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 14 Nov 2025 20:59:46 -0300 Subject: [PATCH 4/8] Fix include mess. --- torch_xla/csrc/device.cpp | 7 +++++++ torch_xla/csrc/device.h | 4 +++- torch_xla/csrc/layout_manager.cpp | 3 +-- torch_xla/csrc/layout_manager.h | 4 ++-- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index da04d161cee..da11397fe5e 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -2,13 +2,20 @@ #include #include +#include +#include +#include #include #include #include #include +#include + +#include #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index c18ffb6ab1e..95f6bc9c121 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -1,13 +1,15 @@ #ifndef XLA_TORCH_XLA_CSRC_DEVICE_H_ #define XLA_TORCH_XLA_CSRC_DEVICE_H_ -#include #include +#include #include #include #include +#include "absl/status/statusor.h" + namespace torch_xla { // Convenient macro for applying another macro to all native device types. diff --git a/torch_xla/csrc/layout_manager.cpp b/torch_xla/csrc/layout_manager.cpp index 36fb5d8e390..e1ac00a890a 100644 --- a/torch_xla/csrc/layout_manager.cpp +++ b/torch_xla/csrc/layout_manager.cpp @@ -2,12 +2,12 @@ #include #include -#include #include #include #include #include +#include #include #include "absl/strings/str_split.h" @@ -17,7 +17,6 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/runtime/util.h" namespace torch_xla { namespace { diff --git a/torch_xla/csrc/layout_manager.h b/torch_xla/csrc/layout_manager.h index 540c6d221e1..c220e7fd64d 100644 --- a/torch_xla/csrc/layout_manager.h +++ b/torch_xla/csrc/layout_manager.h @@ -3,7 +3,7 @@ #include "absl/types/span.h" #include "xla/shape.h" -#include "xla/types.h" +#include "xla/xla_data.pb.h" #include "torch_xla/csrc/device.h" @@ -26,4 +26,4 @@ xla::Shape MakeArrayShapeFromDimensions( } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_LAYOUT_MANAGER_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_LAYOUT_MANAGER_H_ From e07404202d3996ee9a2d2cdfdf2e15ed92362cb5 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 14 Nov 2025 22:12:43 -0300 Subject: [PATCH 5/8] Refactor. --- test/cpp/test_device.cpp | 61 +++++++++++++++++++++++++++++++++++---- torch_xla/csrc/device.cpp | 15 +++++++++- 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/test/cpp/test_device.cpp b/test/cpp/test_device.cpp index d63b8164349..79189b53af5 100644 --- a/test/cpp/test_device.cpp +++ b/test/cpp/test_device.cpp @@ -1,22 +1,27 @@ #include +#include + +#include +#include +#include #include "absl/strings/str_cat.h" #include "torch_xla/csrc/device.h" +namespace torch_xla { + static void CheckFormatError(const std::string& spec) { - absl::StatusOr r = - torch_xla::SafeParseDeviceString(spec); - EXPECT_FALSE(r.ok()); + absl::StatusOr r = SafeParseDeviceString(spec); + ASSERT_FALSE(r.ok()); EXPECT_EQ(r.status().message(), absl::StrCat("expected the device string `", spec, "` to be in the format: `:`.")); } static void CheckIndexParseError(const std::string& spec) { - absl::StatusOr r = - torch_xla::SafeParseDeviceString(spec); - EXPECT_FALSE(r.ok()); + absl::StatusOr r = SafeParseDeviceString(spec); + ASSERT_FALSE(r.ok()); EXPECT_EQ( r.status().message(), absl::StrCat("error while parsing the device spec `", spec, "`: stoi")); @@ -33,3 +38,47 @@ TEST(DeviceTest, ParseDeviceStringIndexParseError) { CheckIndexParseError("xla:xla"); CheckIndexParseError("xla:x11"); } + +static void CheckDeviceTypeConstructionWithString( + XlaDeviceType xla_device_type, std::string_view device_type_str) { + DeviceType device_type(device_type_str); + EXPECT_EQ(device_type.getType(), xla_device_type); + EXPECT_EQ(device_type.toString(), absl::StrCat(device_type_str, ":")); +} + +TEST(DeviceTest, ConstructNativeDeviceTypeWithString) { +#define XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING(type, _) \ + CheckDeviceTypeConstructionWithString(XlaDeviceType::type, #type); + XLA_FOR_ALL_NATIVE_DEVICE_TYPES_( + XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING) +#undef XLA_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING +} + +TEST(DeviceTest, ConstructPluginDeviceTypeWithString) { + DeviceType device_type("OTHER"); + EXPECT_EQ(device_type.getType(), XlaDeviceType::PLUGIN); + EXPECT_EQ(device_type.toString(), "OTHER:"); +} + +static void CheckDeviceTypeConstructionWithEnum( + XlaDeviceType xla_device_type, std::string_view device_type_str) { + DeviceType device_type(xla_device_type); + ASSERT_EQ(device_type.getType(), xla_device_type); + EXPECT_EQ(device_type.toString(), absl::StrCat(device_type_str, ":")); +} + +TEST(DeviceTest, ConstructNativeDeviceTypeWithEnum) { +#define XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_ENUM(type, _) \ + CheckDeviceTypeConstructionWithEnum(XlaDeviceType::type, #type); + XLA_FOR_ALL_NATIVE_DEVICE_TYPES_( + XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_ENUM) +#undef XLA_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING +} + +TEST(DeviceTest, ConstructPluginDeviceTypeWithEnumError) { + EXPECT_DEATH({ DeviceType device_type(XlaDeviceType::PLUGIN); }, + absl::StrCat("invalid native XlaDeviceType value: ", + static_cast(XlaDeviceType::PLUGIN))); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index da11397fe5e..74a6f0fb7c7 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -39,6 +39,19 @@ constexpr std::array, #undef XLA_DEVICE_NAME_PAIR }}; +absl::Status CheckIsNativeXlaDeviceType(int8_t value) { + if (value < 0 || value >= kNativeXlaDeviceTypeNumber) { + return XLA_ERROR_WITH_LOCATION(absl::InternalError( + absl::StrCat("invalid native XlaDeviceType value: ", value, + " (casted to int). It should be non-negative, less than ", + kNativeXlaDeviceTypeNumber, + " (number of native XlaDeviceType). This shouldn't be " + "called for XlaDeviceType::PLUGIN (", + static_cast(XlaDeviceType::PLUGIN), ")."))); + } + return absl::OkStatus(); +} + std::string_view NativeXlaDeviceTypeToString(XlaDeviceType type) { int8_t value = static_cast(type); // This check makes sure we are not dealing with: @@ -48,7 +61,7 @@ std::string_view NativeXlaDeviceTypeToString(XlaDeviceType type) { // // 2. The XlaDeviceType::PLUGIN enum, since it's not considered a "native" // device type - ABSL_CHECK(value < kNativeXlaDeviceTypeNumber); + XLA_CHECK_OK(CheckIsNativeXlaDeviceType(value)); return kNativeXlaDeviceTypeWithName[value].second; } From 220a25e7fd29ee657ef7dfb309db8d435a759e3e Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 18 Nov 2025 15:29:19 -0300 Subject: [PATCH 6/8] Fix test. --- torch_xla/csrc/device.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index 74a6f0fb7c7..27eca0a5f8c 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -83,7 +83,9 @@ XlaDeviceType StringToXlaDeviceType(std::string_view type_name) { DeviceType::DeviceType(XlaDeviceType xla_device_type) : torch::lazy::BackendDeviceType(static_cast(xla_device_type)), - type_name_() {} + type_name_() { + XLA_CHECK_OK(CheckIsNativeXlaDeviceType(type)); +} DeviceType::DeviceType(std::string_view type_name) : torch::lazy::BackendDeviceType( From 51d7c3b15cd65b8dc8a1378cd018fd81f3eda19d Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 18 Nov 2025 18:54:08 -0300 Subject: [PATCH 7/8] Fix lint issues. --- torch_xla/csrc/layout_manager.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/layout_manager.h b/torch_xla/csrc/layout_manager.h index c220e7fd64d..e270d2b8cbd 100644 --- a/torch_xla/csrc/layout_manager.h +++ b/torch_xla/csrc/layout_manager.h @@ -9,16 +9,17 @@ namespace torch_xla { -// Creates a minor-to-major layout from given dimensions. The dynamic_dimensions -// slice should be either empty, or of the same size as dimensions. +// Creates a minor-to-major layout from given dimensions. The +// dynamic_dimensions slice should be either empty, or of the same size as +// dimensions. xla::Shape MakeTorchTensorLayout(absl::Span dimensions, absl::Span dynamic_dimensions, xla::PrimitiveType type); // Create an XLA shape with the given dimensions and type, suitable to be used -// in the specified device type. The type of device can affect the choice of the -// XLA layout. The dynamic_dimensions slice should be either empty, or of the -// same size as dimensions. +// in the specified device type. The type of device can affect the choice of +// the XLA layout. The dynamic_dimensions slice should be either empty, or of +// the same size as dimensions. xla::Shape MakeArrayShapeFromDimensions( absl::Span dimensions, absl::Span dynamic_dimensions, xla::PrimitiveType type, From dd8e41798f867ace7a01c387375545b9e57a0e6f Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 19 Nov 2025 15:20:37 -0300 Subject: [PATCH 8/8] Address review. --- test/cpp/test_device.cpp | 24 +++++++++++++----------- torch_xla/csrc/device.cpp | 39 ++++++++++++++++++--------------------- torch_xla/csrc/device.h | 32 ++------------------------------ 3 files changed, 33 insertions(+), 62 deletions(-) diff --git a/test/cpp/test_device.cpp b/test/cpp/test_device.cpp index 79189b53af5..79cce566b3b 100644 --- a/test/cpp/test_device.cpp +++ b/test/cpp/test_device.cpp @@ -1,10 +1,12 @@ #include -#include #include #include #include +#include + +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "torch_xla/csrc/device.h" @@ -47,11 +49,11 @@ static void CheckDeviceTypeConstructionWithString( } TEST(DeviceTest, ConstructNativeDeviceTypeWithString) { -#define XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING(type, _) \ - CheckDeviceTypeConstructionWithString(XlaDeviceType::type, #type); - XLA_FOR_ALL_NATIVE_DEVICE_TYPES_( - XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING) -#undef XLA_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING + CheckDeviceTypeConstructionWithString(XlaDeviceType::CPU, "CPU"); + CheckDeviceTypeConstructionWithString(XlaDeviceType::CUDA, "CUDA"); + CheckDeviceTypeConstructionWithString(XlaDeviceType::TPU, "TPU"); + CheckDeviceTypeConstructionWithString(XlaDeviceType::NEURON, "NEURON"); + CheckDeviceTypeConstructionWithString(XlaDeviceType::SPMD, "SPMD"); } TEST(DeviceTest, ConstructPluginDeviceTypeWithString) { @@ -68,11 +70,11 @@ static void CheckDeviceTypeConstructionWithEnum( } TEST(DeviceTest, ConstructNativeDeviceTypeWithEnum) { -#define XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_ENUM(type, _) \ - CheckDeviceTypeConstructionWithEnum(XlaDeviceType::type, #type); - XLA_FOR_ALL_NATIVE_DEVICE_TYPES_( - XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_ENUM) -#undef XLA_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING + CheckDeviceTypeConstructionWithEnum(XlaDeviceType::CPU, "CPU"); + CheckDeviceTypeConstructionWithEnum(XlaDeviceType::CUDA, "CUDA"); + CheckDeviceTypeConstructionWithEnum(XlaDeviceType::TPU, "TPU"); + CheckDeviceTypeConstructionWithEnum(XlaDeviceType::NEURON, "NEURON"); + CheckDeviceTypeConstructionWithEnum(XlaDeviceType::SPMD, "SPMD"); } TEST(DeviceTest, ConstructPluginDeviceTypeWithEnumError) { diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index 27eca0a5f8c..c016b60d53d 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -5,15 +5,14 @@ #include #include #include +#include #include #include #include -#include #include #include -#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -30,14 +29,13 @@ namespace { static bool spmd_config_is_locked = false; static bool use_virtual_device = false; -constexpr std::size_t kNativeXlaDeviceTypeNumber = 5; -constexpr std::array, - kNativeXlaDeviceTypeNumber> - kNativeXlaDeviceTypeWithName = {{ -#define XLA_DEVICE_NAME_PAIR(name, _) {XlaDeviceType::name, #name}, - XLA_FOR_ALL_NATIVE_DEVICE_TYPES_(XLA_DEVICE_NAME_PAIR) -#undef XLA_DEVICE_NAME_PAIR - }}; +constexpr int8_t kNativeXlaDeviceTypeNumber = + static_cast(XlaDeviceType::PLUGIN); + +// The elements in this array should match the order in the XlaDeviceType enum +// declaration. So, if you modify one of them, make sure to keep them in sync. +constexpr std::array + kNativeXlaDeviceTypeNames = {"CPU", "CUDA", "TPU", "NEURON", "SPMD"}; absl::Status CheckIsNativeXlaDeviceType(int8_t value) { if (value < 0 || value >= kNativeXlaDeviceTypeNumber) { @@ -62,21 +60,20 @@ std::string_view NativeXlaDeviceTypeToString(XlaDeviceType type) { // 2. The XlaDeviceType::PLUGIN enum, since it's not considered a "native" // device type XLA_CHECK_OK(CheckIsNativeXlaDeviceType(value)); - return kNativeXlaDeviceTypeWithName[value].second; + return kNativeXlaDeviceTypeNames[value]; } XlaDeviceType StringToXlaDeviceType(std::string_view type_name) { - std::array, - kNativeXlaDeviceTypeNumber>::const_iterator it = - std::find_if(kNativeXlaDeviceTypeWithName.begin(), - kNativeXlaDeviceTypeWithName.end(), - [=](const std::pair& pair) { - return pair.second == type_name; - }); - if (it != kNativeXlaDeviceTypeWithName.end()) { - return it->first; + std::array::const_iterator it = + std::find(kNativeXlaDeviceTypeNames.begin(), + kNativeXlaDeviceTypeNames.end(), type_name); + + if (it == kNativeXlaDeviceTypeNames.end()) { + return XlaDeviceType::PLUGIN; } - return XlaDeviceType::PLUGIN; + + std::size_t index = std::distance(kNativeXlaDeviceTypeNames.begin(), it); + return static_cast(index); } } // namespace diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index 95f6bc9c121..f6b8f151ad8 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -1,48 +1,20 @@ #ifndef XLA_TORCH_XLA_CSRC_DEVICE_H_ #define XLA_TORCH_XLA_CSRC_DEVICE_H_ +#include #include #include #include -#include -#include #include "absl/status/statusor.h" namespace torch_xla { -// Convenient macro for applying another macro to all native device types. -// -// Add new device type -// =================== -// -// Add a new line to the macro below: -// -// _(, ) -// -// Where is the enum of the given device, and is the -// previous number plus 1. -// -#define XLA_FOR_ALL_NATIVE_DEVICE_TYPES_(_) \ - _(CPU, 0) \ - _(CUDA, 1) \ - _(TPU, 2) \ - _(NEURON, 3) \ - _(SPMD, 4) - // TODO(yeounoh) `SPMD` is a virtual device that defers data `TransferToDevice` // until after the paritioning pass. This avoids transfering the full input // tensor to the device. -enum class XlaDeviceType : int8_t { -#define XLA_DECLARE_ENUM(name, value) name = value, - XLA_FOR_ALL_NATIVE_DEVICE_TYPES_(XLA_DECLARE_ENUM) -#undef XLA_DECLARE_ENUM - - // Plugin is not considered a native device type. - // It has a special treatment for some functions. - PLUGIN, -}; +enum class XlaDeviceType : int8_t { CPU = 0, CUDA, TPU, NEURON, SPMD, PLUGIN }; struct DeviceType : public torch::lazy::BackendDeviceType { DeviceType(XlaDeviceType xla_device_type);