Skip to content

Commit

Permalink
Add unbounded dynamism test for ComplexOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622343055
  • Loading branch information
ghpvnist authored and tensorflower-gardener committed Apr 6, 2024
1 parent 38da638 commit ba6437b
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 7 deletions.
14 changes: 10 additions & 4 deletions third_party/xla/xla/client/xla_builder_test.cc
Expand Up @@ -1797,7 +1797,8 @@ TEST_P(XlaBuilderUnboundedUnaryOpTest, UnboundedUnaryOpTest) {
TF_ASSERT_OK_AND_ASSIGN(const Shape expected,
ParseShape(GetParam().expected));
GetParam().unary_op(Parameter(&b, 0, operand, "operand"));
TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<xla::HloModule> module,
BuildHloModule(b));
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}
Expand All @@ -1811,9 +1812,9 @@ TEST_P(XlaBuilderUnboundedBinaryOpTest, UnboundedBinaryOpTest) {
GetParam().binary_op(Parameter(&b, 0, lhs, "lhs"),
Parameter(&b, 1, rhs, "rhs"),
GetParam().broadcast_dimensions);
if (auto result = BuildHloModule(b); result.ok()) {
const std::unique_ptr<HloModule> module = std::move(*result);
EXPECT_THAT(GetRoot(*module),
if (const auto result = BuildHloModule(b); result.ok()) {
ASSERT_NE(*result, nullptr);
EXPECT_THAT(GetRoot(**result),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
} else {
ASSERT_TRUE(GetParam().error_message.has_value());
Expand Down Expand Up @@ -2540,6 +2541,11 @@ INSTANTIATE_TEST_SUITE_P(
{"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]",
/*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]",
&Atan2},
{"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]",
/*broadcast_dimensions=*/empty_array, "c64[?, ?, 2, 2, <=2, <=2, ?]",
&Complex},
{"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array,
"c64[?, 10]", &Complex},
{"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]",
/*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]",
&Div},
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/BUILD
Expand Up @@ -599,12 +599,12 @@ xla_cc_test(
"//xla:statusor",
"//xla:test",
"//xla:test_helpers",
"//xla:types",
"//xla:xla_data_proto_cc",
"//xla/client:padding",
"//xla/hlo/ir:hlo",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
Expand Down
59 changes: 57 additions & 2 deletions third_party/xla/xla/service/shape_inference_test.cc
Expand Up @@ -15,14 +15,18 @@ limitations under the License.

#include "xla/service/shape_inference.h"

#include <array>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include <gtest/gtest.h>
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
Expand All @@ -32,10 +36,8 @@ limitations under the License.
#include "xla/service/hlo_parser.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/statusor.h"
#include "xla/test.h"
#include "xla/test_helpers.h"
#include "xla/types.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
Expand Down Expand Up @@ -141,6 +143,10 @@ class UnboundedBinaryOpShapeInferenceTest
class UnboundedCompareOpShapeInferenceTest
: public ::testing::TestWithParam<BinaryOpTestCase> {};

// Subclass for testing unbounded dynamic complex op
class UnboundedComplexOpShapeInferenceTest
: public ::testing::TestWithParam<BinaryOpTestCase> {};

// Subclass for testing unbounded dynamic concatenate op
class UnboundedConcatenateOpShapeInferenceTest
: public ::testing::TestWithParam<std::vector<std::string>> {};
Expand Down Expand Up @@ -4236,6 +4242,25 @@ TEST_P(UnboundedCompareOpShapeInferenceTest, UnboundedCompare) {
}
}

TEST_P(UnboundedComplexOpShapeInferenceTest, UnboundedComplex) {
TF_ASSERT_OK_AND_ASSIGN(const Shape real, ParseShape(GetParam().lhs));
TF_ASSERT_OK_AND_ASSIGN(const Shape imag, ParseShape(GetParam().rhs));
const absl::StatusOr<Shape> inferred_shape =
ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, real, imag,
GetParam().broadcast_dimensions);
if (inferred_shape.ok()) {
TF_ASSERT_OK_AND_ASSIGN(const Shape expected,
ParseShape(GetParam().expected));
EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected))
<< "inferred: " << ShapeUtil::HumanString(*inferred_shape)
<< " expected: " << ShapeUtil::HumanString(expected);
} else {
ASSERT_TRUE(GetParam().error_message.has_value());
EXPECT_THAT(inferred_shape.status().message(),
HasSubstr(*GetParam().error_message));
}
}

TEST_P(UnboundedConcatenateOpShapeInferenceTest, UnboundedConcatenate) {
TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape(GetParam()[0]));
TF_ASSERT_OK_AND_ASSIGN(const Shape operand2, ParseShape(GetParam()[1]));
Expand Down Expand Up @@ -4879,6 +4904,36 @@ INSTANTIATE_TEST_SUITE_P(UnboundedDynamism,
"",
kIncompatibleBinaryOpShapeErrorMessage}}));

INSTANTIATE_TEST_SUITE_P(UnboundedDynamism,
UnboundedComplexOpShapeInferenceTest,
::testing::ValuesIn<BinaryOpTestCase>(
{// LHS | RHS | bdims | Res
// 1 | ? | [] | ?
{"f32[1]", "f32[?]", {}, "c64[?]"},
// ? | 1 | [] | ?
{"f32[?]", "f32[1]", {}, "c64[?]"},
// 2 | ? | [] | 2
{"f32[2]", "f32[?]", {}, "c64[2]"},
// ? | 2 | [] | 2
{"f32[?]", "f32[2]", {}, "c64[2]"},
// <=2 | ? | [] | <=2
{"f32[<=2]", "f32[?]", {}, "c64[<=2]"},
// ? | <=2 | [] | <=2
{"f32[?]", "f32[<=2]", {}, "c64[<=2]"},
// ? | ? | [] | ?
{"f32[?]", "f32[?]", {}, "c64[?]"},
// 1 | ?,3 | [0] | ?,3
{"f32[1]", "f32[?,3]", zero_array, "c64[?,3]"},
// 2 | ?,3 | [0] | err
{"f32[2]", "f32[?,3]", zero_array, "",
kBroadcastDimensionMismatchErrorMessage},
// ?,2 | ?,3 | [] | err
{"f32[?,2]",
"f32[?,3]",
{},
"",
kIncompatibleBinaryOpShapeErrorMessage}}));

INSTANTIATE_TEST_SUITE_P(
UnboundedDynamism, UnboundedConcatenateOpShapeInferenceTest,
::testing::Values(
Expand Down

0 comments on commit ba6437b

Please sign in to comment.