diff --git a/third_party/xla/xla/client/xla_builder.cc b/third_party/xla/xla/client/xla_builder.cc index 8b6ee8630dfdf2..9ff782bf566fd5 100644 --- a/third_party/xla/xla/client/xla_builder.cc +++ b/third_party/xla/xla/client/xla_builder.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/client/xla_builder.h" +#include #include #include #include @@ -3800,15 +3801,55 @@ XlaOp XlaBuilder::AllToAllArray( return all_to_all; } DimensionVector sizes; + const bool is_unbounded = operand_shape->is_unbounded_dynamic(); + std::vector dynamic_sizes; + auto GetR1DimensionSizeOrConstant = [&](XlaOp operand, + int64_t dimension) -> XlaOp { + if (operand_shape->is_unbounded_dynamic_dimension(dimension)) { + return Reshape(GetDimensionSize(operand, dimension), {1}); + } + return ConstantR1( + this, {static_cast(operand_shape->dimensions(dimension))}); + }; + XlaOp r1_split_count = + ConstantR1(this, {static_cast(split_count)}); for (int64_t i = 0; i < operand_shape->rank(); ++i) { if (i != split_dimension) { sizes.push_back(operand_shape->dimensions(i)); + if (is_unbounded) { + dynamic_sizes.push_back(GetR1DimensionSizeOrConstant(operand, i)); + } continue; } sizes.push_back(split_count); - sizes.push_back(operand_shape->dimensions(i) / split_count); + sizes.push_back(operand_shape->is_unbounded_dynamic_dimension(i) + ? Shape::kUnboundedSize + : operand_shape->dimensions(i) / split_count); + + if (is_unbounded) { + dynamic_sizes.push_back(r1_split_count); + dynamic_sizes.push_back( + operand_shape->is_unbounded_dynamic_dimension(i) + ? Div(GetR1DimensionSizeOrConstant(operand, i), r1_split_count) + : ConstantR1(this, + {static_cast(sizes.back())})); + } + } + + if (is_unbounded) { + std::vector dynamic_dimensions; + std::transform( + sizes.begin(), sizes.end(), std::back_inserter(dynamic_dimensions), + [](int64_t size) { return size == Shape::kUnboundedSize; }); + TF_ASSIGN_OR_RETURN( + const Shape shape, + ShapeUtil::MakeValidatedShape(all_to_all_shape.element_type(), sizes, + dynamic_dimensions)); + all_to_all = + MhloDynamicReshape(all_to_all, ConcatInDim(dynamic_sizes, 0), shape); + } else { + all_to_all = Reshape(all_to_all, sizes); } - all_to_all = Reshape(all_to_all, sizes); std::vector permutation; const auto rank = operand_shape->rank(); @@ -3821,6 +3862,21 @@ XlaOp XlaBuilder::AllToAllArray( permutation.push_back(dim_after_reshape); } all_to_all = Transpose(all_to_all, permutation); + + if (is_unbounded) { + std::vector new_dimensions; + for (int64_t i = 0; i < operand_shape->rank(); ++i) { + new_dimensions.push_back(GetR1DimensionSizeOrConstant(operand, i)); + } + new_dimensions[split_dimension] = + Div(new_dimensions[split_dimension], r1_split_count); + new_dimensions[concat_dimension] = + Mul(new_dimensions[concat_dimension], r1_split_count); + + return MhloDynamicReshape(all_to_all, ConcatInDim(new_dimensions, 0), + all_to_all_shape); + } + return Reshape(all_to_all_shape, all_to_all); }); } @@ -3876,6 +3932,13 @@ XlaOp XlaBuilder::AllToAllTuple( const std::optional& channel_id) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + if (operand_shape->is_unbounded_dynamic() || + split_dimension == Shape::kUnboundedSize || + concat_dimension == Shape::kUnboundedSize || + split_count == Shape::kUnboundedSize) { + return InvalidArgument( + "AllToAllTuple does not support unbounded dynamic shapes"); + } // The HloInstruction for AllToAll currently only handles the data // communication: it accepts N already split parts and scatters them to N @@ -3901,14 +3964,14 @@ XlaOp XlaBuilder::AllToAllTuple( } // Handle data communication. - XlaOp alltoall = + XlaOp all_to_all = this->AllToAllTuple(slices, replica_groups, layout, channel_id); // Concat the N received parts. std::vector received; received.reserve(split_count); for (int i = 0; i < split_count; i++) { - received.push_back(this->GetTupleElement(alltoall, i)); + received.push_back(this->GetTupleElement(all_to_all, i)); } return this->ConcatInDim(received, concat_dimension); }); diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h index 24ca2db618664e..6d57c871a28546 100644 --- a/third_party/xla/xla/client/xla_builder.h +++ b/third_party/xla/xla/client/xla_builder.h @@ -2566,7 +2566,10 @@ XlaOp ReduceScatter( const std::optional& layout = std::nullopt, std::optional use_global_device_ids = std::nullopt); -// Enqueues an operation that do an Alltoall of the operand cross cores. +// Enqueues an operation that do an AllToAll of the operand cross cores. +// This involves AllToAll, followed by Reshape, Transpose, and another Reshape +// to get proper codegen. See implementation for additional details. +// // An optional `layout` can be specified to force the layout of the instruction. // This is used to guarantee the same layout for a group of AllToAll ops // compiled separately. diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc index c4d43bc1dc1c8e..8a980accf22433 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/client/xla_builder_test.cc @@ -1990,6 +1990,130 @@ TEST(XlaBuilderTest, UnboundedAllReduce) { GmockMatch(m::Op().WithShapeEqualTo(&expected))); } +TEST(XlaBuilderTest, UnboundedAllToAllDynamicSplitDimension) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 45]")); + AllToAll(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*split_dimension=*/0, + /*concat_dimension=*/1, + /*split_count=*/3, + /*replica_groups=*/{}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + std::cout << module->ToString() << "\n"; + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedAllToAllDynamicConcatDimension) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 5]")); + AllToAll(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*split_dimension=*/1, + /*concat_dimension=*/0, + /*split_count=*/3, + /*replica_groups=*/{}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + std::cout << module->ToString() << "\n"; + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedAllToAllDynamicSplitAndConcatDimensionEqual) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 15]")); + AllToAll(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*split_dimension=*/0, + /*concat_dimension=*/0, + /*split_count=*/3, + /*replica_groups=*/{}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + std::cout << module->ToString() << "\n"; + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedAllToAllFullyDynamic) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?]")); + AllToAll(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*split_dimension=*/0, + /*concat_dimension=*/1, + /*split_count=*/3, + /*replica_groups=*/{}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + std::cout << module->ToString() << "\n"; + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedAllToAllTupleVariadicUnsupported) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]{1,0}")); + b.ReportErrorOrReturn( + AllToAllTuple(/*operands=*/{Parameter(&b, 0, operand, "operand0"), + Parameter(&b, 1, operand, "operand1")}, + /*replica_groups=*/{})); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, + HasSubstr( + "AllToAllTuple does not support unbounded dynamic shapes"))); +} + +TEST(XlaBuilderTest, UnboundedAllToAllTupleUnsupported) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]{1,0}")); + b.ReportErrorOrReturn( + AllToAllTuple(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*split_dimension=*/0, + /*concat_dimension=*/1, + /*split_count=*/3, + /*replica_groups=*/{})); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, + HasSubstr( + "AllToAllTuple does not support unbounded dynamic shapes"))); +} + +TEST(XlaBuilderTest, BoundedAllToAllTupleUnsupported) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, <=15]{1,0}")); + b.ReportErrorOrReturn( + AllToAllTuple(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*split_dimension=*/0, + /*concat_dimension=*/1, + /*split_count=*/3, + /*replica_groups=*/{})); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, + HasSubstr("AllToAll does not support bounded dynamic shapes"))); +} + +TEST(XlaBuilderTest, BoundedAllToAllUnsupported) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, <=15]{1,0}")); + b.ReportErrorOrReturn( + AllToAllTuple(/*operand=*/Parameter(&b, 0, operand, "operand"), + /*split_dimension=*/0, + /*concat_dimension=*/1, + /*split_count=*/3, + /*replica_groups=*/{})); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, + HasSubstr("AllToAll does not support bounded dynamic shapes"))); +} + TEST(XlaBuilderTest, UnboundedAnd) { XlaBuilder b(TestName()); TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index a6ba2645d35273..10d2f5d3f75b44 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -2492,6 +2492,8 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const Shape& shape, int64_t split_dimension, int64_t concat_dimension, int64_t split_count) { TF_RET_CHECK(split_count > 0); + TF_RET_CHECK(!shape.is_bounded_dynamic()) + << "AllToAll does not support bounded dynamic shapes"; if (split_dimension >= shape.rank() || split_dimension < 0) { return InvalidArgument( "AllToAll split_dimension %d is out-of-bounds in shape %s.", @@ -2502,25 +2504,41 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { "AllToAll concat_dimension %d is out-of-bounds in shape %s.", concat_dimension, ShapeUtil::HumanString(shape)); } - if (shape.dimensions(split_dimension) % split_count != 0) { + int64_t split_dimension_size = shape.dimensions(split_dimension); + if (!IsUnboundedDynamicSize(split_dimension_size) && + split_dimension_size % split_count != 0) { return InvalidArgument( "AllToAll split dimension size %d must be dividable by split_count " "%d.", - shape.dimensions(split_dimension), split_count); + split_dimension_size, split_count); } std::vector new_dimensions(shape.dimensions().begin(), shape.dimensions().end()); - new_dimensions[split_dimension] /= split_count; - new_dimensions[concat_dimension] *= split_count; - return ShapeUtil::MakeShape(shape.element_type(), new_dimensions); + new_dimensions[split_dimension] = + IsUnboundedDynamicSize(new_dimensions[split_dimension]) + ? Shape::kUnboundedSize + : new_dimensions[split_dimension] / split_count; + new_dimensions[concat_dimension] = + IsUnboundedDynamicSize(new_dimensions[concat_dimension]) + ? Shape::kUnboundedSize + : new_dimensions[concat_dimension] * split_count; + + const std::vector dynamic_dimensions(shape.dynamic_dimensions().begin(), + shape.dynamic_dimensions().end()); + return ShapeUtil::MakeShape(shape.element_type(), new_dimensions, + dynamic_dimensions); } /* static */ absl::StatusOr ShapeInference::InferAllToAllTupleShape( absl::Span operand_shapes) { - // An Alltoall HLO instruction receives N operands (with the same shape) and + // An AllToAll HLO instruction receives N operands (with the same shape) and // returns a tuple that contains N array shapes. TF_RET_CHECK(!operand_shapes.empty()); for (int i = 0; i < operand_shapes.size(); i++) { + if (operand_shapes[i]->is_unbounded_dynamic()) { + return InvalidArgument( + "AllToAllTuple does not support unbounded dynamic shapes"); + } if (!Shape::Equal().IgnoreMemorySpaceInLayout()(*operand_shapes[0], *operand_shapes[i])) { return InvalidArgument( diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 3afa14fa302eb8..2ec1968a94b7c2 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -4056,6 +4056,32 @@ TEST_F(ShapeInferenceTest, UnboundedAllReduce) { << " expected: " << ShapeUtil::HumanString(expected); } +TEST_F(ShapeInferenceTest, UnboundedAllToAll) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferAllToAllShape(/*shape=*/operand, + /*split_dimension=*/0, + /*concat_dimension=*/0, + /*split_count=*/3)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedAllToAllTupleUnsupported) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("(f32[?, 10], f32[?, 10])")); + const absl::StatusOr inferred_shape = + ShapeInference::InferAllToAllTupleShape( + /*operand_shapes=*/{&operand, &operand}); + EXPECT_THAT( + inferred_shape.status().message(), + HasSubstr("AllToAllTuple does not support unbounded dynamic shapes")); +} + TEST_P(UnboundedLogicalOpShapeInferenceTest, UnboundedAnd) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD index 7e08ca271a02f9..5cdc1ce69ca7a4 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD +++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD @@ -117,7 +117,10 @@ xla_test( "notap", ], }, - backends = ["gpu"], + backends = [ + "cpu", + "gpu", + ], data = [ "data/sharded_16_devices.hlo", "data/sharded_2_devices.hlo", @@ -127,6 +130,8 @@ xla_test( tags = ["nomac"], deps = [ ":functional_hlo_runner", + "//xla:statusor", + "//xla/pjrt:pjrt_client", "//xla/tests:filecheck", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index 662bbdb7e46dcb..369eb4d68a4367 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "xla/pjrt/pjrt_client.h" +#include "xla/statusor.h" #include "xla/tests/filecheck.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" @@ -33,6 +35,20 @@ namespace { using ::testing::SizeIs; +bool IsTestingCpu() { +#ifdef XLA_TEST_BACKEND_CPU + return true; +#endif + return false; +} + +absl::StatusOr> GetPjRtClient() { + if (IsTestingCpu()) { + return xla::FunctionalHloRunner::CreateHostClient(); + } + return xla::FunctionalHloRunner::CreateGpuClient(); +} + class FunctionalHloRunnerTest : public ::testing::Test { protected: std::string GetHloPath(std::string file_name) { @@ -43,7 +59,7 @@ class FunctionalHloRunnerTest : public ::testing::Test { TEST_F(FunctionalHloRunnerTest, SingleDeviceHlo) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::FunctionalHloRunner::CreateGpuClient()); + GetPjRtClient()); // Options corresponding to --num_replicas=1 --num_partitions=1 xla::DebugOptions debug_options; @@ -60,7 +76,7 @@ TEST_F(FunctionalHloRunnerTest, SingleDeviceHlo) { TEST_F(FunctionalHloRunnerTest, Sharded2Devices) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::FunctionalHloRunner::CreateGpuClient()); + GetPjRtClient()); constexpr int kRequiredDeviceCount = 2; const int kDeviceCount = client->device_count(); @@ -89,7 +105,7 @@ TEST_F(FunctionalHloRunnerTest, Sharded2Devices) { TEST_F(FunctionalHloRunnerTest, UseZerosAsInputs) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::FunctionalHloRunner::CreateGpuClient()); + GetPjRtClient()); constexpr int kRequiredDeviceCount = 2; const int kDeviceCount = client->device_count(); @@ -121,7 +137,7 @@ TEST_F(FunctionalHloRunnerTest, UseZerosAsInputs) { TEST_F(FunctionalHloRunnerTest, UseUninitializedInputs) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::FunctionalHloRunner::CreateGpuClient()); + GetPjRtClient()); constexpr int kRequiredDeviceCount = 2; const int kDeviceCount = client->device_count(); @@ -153,7 +169,7 @@ TEST_F(FunctionalHloRunnerTest, UseUninitializedInputs) { TEST_F(FunctionalHloRunnerTest, UseUninitializedInputsWithTupledArguments) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::FunctionalHloRunner::CreateGpuClient()); + GetPjRtClient()); // Options corresponding to: // --num_replicas=1 --num_partitions=1 @@ -196,7 +212,7 @@ TEST_F(FunctionalHloRunnerTest, CanCompileWithoutHavingEnoughGpus) { raw_compile_options.xla_dump_to = dump_dir; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::FunctionalHloRunner::CreateGpuClient()); + GetPjRtClient()); TF_EXPECT_OK(FunctionalHloRunner::LoadAndCompile( *client, debug_options, preproc_options, raw_compile_options, GetHloPath("sharded_16_devices.hlo"), InputFormat::kText)); @@ -212,8 +228,8 @@ TEST_F(FunctionalHloRunnerTest, CanCompileWithoutHavingEnoughGpus) { TF_ASSERT_OK( tsl::ReadFileToString(env, after_opt_hlo_paths[0], &after_opt_hlo)); absl::StatusOr file_check_result = RunFileCheck(after_opt_hlo, R"( - // CHECK: param = f32[16,1]{1,0} - // CHECK: add = f32[16,1]{1,0} + // CHECK: param{{.*}} = f32[16,1]{1,0} + // CHECK: add{{.*}} = f32[16,1]{1,0} )"); TF_ASSERT_OK(file_check_result.status()); EXPECT_TRUE(file_check_result.value());