Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions test/cpp/test_device.cpp
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
#include <gtest/gtest.h>

#include <cstdint>
#include <string>
#include <string_view>

#include <torch/csrc/lazy/backend/backend_device.h>

#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"

#include "torch_xla/csrc/device.h"

namespace torch_xla {

static void CheckFormatError(const std::string& spec) {
absl::StatusOr<torch::lazy::BackendDevice> r =
torch_xla::SafeParseDeviceString(spec);
EXPECT_FALSE(r.ok());
absl::StatusOr<torch::lazy::BackendDevice> r = SafeParseDeviceString(spec);
ASSERT_FALSE(r.ok());
EXPECT_EQ(r.status().message(),
absl::StrCat("expected the device string `", spec,
"` to be in the format: `<type>:<index>`."));
}

static void CheckIndexParseError(const std::string& spec) {
absl::StatusOr<torch::lazy::BackendDevice> r =
torch_xla::SafeParseDeviceString(spec);
EXPECT_FALSE(r.ok());
absl::StatusOr<torch::lazy::BackendDevice> r = SafeParseDeviceString(spec);
ASSERT_FALSE(r.ok());
EXPECT_EQ(
r.status().message(),
absl::StrCat("error while parsing the device spec `", spec, "`: stoi"));
Expand All @@ -33,3 +40,47 @@ TEST(DeviceTest, ParseDeviceStringIndexParseError) {
CheckIndexParseError("xla:xla");
CheckIndexParseError("xla:x11");
}

static void CheckDeviceTypeConstructionWithString(
XlaDeviceType xla_device_type, std::string_view device_type_str) {
DeviceType device_type(device_type_str);
EXPECT_EQ(device_type.getType(), xla_device_type);
EXPECT_EQ(device_type.toString(), absl::StrCat(device_type_str, ":"));
}

TEST(DeviceTest, ConstructNativeDeviceTypeWithString) {
CheckDeviceTypeConstructionWithString(XlaDeviceType::CPU, "CPU");
CheckDeviceTypeConstructionWithString(XlaDeviceType::CUDA, "CUDA");
CheckDeviceTypeConstructionWithString(XlaDeviceType::TPU, "TPU");
CheckDeviceTypeConstructionWithString(XlaDeviceType::NEURON, "NEURON");
CheckDeviceTypeConstructionWithString(XlaDeviceType::SPMD, "SPMD");
}

TEST(DeviceTest, ConstructPluginDeviceTypeWithString) {
DeviceType device_type("OTHER");
EXPECT_EQ(device_type.getType(), XlaDeviceType::PLUGIN);
EXPECT_EQ(device_type.toString(), "OTHER:");
}

static void CheckDeviceTypeConstructionWithEnum(
XlaDeviceType xla_device_type, std::string_view device_type_str) {
DeviceType device_type(xla_device_type);
ASSERT_EQ(device_type.getType(), xla_device_type);
EXPECT_EQ(device_type.toString(), absl::StrCat(device_type_str, ":"));
}

TEST(DeviceTest, ConstructNativeDeviceTypeWithEnum) {
CheckDeviceTypeConstructionWithEnum(XlaDeviceType::CPU, "CPU");
CheckDeviceTypeConstructionWithEnum(XlaDeviceType::CUDA, "CUDA");
CheckDeviceTypeConstructionWithEnum(XlaDeviceType::TPU, "TPU");
CheckDeviceTypeConstructionWithEnum(XlaDeviceType::NEURON, "NEURON");
CheckDeviceTypeConstructionWithEnum(XlaDeviceType::SPMD, "SPMD");
}

TEST(DeviceTest, ConstructPluginDeviceTypeWithEnumError) {
EXPECT_DEATH({ DeviceType device_type(XlaDeviceType::PLUGIN); },
absl::StrCat("invalid native XlaDeviceType value: ",
static_cast<int8_t>(XlaDeviceType::PLUGIN)));
}

} // namespace torch_xla
101 changes: 68 additions & 33 deletions torch_xla/csrc/device.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
#include "torch_xla/csrc/device.h"

#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <iterator>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include <torch/csrc/lazy/backend/backend_device.h>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"

#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/status.h"

Expand All @@ -20,45 +29,71 @@ namespace {
static bool spmd_config_is_locked = false;
static bool use_virtual_device = false;

} // namespace

std::string DeviceType::XlaDeviceTypeToString(XlaDeviceType hw_type) {
XLA_CHECK(hw_type != XlaDeviceType::PLUGIN) << "PLUGIN type name unknown";

switch (hw_type) {
case XlaDeviceType::CPU:
return "CPU";
case XlaDeviceType::CUDA:
return "CUDA";
case XlaDeviceType::TPU:
return "TPU";
case XlaDeviceType::NEURON:
return "NEURON";
case XlaDeviceType::SPMD:
return "SPMD";
default:
XLA_ERROR() << "Invalid device type";
constexpr int8_t kNativeXlaDeviceTypeNumber =
static_cast<int8_t>(XlaDeviceType::PLUGIN);

// The elements in this array should match the order in the XlaDeviceType enum
// declaration. So, if you modify one of them, make sure to keep them in sync.
constexpr std::array<std::string_view, kNativeXlaDeviceTypeNumber>
kNativeXlaDeviceTypeNames = {"CPU", "CUDA", "TPU", "NEURON", "SPMD"};

absl::Status CheckIsNativeXlaDeviceType(int8_t value) {
if (value < 0 || value >= kNativeXlaDeviceTypeNumber) {
return XLA_ERROR_WITH_LOCATION(absl::InternalError(
absl::StrCat("invalid native XlaDeviceType value: ", value,
" (casted to int). It should be non-negative, less than ",
kNativeXlaDeviceTypeNumber,
" (number of native XlaDeviceType). This shouldn't be "
"called for XlaDeviceType::PLUGIN (",
static_cast<int8_t>(XlaDeviceType::PLUGIN), ").")));
}
return absl::OkStatus();
}

XlaDeviceType DeviceType::StringToXlaDeviceType(const std::string& type_name) {
if (type_name == "SPMD") {
return XlaDeviceType::SPMD;
} else if (type_name == "TPU") {
return XlaDeviceType::TPU;
} else if (type_name == "CPU") {
return XlaDeviceType::CPU;
} else if (type_name == "CUDA") {
return XlaDeviceType::CUDA;
} else if (type_name == "NEURON") {
return XlaDeviceType::NEURON;
std::string_view NativeXlaDeviceTypeToString(XlaDeviceType type) {
int8_t value = static_cast<int8_t>(type);
// This check makes sure we are not dealing with:
//
// 1. Invalid XlaDeviceType (i.e. result of conversion of a number bigger
// than PLUGIN -- the last enum value)
//
// 2. The XlaDeviceType::PLUGIN enum, since it's not considered a "native"
// device type
XLA_CHECK_OK(CheckIsNativeXlaDeviceType(value));
return kNativeXlaDeviceTypeNames[value];
}

XlaDeviceType StringToXlaDeviceType(std::string_view type_name) {
std::array<std::string_view, kNativeXlaDeviceTypeNumber>::const_iterator it =
std::find(kNativeXlaDeviceTypeNames.begin(),
kNativeXlaDeviceTypeNames.end(), type_name);

if (it == kNativeXlaDeviceTypeNames.end()) {
return XlaDeviceType::PLUGIN;
}

return XlaDeviceType::PLUGIN;
std::size_t index = std::distance(kNativeXlaDeviceTypeNames.begin(), it);
return static_cast<XlaDeviceType>(index);
}

} // namespace

DeviceType::DeviceType(XlaDeviceType xla_device_type)
: torch::lazy::BackendDeviceType(static_cast<int8_t>(xla_device_type)),
type_name_() {
XLA_CHECK_OK(CheckIsNativeXlaDeviceType(type));
}

DeviceType::DeviceType(std::string_view type_name)
: torch::lazy::BackendDeviceType(
static_cast<int8_t>(StringToXlaDeviceType(type_name))),
type_name_(type_name) {}

std::string DeviceType::toString() const {
return absl::StrCat(type_name_, ":");
std::string_view str = (getType() == XlaDeviceType::PLUGIN)
? type_name_
: NativeXlaDeviceTypeToString(getType());
return absl::StrCat(str, ":");
}

XlaDeviceType DeviceType::getType() const {
Expand Down
31 changes: 16 additions & 15 deletions torch_xla/csrc/device.h
Original file line number Diff line number Diff line change
@@ -1,39 +1,40 @@
#ifndef XLA_TORCH_XLA_CSRC_DEVICE_H_
#define XLA_TORCH_XLA_CSRC_DEVICE_H_

#include <iostream>
#include <cstdint>
#include <string>
#include <string_view>

#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/core/hash.h>
#include <torch/csrc/lazy/core/util.h>

#include "torch_xla/csrc/runtime/util.h"
#include "absl/status/statusor.h"

namespace torch_xla {

// TODO(yeounoh) `SPMD` is a virtual device that defers data `TransferToDevice`
// until after the paritioning pass. This avoids transfering the full input
// tensor to the device.
enum class XlaDeviceType { CPU, CUDA, TPU, NEURON, SPMD, PLUGIN };
enum class XlaDeviceType : int8_t { CPU = 0, CUDA, TPU, NEURON, SPMD, PLUGIN };

struct DeviceType : public torch::lazy::BackendDeviceType {
DeviceType(XlaDeviceType xla_device_type)
: torch::lazy::BackendDeviceType(static_cast<int>(xla_device_type)),
type_name_(XlaDeviceTypeToString(xla_device_type)) {}
DeviceType(const std::string& type_name)
: torch::lazy::BackendDeviceType(
static_cast<int>(StringToXlaDeviceType(type_name))),
type_name_(type_name) {}
DeviceType(XlaDeviceType xla_device_type);

// Constructor parses the `type_name` into an `XlaDeviceType`.
//
// This should be used in 2 cases:
//
// 1. When using non-native device types.
// Although `XlaDeviceType::PLUGIN` will be used, the `type_name`
// parameter will be stored internally.
//
// 2. When parsing string device types.
DeviceType(std::string_view type_name);

std::string toString() const override;
XlaDeviceType getType() const;

private:
std::string type_name_;

static std::string XlaDeviceTypeToString(XlaDeviceType hw_type);
static XlaDeviceType StringToXlaDeviceType(const std::string& type_name);
};

// Parses the given `device_spec` into a new `BackendDevice`.
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/layout_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

#include <algorithm>
#include <exception>
#include <functional>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>

#include <torch/csrc/lazy/core/hash.h>
#include <torch/csrc/lazy/core/util.h>

#include "absl/strings/str_split.h"
Expand All @@ -17,7 +17,6 @@
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/util.h"

namespace torch_xla {
namespace {
Expand Down
15 changes: 8 additions & 7 deletions torch_xla/csrc/layout_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,28 @@

#include "absl/types/span.h"
#include "xla/shape.h"
#include "xla/types.h"
#include "xla/xla_data.pb.h"

#include "torch_xla/csrc/device.h"

namespace torch_xla {

// Creates a minor-to-major layout from given dimensions. The dynamic_dimensions
// slice should be either empty, or of the same size as dimensions.
// Creates a minor-to-major layout from given dimensions. The
// dynamic_dimensions slice should be either empty, or of the same size as
// dimensions.
xla::Shape MakeTorchTensorLayout(absl::Span<const int64_t> dimensions,
absl::Span<const bool> dynamic_dimensions,
xla::PrimitiveType type);

// Create an XLA shape with the given dimensions and type, suitable to be used
// in the specified device type. The type of device can affect the choice of the
// XLA layout. The dynamic_dimensions slice should be either empty, or of the
// same size as dimensions.
// in the specified device type. The type of device can affect the choice of
// the XLA layout. The dynamic_dimensions slice should be either empty, or of
// the same size as dimensions.
xla::Shape MakeArrayShapeFromDimensions(
absl::Span<const int64_t> dimensions,
absl::Span<const bool> dynamic_dimensions, xla::PrimitiveType type,
XlaDeviceType hw_type);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_LAYOUT_MANAGER_H_
#endif // XLA_TORCH_XLA_CSRC_LAYOUT_MANAGER_H_