Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h>
#include "aten_xla_bridge.h"
#include "aten_xla_type_instances.h"
#include "cpp_test_util.h"
#include "tensor_impl.h"
#include "tensorflow/compiler/xla/xla_client/metrics.h"
#include "torch_util.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/aten_xla_type_instances.h"
#include "torch_xla/csrc/tensor_impl.h"
#include "torch_xla/csrc/torch_util.h"
#include "torch_xla_test.h"

namespace torch_xla {
Expand Down
59 changes: 58 additions & 1 deletion test/cpp/test_tensor.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,75 @@
#include <gtest/gtest.h>

#include <limits>
#include <vector>

#include <ATen/ATen.h>
#include "cpp_test_util.h"
#include "tensor.h"
#include "torch/csrc/autograd/variable.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla_test.h"

namespace torch_xla {
namespace cpp_test {
namespace {

bool CheckBidirectionalConversion(
const at::Tensor& input, at::ScalarType dest_element_type,
c10::optional<xla::PrimitiveType> xla_type = c10::nullopt) {
xla::Literal literal = GetTensorLiteral(input, /*shape=*/nullptr);
if (xla_type) {
literal = literal.Convert(*xla_type).ConsumeValueOrDie();
}
at::Tensor converted = MakeTensorFromXlaLiteral(literal, dest_element_type);
return EqualValues(converted, input);
}

} // namespace

using TensorTest = TorchXlaTest;

TEST_F(TensorTest, TestConversions) {
{
at::Tensor a = at::randint(std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max(), {2, 2},
at::TensorOptions(at::kByte));
CheckBidirectionalConversion(a, at::ScalarType::Short);
CheckBidirectionalConversion(a, at::ScalarType::Int);
CheckBidirectionalConversion(a, at::ScalarType::Long);
}
{
at::Tensor a = at::randint(std::numeric_limits<int8_t>::min(),
std::numeric_limits<int8_t>::max(), {2, 2},
at::TensorOptions(at::kChar));
CheckBidirectionalConversion(a, at::ScalarType::Short);
CheckBidirectionalConversion(a, at::ScalarType::Int);
CheckBidirectionalConversion(a, at::ScalarType::Long);
}
{
at::Tensor a = at::randint(std::numeric_limits<int16_t>::min(),
std::numeric_limits<int16_t>::max(), {2, 2},
at::TensorOptions(at::kShort));
CheckBidirectionalConversion(a, at::ScalarType::Int);
CheckBidirectionalConversion(a, at::ScalarType::Long);
}
{
at::Tensor a = at::randint(std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::max(), {2, 2},
at::TensorOptions(at::kInt));
CheckBidirectionalConversion(a, at::ScalarType::Long);
}
{
at::Tensor a = at::randint(0, 1, {2, 2}, at::TensorOptions(at::kByte));
CheckBidirectionalConversion(a, at::ScalarType::Byte,
xla::PrimitiveType::PRED);
}
{
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat));
CheckBidirectionalConversion(a, at::ScalarType::Double);
}
}

TEST_F(TensorTest, TestAdd) {
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat));
at::Tensor b = at::rand({2, 2}, at::TensorOptions(at::kFloat));
Expand Down
95 changes: 62 additions & 33 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,44 @@ at::Tensor XlaLiteralToTensor(const xla::Literal& literal, at::ScalarType atype,
return tensor;
}

template <typename SType>
at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal,
xla::PrimitiveType source_element_type,
at::ScalarType dest_element_type) {
switch (dest_element_type) {
case at::ScalarType::Byte: {
return XlaLiteralToTensor<SType, uint8_t>(literal, dest_element_type,
source_element_type);
}
case at::ScalarType::Char: {
return XlaLiteralToTensor<SType, int8_t>(literal, dest_element_type,
source_element_type);
}
case at::ScalarType::Short: {
return XlaLiteralToTensor<SType, int16_t>(literal, dest_element_type,
source_element_type);
}
case at::ScalarType::Int: {
return XlaLiteralToTensor<SType, int32_t>(literal, dest_element_type,
source_element_type);
}
case at::ScalarType::Long: {
return XlaLiteralToTensor<SType, int64_t>(literal, dest_element_type,
source_element_type);
}
case at::ScalarType::Float: {
return XlaLiteralToTensor<SType, float>(literal, dest_element_type,
source_element_type);
}
case at::ScalarType::Double: {
return XlaLiteralToTensor<SType, double>(literal, dest_element_type,
source_element_type);
}
default:
XLA_ERROR() << "Unsupported scalar type: " << dest_element_type;
}
}

} // namespace

namespace detail {
Expand All @@ -318,39 +356,30 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal,
at::ScalarType dest_element_type) {
xla::PrimitiveType element_type = literal.shape().element_type();
switch (element_type) {
case xla::PrimitiveType::PRED: {
XLA_CHECK_EQ(dest_element_type, at::ScalarType::Byte);
return XlaLiteralToTensor<bool, uint8_t>(literal, at::ScalarType::Byte,
xla::PrimitiveType::PRED);
}
case xla::PrimitiveType::BF16: {
return XlaLiteralToTensor<tensorflow::bfloat16, float>(
literal, at::ScalarType::Float, xla::PrimitiveType::F32);
}
case xla::PrimitiveType::F32: {
return XlaLiteralToTensor<float, float>(literal, at::ScalarType::Float,
xla::PrimitiveType::F32);
}
case xla::PrimitiveType::U8: {
return XlaLiteralToTensor<xla::uint8, uint8_t>(
literal, at::ScalarType::Byte, xla::PrimitiveType::U8);
}
case xla::PrimitiveType::S8: {
return XlaLiteralToTensor<xla::int8, int8_t>(
literal, at::ScalarType::Char, xla::PrimitiveType::S8);
}
case xla::PrimitiveType::S16: {
return XlaLiteralToTensor<xla::int16, int16_t>(
literal, at::ScalarType::Short, xla::PrimitiveType::S16);
}
case xla::PrimitiveType::S32: {
return XlaLiteralToTensor<xla::int32, int32_t>(
literal, at::ScalarType::Int, xla::PrimitiveType::S32);
}
case xla::PrimitiveType::S64: {
return XlaLiteralToTensor<xla::int64, int64_t>(
literal, at::ScalarType::Long, xla::PrimitiveType::S64);
}
case xla::PrimitiveType::PRED:
return XlaLiteralToTensorHelper<bool>(literal, element_type,
dest_element_type);
case xla::PrimitiveType::BF16:
return XlaLiteralToTensorHelper<tensorflow::bfloat16>(
literal, element_type, dest_element_type);
case xla::PrimitiveType::F32:
return XlaLiteralToTensorHelper<float>(literal, element_type,
dest_element_type);
case xla::PrimitiveType::U8:
return XlaLiteralToTensorHelper<xla::uint8>(literal, element_type,
dest_element_type);
case xla::PrimitiveType::S8:
return XlaLiteralToTensorHelper<xla::int8>(literal, element_type,
dest_element_type);
case xla::PrimitiveType::S16:
return XlaLiteralToTensorHelper<xla::int16>(literal, element_type,
dest_element_type);
case xla::PrimitiveType::S32:
return XlaLiteralToTensorHelper<xla::int32>(literal, element_type,
dest_element_type);
case xla::PrimitiveType::S64:
return XlaLiteralToTensorHelper<xla::int64>(literal, element_type,
dest_element_type);
default:
XLA_ERROR() << "Unsupported literal type: " << literal.shape();
}
Expand Down