Skip to content

Commit

Permalink
Add CollectiveBroadcast support for XLA builder API.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611650063
  • Loading branch information
chaserileyroberts authored and tensorflower-gardener committed Mar 1, 2024
1 parent aeb5233 commit 33b8afb
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
35 changes: 35 additions & 0 deletions third_party/xla/xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3799,6 +3799,34 @@ XlaOp XlaBuilder::AllToAllTuple(
});
}

XlaOp XlaBuilder::CollectiveBroadcast(
XlaOp operand, absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id) {
return CollectiveBroadcastImpl(operand, replica_groups, channel_id);
}

XlaOp XlaBuilder::CollectiveBroadcastImpl(
XlaOp operand, absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id) {
return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferCollectiveBroadcastShape({operand_shape}));
*instr.mutable_shape() = shape.ToProto();
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
}
if (channel_id.has_value()) {
instr.set_channel_id(channel_id->handle());
}

return AddInstruction(std::move(instr), HloOpcode::kCollectiveBroadcast,
{operand});
});
}

XlaOp XlaBuilder::CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
Expand Down Expand Up @@ -5312,6 +5340,13 @@ XlaOp AllToAllTuple(const XlaOp operand, int64_t split_dimension,
replica_groups, layout, channel_id);
}

XlaOp CollectiveBroadcast(const XlaOp operand,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id) {
return operand.builder()->CollectiveBroadcast(operand, replica_groups,
channel_id);
}

XlaOp CollectivePermute(
const XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
Expand Down
15 changes: 15 additions & 0 deletions third_party/xla/xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,10 @@ class XlaBuilder {
const std::optional<Layout>& layout,
const std::optional<ChannelHandle>& channel_id = std::nullopt);

XlaOp CollectiveBroadcast(
XlaOp operand, absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id = std::nullopt);

XlaOp CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
Expand Down Expand Up @@ -1492,6 +1496,9 @@ class XlaBuilder {
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<Layout>& layout,
const std::optional<ChannelHandle>& channel_id);
friend XlaOp CollectiveBroadcast(
XlaOp operand, absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id);
friend XlaOp CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
Expand Down Expand Up @@ -1640,6 +1647,10 @@ class XlaBuilder {
const std::optional<Shape>& layout,
std::optional<bool> use_global_device_ids, bool async);

XlaOp CollectiveBroadcastImpl(XlaOp operand,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id);

XlaOp CollectivePermuteImpl(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
Expand Down Expand Up @@ -2532,6 +2543,10 @@ XlaOp AllToAllTuple(
const std::optional<Layout>& layout = std::nullopt,
const std::optional<ChannelHandle>& channel_id = std::nullopt);

XlaOp CollectiveBroadcast(
XlaOp operand, absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id = std::nullopt);

// Enqueues an collective operation that sends and receives data cross replicas.
//
// - `source_target_pair`: a list of (source_replica_id, target_replica_id)
Expand Down
12 changes: 12 additions & 0 deletions third_party/xla/xla/client/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,18 @@ TEST(XlaBuilderTest, AllReduceTuple) {
.WithShapeEqualTo(&tuple_shape)));
}

TEST(XlaBuilderTest, CollectiveBroadcast) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
ReplicaGroup replica_group;
replica_group.add_replica_ids(0);
replica_group.add_replica_ids(1);
CollectiveBroadcast(x, {replica_group});
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(b));
auto root = module->entry_computation()->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kCollectiveBroadcast);
}

TEST(XlaBuilderTest, CollectivePermute) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
Expand Down

0 comments on commit 33b8afb

Please sign in to comment.