-
Notifications
You must be signed in to change notification settings - Fork 565
Add ConvertXlaShapeToLazy helper function #4081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+155
−0
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
1820d44
Add ConvertXlaShapeToLazy helper function
wonjoo-wj ebd567b
Fix build failures during type conversions
wonjoo-wj 6bde8a5
Add unit tests for lazy
wonjoo-wj ff8ce67
Remove debugging print statements
wonjoo-wj 958f35f
Run linter
wonjoo-wj fea01a5
Update element_type conversion to an existing util function
wonjoo-wj 5dc3627
Update minor fixes and add more unit tests
wonjoo-wj 9c0fb95
Update tests
wonjoo-wj 68c5178
Run linter again
wonjoo-wj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
#include <gtest/gtest.h> | ||
|
||
#include "tensorflow/compiler/xla/shape.h" | ||
#include "torch/csrc/lazy/core/shape.h" | ||
#include "torch_xla/csrc/helpers.h" | ||
#include "torch_xla_test.h" | ||
|
||
namespace torch_xla { | ||
namespace cpp_test { | ||
|
||
class LazyTest : public TorchXlaTest {}; | ||
|
||
TEST_F(LazyTest, TestXlaShapeToLazyWithF64) { | ||
int64_t dimensions[] = {1}; | ||
bool dynamic_dimensions[] = {false}; | ||
absl::Span<const int64_t> xla_dimensions = | ||
absl::Span<const int64_t>(dimensions); | ||
absl::Span<const bool> xla_dynamic_dimensions = | ||
absl::Span<const bool>(dynamic_dimensions); | ||
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>(); | ||
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::F64, xla_dimensions, | ||
xla_dynamic_dimensions, xla_tuple_shapes); | ||
|
||
torch::lazy::Shape lazy_shape = XlaHelpers::ConvertXlaShapeToLazy(xla_shape); | ||
std::vector<int64_t> lazy_dimensions = | ||
xla::util::ToVector<int64_t>(lazy_shape.sizes()); | ||
const c10::optional<std::vector<bool>>& lazy_dynamic_dimensions = | ||
lazy_shape.is_symbolic(); | ||
EXPECT_EQ(lazy_shape.scalar_type(), at::ScalarType::Double); | ||
EXPECT_EQ(lazy_dimensions, xla::util::ToVector<int64_t>(xla_dimensions)); | ||
EXPECT_EQ(lazy_dynamic_dimensions.has_value(), false); | ||
} | ||
|
||
TEST_F(LazyTest, TestXlaShapeToLazyWithPred) { | ||
int64_t dimensions[] = {1}; | ||
bool dynamic_dimensions[] = {false}; | ||
absl::Span<const int64_t> xla_dimensions = | ||
absl::Span<const int64_t>(dimensions); | ||
absl::Span<const bool> xla_dynamic_dimensions = | ||
absl::Span<const bool>(dynamic_dimensions); | ||
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>(); | ||
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::PRED, xla_dimensions, | ||
xla_dynamic_dimensions, xla_tuple_shapes); | ||
|
||
torch::lazy::Shape lazy_shape = XlaHelpers::ConvertXlaShapeToLazy(xla_shape); | ||
std::vector<int64_t> lazy_dimensions = | ||
xla::util::ToVector<int64_t>(lazy_shape.sizes()); | ||
const c10::optional<std::vector<bool>>& lazy_dynamic_dimensions = | ||
lazy_shape.is_symbolic(); | ||
EXPECT_EQ(lazy_shape.scalar_type(), at::ScalarType::Bool); | ||
EXPECT_EQ(lazy_dimensions, xla::util::ToVector<int64_t>(xla_dimensions)); | ||
EXPECT_EQ(lazy_dynamic_dimensions.has_value(), false); | ||
} | ||
|
||
TEST_F(LazyTest, TestXlaShapeToLazyWithU64) { | ||
int64_t dimensions[] = {1}; | ||
bool dynamic_dimensions[] = {false}; | ||
absl::Span<const int64_t> xla_dimensions = | ||
absl::Span<const int64_t>(dimensions); | ||
absl::Span<const bool> xla_dynamic_dimensions = | ||
absl::Span<const bool>(dynamic_dimensions); | ||
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>(); | ||
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::U64, xla_dimensions, | ||
xla_dynamic_dimensions, xla_tuple_shapes); | ||
|
||
torch::lazy::Shape lazy_shape = XlaHelpers::ConvertXlaShapeToLazy(xla_shape); | ||
std::vector<int64_t> lazy_dimensions = | ||
xla::util::ToVector<int64_t>(lazy_shape.sizes()); | ||
const c10::optional<std::vector<bool>>& lazy_dynamic_dimensions = | ||
lazy_shape.is_symbolic(); | ||
EXPECT_EQ(lazy_shape.scalar_type(), at::ScalarType::Long); | ||
EXPECT_EQ(lazy_dimensions, xla::util::ToVector<int64_t>(xla_dimensions)); | ||
EXPECT_EQ(lazy_dynamic_dimensions.has_value(), false); | ||
} | ||
|
||
TEST_F(LazyTest, TestXlaShapeToLazyWithMultipleDimensions) { | ||
int64_t dimensions[] = {2, 1, 3}; | ||
bool dynamic_dimensions[] = {false, false, false}; | ||
absl::Span<const int64_t> xla_dimensions = | ||
absl::Span<const int64_t>(dimensions); | ||
absl::Span<const bool> xla_dynamic_dimensions = | ||
absl::Span<const bool>(dynamic_dimensions); | ||
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>(); | ||
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::F64, xla_dimensions, | ||
xla_dynamic_dimensions, xla_tuple_shapes); | ||
|
||
torch::lazy::Shape lazy_shape = XlaHelpers::ConvertXlaShapeToLazy(xla_shape); | ||
std::vector<int64_t> lazy_dimensions = | ||
xla::util::ToVector<int64_t>(lazy_shape.sizes()); | ||
const c10::optional<std::vector<bool>>& lazy_dynamic_dimensions = | ||
lazy_shape.is_symbolic(); | ||
EXPECT_EQ(lazy_shape.scalar_type(), at::ScalarType::Double); | ||
EXPECT_EQ(lazy_dimensions, xla::util::ToVector<int64_t>(xla_dimensions)); | ||
EXPECT_EQ(lazy_dynamic_dimensions.has_value(), false); | ||
} | ||
|
||
TEST_F(LazyTest, TestXlaShapeToLazyWithDynamicDimensions) { | ||
int64_t dimensions[] = {2, 1, 3}; | ||
bool dynamic_dimensions[] = {true, false, true}; | ||
absl::Span<const int64_t> xla_dimensions = | ||
absl::Span<const int64_t>(dimensions); | ||
absl::Span<const bool> xla_dynamic_dimensions = | ||
absl::Span<const bool>(dynamic_dimensions); | ||
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>(); | ||
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::F64, xla_dimensions, | ||
xla_dynamic_dimensions, xla_tuple_shapes); | ||
|
||
torch::lazy::Shape lazy_shape = XlaHelpers::ConvertXlaShapeToLazy(xla_shape); | ||
std::vector<int64_t> lazy_dimensions = | ||
xla::util::ToVector<int64_t>(lazy_shape.sizes()); | ||
const c10::optional<std::vector<bool>>& lazy_dynamic_dimensions = | ||
lazy_shape.is_symbolic(); | ||
EXPECT_EQ(lazy_shape.scalar_type(), at::ScalarType::Double); | ||
EXPECT_EQ(lazy_dimensions, xla::util::ToVector<int64_t>(xla_dimensions)); | ||
EXPECT_EQ(lazy_dynamic_dimensions.has_value(), true); | ||
EXPECT_EQ(lazy_dynamic_dimensions.value(), | ||
std::vector<bool>(std::begin(dynamic_dimensions), | ||
std::end(dynamic_dimensions))); | ||
} | ||
|
||
TEST_F(LazyTest, TestXlaShapeToLazyWithUnsupportedPrimitiveType) { | ||
int64_t dimensions[] = {1}; | ||
bool dynamic_dimensions[] = {false}; | ||
absl::Span<const int64_t> xla_dimensions = | ||
absl::Span<const int64_t>(dimensions); | ||
absl::Span<const bool> xla_dynamic_dimensions = | ||
absl::Span<const bool>(dynamic_dimensions); | ||
std::vector<xla::Shape> xla_tuple_shapes = std::vector<xla::Shape>(); | ||
xla::Shape xla_shape = xla::Shape(xla::PrimitiveType::TUPLE, xla_dimensions, | ||
xla_dynamic_dimensions, xla_tuple_shapes); | ||
|
||
EXPECT_THROW(XlaHelpers::ConvertXlaShapeToLazy(xla_shape), | ||
std::runtime_error); | ||
} | ||
|
||
} // namespace cpp_test | ||
} // namespace torch_xla |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.