Skip to content

Commit

Permalink
Add unbounded dynamism test for CholeskyOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622946035
  • Loading branch information
ghpvnist authored and tensorflower-gardener committed Apr 8, 2024
1 parent f259352 commit 24f5fe5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
11 changes: 11 additions & 0 deletions third_party/xla/xla/client/xla_builder_test.cc
Expand Up @@ -1979,6 +1979,17 @@ TEST(XlaBuilderTest, UnboundedBroadcastInDimUnsupported) {
"static or bounded dynamic")));
}

TEST(XlaBuilderTest, UnboundedCholesky) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape a, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]"));
Cholesky(Parameter(&b, 0, a, "a"), /*lower=*/true);
TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<HloModule> module,
BuildHloModule(b));
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedClamp) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape lhs,
Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/service/shape_inference.cc
Expand Up @@ -2370,9 +2370,10 @@ ShapeInference::InferScalarBroadcastShape(absl::Span<const Shape> shapes) {
"The 'a' argument to Cholesky must have rank >= 2, got shape %s",
a.ToString());
}
if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
if (!CompatibleDimensionSizes(a.dimensions(a.rank() - 2),
a.dimensions(a.rank() - 1))) {
return InvalidArgument(
"The two minor dimensions of 'a' must have equal size, got %s.",
"The two minor dimensions of 'a' must have compatible size, got %s.",
a.ToString());
}
return a;
Expand Down
10 changes: 10 additions & 0 deletions third_party/xla/xla/service/shape_inference_test.cc
Expand Up @@ -4194,6 +4194,16 @@ TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimUnsupported) {
HasSubstr("Non-broadcast dimensions must not be dynamic."));
}

TEST_F(ShapeInferenceTest, UnboundedCholesky) {
TF_ASSERT_OK_AND_ASSIGN(const Shape a, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape,
ShapeInference::InferCholeskyShape(a));
EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected))
<< "inferred: " << ShapeUtil::HumanString(inferred_shape)
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_P(UnboundedClampOpShapeInferenceTest, UnboundedClamp) {
TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam()[0]));
TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam()[1]));
Expand Down

0 comments on commit 24f5fe5

Please sign in to comment.