Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ set(TORCH_XLA_TEST_SOURCES
test_xla_backend_intf.cpp
test_symint.cpp
test_xla_sharding.cpp
test_lazy.cpp
)

add_executable(test_ptxla ${TORCH_XLA_TEST_SOURCES})
Expand Down
137 changes: 137 additions & 0 deletions test/cpp/test_lazy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#include <gtest/gtest.h>

#include "tensorflow/compiler/xla/shape.h"
#include "torch/csrc/lazy/core/shape.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla_test.h"

namespace torch_xla {
namespace cpp_test {

class LazyTest : public TorchXlaTest {};

TEST_F(LazyTest, TestXlaShapeToLazyWithF64) {
int64_t dimensions[] = {1};
bool dynamic_dimensions[] = {false};
absl::Span<const int64_t> xla_dimensions =
absl::Span<const int64_t>(dimensions);
absl::Span<const bool> xla_dynamic_dimensions =
absl::Span<const bool>(dynamic_dimensions);
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>();
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::F64, xla_dimensions,
xla_dynamic_dimensions, xla_tuple_shapes);

torch::lazy::Shape lazy_shape = XlaHelpers::ConvertXlaShapeToLazy(xla_shape);
std::vector<int64_t> lazy_dimensions =
xla::util::ToVector<int64_t>(lazy_shape.sizes());
const c10::optional<std::vector<bool>>& lazy_dynamic_dimensions =
lazy_shape.is_symbolic();
EXPECT_EQ(lazy_shape.scalar_type(), at::ScalarType::Double);
EXPECT_EQ(lazy_dimensions, xla::util::ToVector<int64_t>(xla_dimensions));
EXPECT_EQ(lazy_dynamic_dimensions.has_value(), false);
}

TEST_F(LazyTest, TestXlaShapeToLazyWithPred) {
int64_t dimensions[] = {1};
bool dynamic_dimensions[] = {false};
absl::Span<const int64_t> xla_dimensions =
absl::Span<const int64_t>(dimensions);
absl::Span<const bool> xla_dynamic_dimensions =
absl::Span<const bool>(dynamic_dimensions);
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>();
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::PRED, xla_dimensions,
xla_dynamic_dimensions, xla_tuple_shapes);

torch::lazy::Shape lazy_shape = XlaHelpers::ConvertXlaShapeToLazy(xla_shape);
std::vector<int64_t> lazy_dimensions =
xla::util::ToVector<int64_t>(lazy_shape.sizes());
const c10::optional<std::vector<bool>>& lazy_dynamic_dimensions =
lazy_shape.is_symbolic();
EXPECT_EQ(lazy_shape.scalar_type(), at::ScalarType::Bool);
EXPECT_EQ(lazy_dimensions, xla::util::ToVector<int64_t>(xla_dimensions));
EXPECT_EQ(lazy_dynamic_dimensions.has_value(), false);
}

TEST_F(LazyTest, TestXlaShapeToLazyWithU64) {
int64_t dimensions[] = {1};
bool dynamic_dimensions[] = {false};
absl::Span<const int64_t> xla_dimensions =
absl::Span<const int64_t>(dimensions);
absl::Span<const bool> xla_dynamic_dimensions =
absl::Span<const bool>(dynamic_dimensions);
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>();
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::U64, xla_dimensions,
xla_dynamic_dimensions, xla_tuple_shapes);

torch::lazy::Shape lazy_shape = XlaHelpers::ConvertXlaShapeToLazy(xla_shape);
std::vector<int64_t> lazy_dimensions =
xla::util::ToVector<int64_t>(lazy_shape.sizes());
const c10::optional<std::vector<bool>>& lazy_dynamic_dimensions =
lazy_shape.is_symbolic();
EXPECT_EQ(lazy_shape.scalar_type(), at::ScalarType::Long);
EXPECT_EQ(lazy_dimensions, xla::util::ToVector<int64_t>(xla_dimensions));
EXPECT_EQ(lazy_dynamic_dimensions.has_value(), false);
}

TEST_F(LazyTest, TestXlaShapeToLazyWithMultipleDimensions) {
int64_t dimensions[] = {2, 1, 3};
bool dynamic_dimensions[] = {false, false, false};
absl::Span<const int64_t> xla_dimensions =
absl::Span<const int64_t>(dimensions);
absl::Span<const bool> xla_dynamic_dimensions =
absl::Span<const bool>(dynamic_dimensions);
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>();
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::F64, xla_dimensions,
xla_dynamic_dimensions, xla_tuple_shapes);

torch::lazy::Shape lazy_shape = XlaHelpers::ConvertXlaShapeToLazy(xla_shape);
std::vector<int64_t> lazy_dimensions =
xla::util::ToVector<int64_t>(lazy_shape.sizes());
const c10::optional<std::vector<bool>>& lazy_dynamic_dimensions =
lazy_shape.is_symbolic();
EXPECT_EQ(lazy_shape.scalar_type(), at::ScalarType::Double);
EXPECT_EQ(lazy_dimensions, xla::util::ToVector<int64_t>(xla_dimensions));
EXPECT_EQ(lazy_dynamic_dimensions.has_value(), false);
}

TEST_F(LazyTest, TestXlaShapeToLazyWithDynamicDimensions) {
int64_t dimensions[] = {2, 1, 3};
bool dynamic_dimensions[] = {true, false, true};
absl::Span<const int64_t> xla_dimensions =
absl::Span<const int64_t>(dimensions);
absl::Span<const bool> xla_dynamic_dimensions =
absl::Span<const bool>(dynamic_dimensions);
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>();
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::F64, xla_dimensions,
xla_dynamic_dimensions, xla_tuple_shapes);

torch::lazy::Shape lazy_shape = XlaHelpers::ConvertXlaShapeToLazy(xla_shape);
std::vector<int64_t> lazy_dimensions =
xla::util::ToVector<int64_t>(lazy_shape.sizes());
const c10::optional<std::vector<bool>>& lazy_dynamic_dimensions =
lazy_shape.is_symbolic();
EXPECT_EQ(lazy_shape.scalar_type(), at::ScalarType::Double);
EXPECT_EQ(lazy_dimensions, xla::util::ToVector<int64_t>(xla_dimensions));
EXPECT_EQ(lazy_dynamic_dimensions.has_value(), true);
EXPECT_EQ(lazy_dynamic_dimensions.value(),
std::vector<bool>(std::begin(dynamic_dimensions),
std::end(dynamic_dimensions)));
}

TEST_F(LazyTest, TestXlaShapeToLazyWithUnsupportedPrimitiveType) {
int64_t dimensions[] = {1};
bool dynamic_dimensions[] = {false};
absl::Span<const int64_t> xla_dimensions =
absl::Span<const int64_t>(dimensions);
absl::Span<const bool> xla_dynamic_dimensions =
absl::Span<const bool>(dynamic_dimensions);
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>();
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::TUPLE, xla_dimensions,
xla_dynamic_dimensions, xla_tuple_shapes);

EXPECT_THROW(XlaHelpers::ConvertXlaShapeToLazy(xla_shape),
std::runtime_error);
}

} // namespace cpp_test
} // namespace torch_xla
14 changes: 14 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,4 +619,18 @@ xla::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
return builder.Build(orig_result);
}

torch::lazy::Shape XlaHelpers::ConvertXlaShapeToLazy(const xla::Shape& shape) {
at::ScalarType scalar_type = TensorTypeFromXlaType(shape.element_type());
c10::optional<std::vector<bool>> is_symbolic = c10::nullopt;
if (shape.is_dynamic()) {
std::vector<bool> xla_dynamic_dimensions =
xla::util::ToVector<bool>(shape.dynamic_dimensions());
is_symbolic = c10::make_optional(xla_dynamic_dimensions);
}

return torch::lazy::Shape(scalar_type,
xla::util::ToVector<int64_t>(shape.dimensions()),
std::move(is_symbolic));
}

} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "torch/csrc/lazy/core/shape.h"
#include "torch/csrc/lazy/core/util.h"

namespace torch_xla {
Expand Down Expand Up @@ -339,6 +340,8 @@ class XlaHelpers {
const std::vector<xla::Shape>& parameter_shapes,
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair);

static torch::lazy::Shape ConvertXlaShapeToLazy(const xla::Shape& shape);

private:
static xla::PrecisionConfig::Precision s_mat_mul_precision;
};
Expand Down