diff --git a/test/cpp/test_device.cpp b/test/cpp/test_device.cpp index d63b8164349..79cce566b3b 100644 --- a/test/cpp/test_device.cpp +++ b/test/cpp/test_device.cpp @@ -1,22 +1,29 @@ #include +#include +#include +#include + +#include + +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "torch_xla/csrc/device.h" +namespace torch_xla { + static void CheckFormatError(const std::string& spec) { - absl::StatusOr r = - torch_xla::SafeParseDeviceString(spec); - EXPECT_FALSE(r.ok()); + absl::StatusOr r = SafeParseDeviceString(spec); + ASSERT_FALSE(r.ok()); EXPECT_EQ(r.status().message(), absl::StrCat("expected the device string `", spec, "` to be in the format: `:`.")); } static void CheckIndexParseError(const std::string& spec) { - absl::StatusOr r = - torch_xla::SafeParseDeviceString(spec); - EXPECT_FALSE(r.ok()); + absl::StatusOr r = SafeParseDeviceString(spec); + ASSERT_FALSE(r.ok()); EXPECT_EQ( r.status().message(), absl::StrCat("error while parsing the device spec `", spec, "`: stoi")); @@ -33,3 +40,47 @@ TEST(DeviceTest, ParseDeviceStringIndexParseError) { CheckIndexParseError("xla:xla"); CheckIndexParseError("xla:x11"); } + +static void CheckDeviceTypeConstructionWithString( + XlaDeviceType xla_device_type, std::string_view device_type_str) { + DeviceType device_type(device_type_str); + EXPECT_EQ(device_type.getType(), xla_device_type); + EXPECT_EQ(device_type.toString(), absl::StrCat(device_type_str, ":")); +} + +TEST(DeviceTest, ConstructNativeDeviceTypeWithString) { + CheckDeviceTypeConstructionWithString(XlaDeviceType::CPU, "CPU"); + CheckDeviceTypeConstructionWithString(XlaDeviceType::CUDA, "CUDA"); + CheckDeviceTypeConstructionWithString(XlaDeviceType::TPU, "TPU"); + CheckDeviceTypeConstructionWithString(XlaDeviceType::NEURON, "NEURON"); + CheckDeviceTypeConstructionWithString(XlaDeviceType::SPMD, "SPMD"); +} + +TEST(DeviceTest, ConstructPluginDeviceTypeWithString) { + DeviceType device_type("OTHER"); + EXPECT_EQ(device_type.getType(), XlaDeviceType::PLUGIN); + EXPECT_EQ(device_type.toString(), "OTHER:"); +} + +static void CheckDeviceTypeConstructionWithEnum( + XlaDeviceType xla_device_type, std::string_view device_type_str) { + DeviceType device_type(xla_device_type); + ASSERT_EQ(device_type.getType(), xla_device_type); + EXPECT_EQ(device_type.toString(), absl::StrCat(device_type_str, ":")); +} + +TEST(DeviceTest, ConstructNativeDeviceTypeWithEnum) { + CheckDeviceTypeConstructionWithEnum(XlaDeviceType::CPU, "CPU"); + CheckDeviceTypeConstructionWithEnum(XlaDeviceType::CUDA, "CUDA"); + CheckDeviceTypeConstructionWithEnum(XlaDeviceType::TPU, "TPU"); + CheckDeviceTypeConstructionWithEnum(XlaDeviceType::NEURON, "NEURON"); + CheckDeviceTypeConstructionWithEnum(XlaDeviceType::SPMD, "SPMD"); +} + +TEST(DeviceTest, ConstructPluginDeviceTypeWithEnumError) { + EXPECT_DEATH({ DeviceType device_type(XlaDeviceType::PLUGIN); }, + absl::StrCat("invalid native XlaDeviceType value: ", + static_cast(XlaDeviceType::PLUGIN))); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index ead6ec9dea5..c016b60d53d 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -1,14 +1,23 @@ #include "torch_xla/csrc/device.h" +#include +#include +#include +#include +#include +#include #include +#include #include -#include +#include + +#include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" -#include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/status.h" @@ -20,45 +29,71 @@ namespace { static bool spmd_config_is_locked = false; static bool use_virtual_device = false; -} // namespace - -std::string DeviceType::XlaDeviceTypeToString(XlaDeviceType hw_type) { - XLA_CHECK(hw_type != XlaDeviceType::PLUGIN) << "PLUGIN type name unknown"; - - switch (hw_type) { - case XlaDeviceType::CPU: - return "CPU"; - case XlaDeviceType::CUDA: - return "CUDA"; - case XlaDeviceType::TPU: - return "TPU"; - case XlaDeviceType::NEURON: - return "NEURON"; - case XlaDeviceType::SPMD: - return "SPMD"; - default: - XLA_ERROR() << "Invalid device type"; +constexpr int8_t kNativeXlaDeviceTypeNumber = + static_cast(XlaDeviceType::PLUGIN); + +// The elements in this array should match the order in the XlaDeviceType enum +// declaration. So, if you modify one of them, make sure to keep them in sync. +constexpr std::array + kNativeXlaDeviceTypeNames = {"CPU", "CUDA", "TPU", "NEURON", "SPMD"}; + +absl::Status CheckIsNativeXlaDeviceType(int8_t value) { + if (value < 0 || value >= kNativeXlaDeviceTypeNumber) { + return XLA_ERROR_WITH_LOCATION(absl::InternalError( + absl::StrCat("invalid native XlaDeviceType value: ", value, + " (casted to int). It should be non-negative, less than ", + kNativeXlaDeviceTypeNumber, + " (number of native XlaDeviceType). This shouldn't be " + "called for XlaDeviceType::PLUGIN (", + static_cast(XlaDeviceType::PLUGIN), ")."))); } + return absl::OkStatus(); } -XlaDeviceType DeviceType::StringToXlaDeviceType(const std::string& type_name) { - if (type_name == "SPMD") { - return XlaDeviceType::SPMD; - } else if (type_name == "TPU") { - return XlaDeviceType::TPU; - } else if (type_name == "CPU") { - return XlaDeviceType::CPU; - } else if (type_name == "CUDA") { - return XlaDeviceType::CUDA; - } else if (type_name == "NEURON") { - return XlaDeviceType::NEURON; +std::string_view NativeXlaDeviceTypeToString(XlaDeviceType type) { + int8_t value = static_cast(type); + // This check makes sure we are not dealing with: + // + // 1. Invalid XlaDeviceType (i.e. result of conversion of a number bigger + // than PLUGIN -- the last enum value) + // + // 2. The XlaDeviceType::PLUGIN enum, since it's not considered a "native" + // device type + XLA_CHECK_OK(CheckIsNativeXlaDeviceType(value)); + return kNativeXlaDeviceTypeNames[value]; +} + +XlaDeviceType StringToXlaDeviceType(std::string_view type_name) { + std::array::const_iterator it = + std::find(kNativeXlaDeviceTypeNames.begin(), + kNativeXlaDeviceTypeNames.end(), type_name); + + if (it == kNativeXlaDeviceTypeNames.end()) { + return XlaDeviceType::PLUGIN; } - return XlaDeviceType::PLUGIN; + std::size_t index = std::distance(kNativeXlaDeviceTypeNames.begin(), it); + return static_cast(index); } +} // namespace + +DeviceType::DeviceType(XlaDeviceType xla_device_type) + : torch::lazy::BackendDeviceType(static_cast(xla_device_type)), + type_name_() { + XLA_CHECK_OK(CheckIsNativeXlaDeviceType(type)); +} + +DeviceType::DeviceType(std::string_view type_name) + : torch::lazy::BackendDeviceType( + static_cast(StringToXlaDeviceType(type_name))), + type_name_(type_name) {} + std::string DeviceType::toString() const { - return absl::StrCat(type_name_, ":"); + std::string_view str = (getType() == XlaDeviceType::PLUGIN) + ? type_name_ + : NativeXlaDeviceTypeToString(getType()); + return absl::StrCat(str, ":"); } XlaDeviceType DeviceType::getType() const { diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index d0fd9423c53..f6b8f151ad8 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -1,39 +1,40 @@ #ifndef XLA_TORCH_XLA_CSRC_DEVICE_H_ #define XLA_TORCH_XLA_CSRC_DEVICE_H_ -#include +#include #include +#include #include -#include -#include -#include "torch_xla/csrc/runtime/util.h" +#include "absl/status/statusor.h" namespace torch_xla { // TODO(yeounoh) `SPMD` is a virtual device that defers data `TransferToDevice` // until after the paritioning pass. This avoids transfering the full input // tensor to the device. -enum class XlaDeviceType { CPU, CUDA, TPU, NEURON, SPMD, PLUGIN }; +enum class XlaDeviceType : int8_t { CPU = 0, CUDA, TPU, NEURON, SPMD, PLUGIN }; struct DeviceType : public torch::lazy::BackendDeviceType { - DeviceType(XlaDeviceType xla_device_type) - : torch::lazy::BackendDeviceType(static_cast(xla_device_type)), - type_name_(XlaDeviceTypeToString(xla_device_type)) {} - DeviceType(const std::string& type_name) - : torch::lazy::BackendDeviceType( - static_cast(StringToXlaDeviceType(type_name))), - type_name_(type_name) {} + DeviceType(XlaDeviceType xla_device_type); + + // Constructor parses the `type_name` into an `XlaDeviceType`. + // + // This should be used in 2 cases: + // + // 1. When using non-native device types. + // Although `XlaDeviceType::PLUGIN` will be used, the `type_name` + // parameter will be stored internally. + // + // 2. When parsing string device types. + DeviceType(std::string_view type_name); std::string toString() const override; XlaDeviceType getType() const; private: std::string type_name_; - - static std::string XlaDeviceTypeToString(XlaDeviceType hw_type); - static XlaDeviceType StringToXlaDeviceType(const std::string& type_name); }; // Parses the given `device_spec` into a new `BackendDevice`. diff --git a/torch_xla/csrc/layout_manager.cpp b/torch_xla/csrc/layout_manager.cpp index 36fb5d8e390..e1ac00a890a 100644 --- a/torch_xla/csrc/layout_manager.cpp +++ b/torch_xla/csrc/layout_manager.cpp @@ -2,12 +2,12 @@ #include #include -#include #include #include #include #include +#include #include #include "absl/strings/str_split.h" @@ -17,7 +17,6 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/runtime/util.h" namespace torch_xla { namespace { diff --git a/torch_xla/csrc/layout_manager.h b/torch_xla/csrc/layout_manager.h index 540c6d221e1..e270d2b8cbd 100644 --- a/torch_xla/csrc/layout_manager.h +++ b/torch_xla/csrc/layout_manager.h @@ -3,22 +3,23 @@ #include "absl/types/span.h" #include "xla/shape.h" -#include "xla/types.h" +#include "xla/xla_data.pb.h" #include "torch_xla/csrc/device.h" namespace torch_xla { -// Creates a minor-to-major layout from given dimensions. The dynamic_dimensions -// slice should be either empty, or of the same size as dimensions. +// Creates a minor-to-major layout from given dimensions. The +// dynamic_dimensions slice should be either empty, or of the same size as +// dimensions. xla::Shape MakeTorchTensorLayout(absl::Span dimensions, absl::Span dynamic_dimensions, xla::PrimitiveType type); // Create an XLA shape with the given dimensions and type, suitable to be used -// in the specified device type. The type of device can affect the choice of the -// XLA layout. The dynamic_dimensions slice should be either empty, or of the -// same size as dimensions. +// in the specified device type. The type of device can affect the choice of +// the XLA layout. The dynamic_dimensions slice should be either empty, or of +// the same size as dimensions. xla::Shape MakeArrayShapeFromDimensions( absl::Span dimensions, absl::Span dynamic_dimensions, xla::PrimitiveType type, @@ -26,4 +27,4 @@ xla::Shape MakeArrayShapeFromDimensions( } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_LAYOUT_MANAGER_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_LAYOUT_MANAGER_H_