From f32c5f26f13946eeb929f0abbc582bdd24778a5a Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Mon, 18 Feb 2019 08:23:04 -0800 Subject: [PATCH] Enable the conversion API to work across allowed types. --- test/cpp/test_aten_xla_tensor.cpp | 8 +-- test/cpp/test_tensor.cpp | 59 ++++++++++++++++++- torch_xla/csrc/tensor_util.cpp | 95 ++++++++++++++++++++----------- 3 files changed, 124 insertions(+), 38 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 129b278339c9..99bd8cf5562b 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -7,12 +7,12 @@ #include #include -#include "aten_xla_bridge.h" -#include "aten_xla_type_instances.h" #include "cpp_test_util.h" -#include "tensor_impl.h" #include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "torch_util.h" +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/aten_xla_type_instances.h" +#include "torch_xla/csrc/tensor_impl.h" +#include "torch_xla/csrc/torch_util.h" #include "torch_xla_test.h" namespace torch_xla { diff --git a/test/cpp/test_tensor.cpp b/test/cpp/test_tensor.cpp index cd946a1f392e..32a1b13571ea 100644 --- a/test/cpp/test_tensor.cpp +++ b/test/cpp/test_tensor.cpp @@ -1,18 +1,75 @@ #include +#include #include #include #include "cpp_test_util.h" -#include "tensor.h" #include "torch/csrc/autograd/variable.h" +#include "torch_xla/csrc/tensor.h" +#include "torch_xla/csrc/tensor_util.h" #include "torch_xla_test.h" namespace torch_xla { namespace cpp_test { +namespace { + +bool CheckBidirectionalConversion( + const at::Tensor& input, at::ScalarType dest_element_type, + c10::optional xla_type = c10::nullopt) { + xla::Literal literal = GetTensorLiteral(input, /*shape=*/nullptr); + if (xla_type) { + literal = literal.Convert(*xla_type).ConsumeValueOrDie(); + } + at::Tensor converted = MakeTensorFromXlaLiteral(literal, dest_element_type); + return EqualValues(converted, input); +} + +} // namespace using TensorTest = TorchXlaTest; +TEST_F(TensorTest, TestConversions) { + { + at::Tensor a = at::randint(std::numeric_limits::min(), + std::numeric_limits::max(), {2, 2}, + at::TensorOptions(at::kByte)); + CheckBidirectionalConversion(a, at::ScalarType::Short); + CheckBidirectionalConversion(a, at::ScalarType::Int); + CheckBidirectionalConversion(a, at::ScalarType::Long); + } + { + at::Tensor a = at::randint(std::numeric_limits::min(), + std::numeric_limits::max(), {2, 2}, + at::TensorOptions(at::kChar)); + CheckBidirectionalConversion(a, at::ScalarType::Short); + CheckBidirectionalConversion(a, at::ScalarType::Int); + CheckBidirectionalConversion(a, at::ScalarType::Long); + } + { + at::Tensor a = at::randint(std::numeric_limits::min(), + std::numeric_limits::max(), {2, 2}, + at::TensorOptions(at::kShort)); + CheckBidirectionalConversion(a, at::ScalarType::Int); + CheckBidirectionalConversion(a, at::ScalarType::Long); + } + { + at::Tensor a = at::randint(std::numeric_limits::min(), + std::numeric_limits::max(), {2, 2}, + at::TensorOptions(at::kInt)); + CheckBidirectionalConversion(a, at::ScalarType::Long); + } + { + at::Tensor a = at::randint(0, 1, {2, 2}, at::TensorOptions(at::kByte)); + CheckBidirectionalConversion(a, at::ScalarType::Byte, + xla::PrimitiveType::PRED); + } + { + at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat)); + CheckBidirectionalConversion(a, at::ScalarType::Double); + } +} + TEST_F(TensorTest, TestAdd) { at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat)); at::Tensor b = at::rand({2, 2}, at::TensorOptions(at::kFloat)); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 29bad421d9a6..cbb386edd12c 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -293,6 +293,44 @@ at::Tensor XlaLiteralToTensor(const xla::Literal& literal, at::ScalarType atype, return tensor; } +template +at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal, + xla::PrimitiveType source_element_type, + at::ScalarType dest_element_type) { + switch (dest_element_type) { + case at::ScalarType::Byte: { + return XlaLiteralToTensor(literal, dest_element_type, + source_element_type); + } + case at::ScalarType::Char: { + return XlaLiteralToTensor(literal, dest_element_type, + source_element_type); + } + case at::ScalarType::Short: { + return XlaLiteralToTensor(literal, dest_element_type, + source_element_type); + } + case at::ScalarType::Int: { + return XlaLiteralToTensor(literal, dest_element_type, + source_element_type); + } + case at::ScalarType::Long: { + return XlaLiteralToTensor(literal, dest_element_type, + source_element_type); + } + case at::ScalarType::Float: { + return XlaLiteralToTensor(literal, dest_element_type, + source_element_type); + } + case at::ScalarType::Double: { + return XlaLiteralToTensor(literal, dest_element_type, + source_element_type); + } + default: + XLA_ERROR() << "Unsupported scalar type: " << dest_element_type; + } +} + } // namespace namespace detail { @@ -318,39 +356,30 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal, at::ScalarType dest_element_type) { xla::PrimitiveType element_type = literal.shape().element_type(); switch (element_type) { - case xla::PrimitiveType::PRED: { - XLA_CHECK_EQ(dest_element_type, at::ScalarType::Byte); - return XlaLiteralToTensor(literal, at::ScalarType::Byte, - xla::PrimitiveType::PRED); - } - case xla::PrimitiveType::BF16: { - return XlaLiteralToTensor( - literal, at::ScalarType::Float, xla::PrimitiveType::F32); - } - case xla::PrimitiveType::F32: { - return XlaLiteralToTensor(literal, at::ScalarType::Float, - xla::PrimitiveType::F32); - } - case xla::PrimitiveType::U8: { - return XlaLiteralToTensor( - literal, at::ScalarType::Byte, xla::PrimitiveType::U8); - } - case xla::PrimitiveType::S8: { - return XlaLiteralToTensor( - literal, at::ScalarType::Char, xla::PrimitiveType::S8); - } - case xla::PrimitiveType::S16: { - return XlaLiteralToTensor( - literal, at::ScalarType::Short, xla::PrimitiveType::S16); - } - case xla::PrimitiveType::S32: { - return XlaLiteralToTensor( - literal, at::ScalarType::Int, xla::PrimitiveType::S32); - } - case xla::PrimitiveType::S64: { - return XlaLiteralToTensor( - literal, at::ScalarType::Long, xla::PrimitiveType::S64); - } + case xla::PrimitiveType::PRED: + return XlaLiteralToTensorHelper(literal, element_type, + dest_element_type); + case xla::PrimitiveType::BF16: + return XlaLiteralToTensorHelper( + literal, element_type, dest_element_type); + case xla::PrimitiveType::F32: + return XlaLiteralToTensorHelper(literal, element_type, + dest_element_type); + case xla::PrimitiveType::U8: + return XlaLiteralToTensorHelper(literal, element_type, + dest_element_type); + case xla::PrimitiveType::S8: + return XlaLiteralToTensorHelper(literal, element_type, + dest_element_type); + case xla::PrimitiveType::S16: + return XlaLiteralToTensorHelper(literal, element_type, + dest_element_type); + case xla::PrimitiveType::S32: + return XlaLiteralToTensorHelper(literal, element_type, + dest_element_type); + case xla::PrimitiveType::S64: + return XlaLiteralToTensorHelper(literal, element_type, + dest_element_type); default: XLA_ERROR() << "Unsupported literal type: " << literal.shape(); }