Skip to content

Commit

Permalink
Add unbounded dynamism test for WhileOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622773485
  • Loading branch information
ghpvnist authored and tensorflower-gardener committed Apr 8, 2024
1 parent d1ca8d8 commit dd22334
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
42 changes: 42 additions & 0 deletions third_party/xla/xla/client/xla_builder_test.cc
Expand Up @@ -2537,6 +2537,48 @@ TEST(XlaBuilderTest, UnboundedTranspose) {
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedWhile) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape init, ParseShape("f32[?]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?]"));

XlaComputation add;
{
const std::unique_ptr<XlaBuilder> sub_builder = b.CreateSubBuilder("add");
Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32),
"arg0"),
Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32),
"arg1"));
TF_ASSERT_OK_AND_ASSIGN(add, sub_builder->Build());
}

XlaComputation condition;
{
const std::unique_ptr<XlaBuilder> sub_builder =
b.CreateSubBuilder("compare");
Ge(/*lhs=*/ConstantR0<float>(sub_builder.get(), 10.0f),
/*rhs=*/Reduce(/*operand=*/Parameter(sub_builder.get(), 0, init, "prev"),
ConstantR0<float>(sub_builder.get(), 0.0f), add,
/*dimensions_to_reduce=*/{0}));
TF_ASSERT_OK_AND_ASSIGN(condition, sub_builder->Build());
}

XlaComputation body;
{
const std::unique_ptr<XlaBuilder> sub_builder = b.CreateSubBuilder("add");
Add(ConstantR1<float>(sub_builder.get(), {1.0f}),
Parameter(sub_builder.get(), 0, init, "prev"),
/*broadcast_dimensions=*/{0});
TF_ASSERT_OK_AND_ASSIGN(body, sub_builder->Build());
}

While(condition, body, Parameter(&b, 0, init, "init"));
TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<HloModule> module,
BuildHloModule(b));
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedXor) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape lhs,
Expand Down
15 changes: 15 additions & 0 deletions third_party/xla/xla/service/shape_inference_test.cc
Expand Up @@ -4892,6 +4892,21 @@ TEST_F(ShapeInferenceTest, UnboundedTransposeRank1) {
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_F(ShapeInferenceTest, UnboundedWhile) {
TF_ASSERT_OK_AND_ASSIGN(const Shape init, ParseShape("f32[?]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape result_shape, ParseShape("f32[?]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?]"));
TF_ASSERT_OK_AND_ASSIGN(
const Shape inferred_shape,
ShapeInference::InferWhileShape(
/*condition=*/ShapeUtil::MakeProgramShape({result_shape}, pred_),
/*body=*/ShapeUtil::MakeProgramShape({result_shape}, result_shape),
/*init=*/init));
EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected))
<< "inferred: " << ShapeUtil::HumanString(inferred_shape)
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_P(UnboundedLogicalOpShapeInferenceTest, UnboundedXor) {
TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs));
TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs));
Expand Down

0 comments on commit dd22334

Please sign in to comment.