Skip to content

Commit

Permalink
Add unbounded dynamism test for DynamicUpdateSliceOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623549454
  • Loading branch information
ghpvnist authored and tensorflower-gardener committed Apr 10, 2024
1 parent 943bacb commit 1a2c921
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
17 changes: 17 additions & 0 deletions third_party/xla/xla/client/xla_builder_test.cc
Expand Up @@ -2189,6 +2189,23 @@ TEST(XlaBuilderTest, UnboundedDynamicSlice) {
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedDynamicUpdateSlice) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape update, ParseShape("f32[?, 5]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape start_indices, ParseShape("s32[]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]"));
DynamicUpdateSlice(Parameter(&b, 0, operand, "operand"),
Parameter(&b, 1, update, "update"),
/*start_indices=*/
{Parameter(&b, 2, start_indices, "start_indices0"),
Parameter(&b, 3, start_indices, "start_indices1")});
TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<HloModule> module,
BuildHloModule(b));
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedGather) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, 4, 2]"));
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/shape_inference.cc
Expand Up @@ -3165,7 +3165,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) {
for (int64_t dim = 0; dim < operand_shape.rank(); ++dim) {
const int64_t input_dim_size = operand_shape.dimensions(dim);
const int64_t update_dim_size = update_shape.dimensions(dim);
if (update_dim_size < 0) {
if (!IsUnboundedDynamicSize(update_dim_size) && update_dim_size < 0) {
return InvalidArgument(
"Size index %d to dynamic update slice must be >= 0.",
update_dim_size);
Expand Down
15 changes: 15 additions & 0 deletions third_party/xla/xla/service/shape_inference_test.cc
Expand Up @@ -4429,6 +4429,21 @@ TEST_F(ShapeInferenceTest, UnboundedDynamicSlice) {
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_F(ShapeInferenceTest, UnboundedDynamicUpdateSlice) {
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape update, ParseShape("f32[?, 5]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape start_index, ParseShape("s32[]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(
const Shape inferred_shape,
ShapeInference::InferDynamicUpdateSliceShape(
operand, update, /*start_index_shapes=*/{start_index, start_index},
/*allow_scalar_indices=*/true));
EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected))
<< "inferred: " << ShapeUtil::HumanString(inferred_shape)
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_F(ShapeInferenceTest, UnboundedGather) {
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, 4, 2]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape start_indices,
Expand Down

0 comments on commit 1a2c921

Please sign in to comment.