Skip to content

Commit

Permalink
Add unbounded dynamism test for SortOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623013151
  • Loading branch information
ghpvnist authored and tensorflower-gardener committed Apr 9, 2024
1 parent 6b5998c commit 484e9fc
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
25 changes: 25 additions & 0 deletions third_party/xla/xla/client/xla_builder_test.cc
Expand Up @@ -2595,6 +2595,31 @@ TEST(XlaBuilderTest, UnboundedSlice) {
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedSort) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]"));

XlaComputation comparator;
{
const std::unique_ptr<XlaBuilder> sub_builder =
b.CreateSubBuilder("compare");
Compare(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32),
"arg0"),
Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32),
"arg1"),
ComparisonDirection::kLt);
TF_ASSERT_OK_AND_ASSIGN(comparator, sub_builder->Build());
}

Sort({Parameter(&b, 0, operand, "operand")}, comparator,
/*dimension=*/0, /*is_stable=*/true);
TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<HloModule> module,
BuildHloModule(b));
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedTranspose) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand,
Expand Down
11 changes: 11 additions & 0 deletions third_party/xla/xla/service/shape_inference_test.cc
Expand Up @@ -4882,6 +4882,17 @@ TEST_F(ShapeInferenceTest, UnboundedSlice) {
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_F(ShapeInferenceTest, UnboundedSort) {
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(
const Shape inferred_shape,
ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&operand}));
EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected))
<< "inferred: " << ShapeUtil::HumanString(inferred_shape)
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedSub) {
TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs));
TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs));
Expand Down

0 comments on commit 484e9fc

Please sign in to comment.