Skip to content

Commit

Permalink
Add unbounded dynamism test for DynamicSliceOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623015159
  • Loading branch information
ghpvnist authored and tensorflower-gardener committed Apr 9, 2024
1 parent 484e9fc commit 08a430e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
20 changes: 19 additions & 1 deletion third_party/xla/xla/client/xla_builder_test.cc
Expand Up @@ -2172,6 +2172,23 @@ TEST(XlaBuilderTest, UnboundedDotGeneral) {
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedDynamicSlice) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape start_indices, ParseShape("s32[]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, 2]"));
DynamicSlice(Parameter(&b, 0, operand, "operand"),
/*start_indices=*/
{
Parameter(&b, 1, start_indices, "start_indices0"),
Parameter(&b, 2, start_indices, "start_indices1"),
},
/*slice_sizes=*/{2, 2});
TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedGather) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, 4, 2]"));
Expand Down Expand Up @@ -2590,7 +2607,8 @@ TEST(XlaBuilderTest, UnboundedSlice) {
/*start_indices=*/{0, 1, 2},
/*limit_indices=*/{1, 3, 5},
/*strides=*/{1, 1, 1});
TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<HloModule> module,
BuildHloModule(b));
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/service/shape_inference.cc
Expand Up @@ -3040,7 +3040,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) {
return InvalidArgument("Negative size index to dynamic slice: %d.",
slice_dim_size);
}
if (slice_dim_size > input_dim_size) {
if (!IsUnboundedDynamicSize(input_dim_size) &&
slice_dim_size > input_dim_size) {
return InvalidArgument(
"Slice dim size %d greater than dynamic slice dimension: %d.",
slice_dim_size, input_dim_size);
Expand Down
14 changes: 14 additions & 0 deletions third_party/xla/xla/service/shape_inference_test.cc
Expand Up @@ -4415,6 +4415,20 @@ TEST_F(ShapeInferenceTest, UnboundedDotGeneral) {
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_F(ShapeInferenceTest, UnboundedDynamicSlice) {
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape start_index, ParseShape("s32[]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, 2]"));
TF_ASSERT_OK_AND_ASSIGN(
const Shape inferred_shape,
ShapeInference::InferDynamicSliceShape(
operand, /*start_index_shapes=*/{start_index, start_index},
/*slice_sizes=*/{2, 2}, /*allow_scalar_indices=*/true));
EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected))
<< "inferred: " << ShapeUtil::HumanString(inferred_shape)
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_F(ShapeInferenceTest, UnboundedGather) {
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, 4, 2]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape start_indices,
Expand Down

0 comments on commit 08a430e

Please sign in to comment.