Skip to content

Commit

Permalink
Add unbounded dynamism test for SelectAndScatterOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622731814
  • Loading branch information
ghpvnist authored and tensorflower-gardener committed Apr 8, 2024
1 parent b19fcf7 commit 9497c24
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 31 deletions.
50 changes: 45 additions & 5 deletions third_party/xla/xla/client/xla_builder_test.cc
Expand Up @@ -2352,11 +2352,10 @@ TEST(XlaBuilderTest, UnboundedScatter) {
XlaComputation update_computation;
{
const std::unique_ptr<XlaBuilder> sub_builder = b.CreateSubBuilder("add");
const XlaOp arg0 = Parameter(sub_builder.get(), 0,
ShapeUtil::MakeScalarShape(F32), "arg0");
const XlaOp arg1 = Parameter(sub_builder.get(), 1,
ShapeUtil::MakeScalarShape(F32), "arg1");
Add(arg0, arg1);
Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32),
"arg0"),
Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32),
"arg1"));
TF_ASSERT_OK_AND_ASSIGN(update_computation, sub_builder->Build());
}

Expand Down Expand Up @@ -2460,6 +2459,47 @@ TEST(XlaBuilderTest,
StatusIs(_, HasSubstr("Unimplemented implicit broadcast.")));
}

TEST(XlaBuilderTest, UnboundedSelectAndScatter) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape source, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape init_value, ParseShape("f32[]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]"));

XlaComputation select;
{
const std::unique_ptr<XlaBuilder> sub_builder =
b.CreateSubBuilder("compare");
Compare(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32),
"arg0"),
Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32),
"arg1"),
ComparisonDirection::kGe);
TF_ASSERT_OK_AND_ASSIGN(select, sub_builder->Build());
}

XlaComputation scatter;
{
const std::unique_ptr<XlaBuilder> sub_builder = b.CreateSubBuilder("add");
Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32),
"arg0"),
Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32),
"arg1"));
TF_ASSERT_OK_AND_ASSIGN(scatter, sub_builder->Build());
}

SelectAndScatter(Parameter(&b, 0, operand, "operand"), select,
/*window_dimensions=*/
std::array<int64_t, 2>({3, 1}),
/*window_strides=*/std::array<int64_t, 2>({2, 1}),
Padding::kValid, Parameter(&b, 1, source, "source"),
Parameter(&b, 2, init_value, "init_value"), scatter);

TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(b));
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedSlice) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, <=3, ?]"));
Expand Down
91 changes: 65 additions & 26 deletions third_party/xla/xla/service/shape_inference_test.cc
Expand Up @@ -4686,6 +4686,31 @@ TEST_F(ShapeInferenceTest, UnboundedReshapeUnsupportedMixOfDynamism) {
"not supported."));
}

TEST_F(ShapeInferenceTest, UnboundedScatter) {
TF_ASSERT_OK_AND_ASSIGN(Shape input, ParseShape("f32[?, ?, ?]"));
TF_ASSERT_OK_AND_ASSIGN(Shape scatter_indices, ParseShape("s32[?, ?, ?]"));
TF_ASSERT_OK_AND_ASSIGN(Shape updates, ParseShape("f32[?, ?, ?, ?]"));
TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape("f32[?, ?, ?]"));

const ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);

ScatterDimensionNumbers dimension_numbers;
dimension_numbers.add_update_window_dims(2);
dimension_numbers.add_update_window_dims(3);
dimension_numbers.add_inserted_window_dims(0);
dimension_numbers.add_scatter_dims_to_operand_dims(1);
dimension_numbers.add_scatter_dims_to_operand_dims(0);
dimension_numbers.set_index_vector_dim(2);

TF_ASSERT_OK_AND_ASSIGN(
Shape result,
ShapeInference::InferScatterShape({&input, &scatter_indices, &updates},
to_apply, dimension_numbers));
EXPECT_TRUE(ShapeUtil::Equal(result, expected))
<< "inferred: " << ShapeUtil::HumanString(result)
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_P(UnboundedSelectOpShapeInferenceTest, UnboundedSelect) {
TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam()[0]));
TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam()[1]));
Expand Down Expand Up @@ -4714,6 +4739,46 @@ TEST_F(ShapeInferenceTest, UnboundedSelectWithTupleUnsupported) {
"(pred[2], pred[?])."));
}

TEST_F(ShapeInferenceTest, UnboundedSelectAndScatter) {
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape source, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape init_value, ParseShape("f32[]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]"));

Window window;
WindowDimension dim0;
dim0.set_base_dilation(1);
dim0.set_size(3);
dim0.set_stride(2);
dim0.set_padding_low(0);
dim0.set_padding_high(1);
dim0.set_window_dilation(1);

WindowDimension dim1;
dim1.set_base_dilation(1);
dim1.set_size(1);
dim1.set_stride(1);
dim1.set_padding_low(0);
dim1.set_padding_high(0);
dim1.set_window_dilation(1);

*window.add_dimensions() = dim0;
*window.add_dimensions() = dim1;

TF_ASSERT_OK_AND_ASSIGN(
Shape result,
ShapeInference::InferSelectAndScatterShape(
operand,
/*select_shape=*/ShapeUtil::MakeProgramShape({f32_, f32_}, pred_),
window, source, init_value,
/*scatter_shape=*/
ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)));

EXPECT_TRUE(ShapeUtil::Equal(result, expected))
<< "inferred: " << ShapeUtil::HumanString(result)
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedShiftLeft) {
TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs));
TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs));
Expand Down Expand Up @@ -4803,32 +4868,6 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedSub) {
}
}

TEST_F(ShapeInferenceTest, UnboundedScatter) {
TF_ASSERT_OK_AND_ASSIGN(const Shape input, ParseShape("f32[?, ?, ?]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_indices,
ParseShape("s32[?, ?, ?]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape updates, ParseShape("f32[?, ?, ?, ?]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?, ?]"));

const ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);

ScatterDimensionNumbers dimension_numbers;
dimension_numbers.add_update_window_dims(2);
dimension_numbers.add_update_window_dims(3);
dimension_numbers.add_inserted_window_dims(0);
dimension_numbers.add_scatter_dims_to_operand_dims(1);
dimension_numbers.add_scatter_dims_to_operand_dims(0);
dimension_numbers.set_index_vector_dim(2);

TF_ASSERT_OK_AND_ASSIGN(
const Shape result,
ShapeInference::InferScatterShape({&input, &scatter_indices, &updates},
to_apply, dimension_numbers));
EXPECT_TRUE(ShapeUtil::Equal(result, expected))
<< "inferred: " << ShapeUtil::HumanString(result)
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_F(ShapeInferenceTest, UnboundedTranspose) {
TF_ASSERT_OK_AND_ASSIGN(const Shape operand,
ParseShape("f32[1, ?, 2, ?, <=2]{4,3,2,1,0}"));
Expand Down

0 comments on commit 9497c24

Please sign in to comment.