Skip to content

Commit

Permalink
[xla][gpu] Fixed two problems in p2p-schedule-preparation.
Browse files Browse the repository at this point in the history
Previously, we record the address of a P2PGroupNode for a complement-group that
forms a cycle with a group. This address can be changed as P2PGroupNode is an
entity of a map. Fixing this problem by recording the channel ID of the
complement-group instead.

Make sure that a group and its complement-group are both annotated with
a stream for pipelining.

PiperOrigin-RevId: 625770614
  • Loading branch information
bixia1 authored and tensorflower-gardener committed Apr 17, 2024
1 parent c98067c commit d82330b
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 58 deletions.
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/BUILD
Expand Up @@ -1266,14 +1266,14 @@ cc_library(
":collective_ops_utils",
":hlo_pass",
"//xla:status",
"//xla:statusor",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_reachability",
"//xla/hlo/utils:hlo_query",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/BUILD
Expand Up @@ -4176,6 +4176,7 @@ xla_cc_test(
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc",
Expand Down
129 changes: 72 additions & 57 deletions third_party/xla/xla/service/p2p_schedule_preparation.cc
Expand Up @@ -17,12 +17,14 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <optional>
#include <set>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
Expand Down Expand Up @@ -189,6 +191,8 @@ struct P2PGroupNode {
return send_stream;
}

int64_t GetChannel() const { return recv->channel_id().value(); }

HloRecvDoneInstruction* recv_done = nullptr;
HloSendDoneInstruction* send_done = nullptr;
HloRecvInstruction* recv = nullptr;
Expand Down Expand Up @@ -301,7 +305,7 @@ struct P2PGroup {
// Records the other group that forms a cycle with this group, assuming that
// we handle only two groups that form a cycle.
Status RecordComplementGroup(P2PGroupMap& p2p_group_map) {
CHECK(complement_group == nullptr && runtime_stream == kStream1);
CHECK(!complement_group_channel.has_value() && runtime_stream == kStream1);
for (auto& [channel, p2p_group] : p2p_group_map) {
if (&p2p_group == this ||
p2p_group.ChildComputation() != ChildComputation()) {
Expand All @@ -315,12 +319,14 @@ struct P2PGroup {
return Internal(
"Expected different pipeline stream for complement group");
}
complement_group = &p2p_group;
p2p_group.complement_group = this;
// Set the complement_group_channel for the current group.
complement_group_channel = channel;
// Set the complement_group_channel for the complement-group.
p2p_group.complement_group_channel = GetChannel();
} else if (p2p_group.kind == kUnpipelined &&
p2p_group.runtime_stream != kStream1) {
complement_group = &p2p_group;
p2p_group.complement_group = this;
p2p_group.runtime_stream == kStream0) {
complement_group_channel = channel;
p2p_group.complement_group_channel = GetChannel();
}
}
return OkStatus();
Expand All @@ -332,6 +338,7 @@ struct P2PGroup {
// Returns the child computation for the group.
HloComputation* ChildComputation() const { return GetChild().computation; }

int64_t GetChannel() const { return nodes[kUnpipelinedNodeIdx].GetChannel(); }
P2PGroupNode& GetChild() { return nodes[kPipelinedChildNodeIdx]; }
P2PGroupNode& GetParent() { return nodes[kPipelinedParentNodeIdx]; }
const P2PGroupNode& GetChild() const { return nodes[kPipelinedChildNodeIdx]; }
Expand All @@ -341,46 +348,53 @@ struct P2PGroup {

// Returns the start and end of a region marked by a pipelined chain in the
// given computation, which is the region with the pipelined P2P instructions.
ChainStartEnd GetChainStartEnd(HloComputation* computation) const {
if (kind == kUnpipelined) {
if (!InCycle()) {
return std::make_pair(GetChild().recv, GetChild().send_done);
}
CHECK(runtime_stream == kStream1);
return std::make_pair(complement_group->GetChild().recv,
GetChild().send_done);
}

CHECK(kind == kPipelined);
ChainStartEnd GetChainStartEnd(const HloComputation* computation,
const P2PGroupMap& p2p_group_map) const {
if (computation == ChildComputation()) {
if (!InCycle()) {
return std::make_pair(GetChild().recv, GetChild().send_done);
}
CHECK(runtime_stream == kStream1);
return std::make_pair(complement_group->GetChild().recv,
GetChild().send_done);
if (runtime_stream == kStream1) {
return std::make_pair(
GetComplementGroup(p2p_group_map)->GetChild().recv,
GetChild().send_done);
}
return std::make_pair(
GetChild().recv,
GetComplementGroup(p2p_group_map)->GetChild().send_done);
}

CHECK(computation == ParentComputation());
CHECK(kind == kPipelined && computation == ParentComputation());
if (!InCycle()) {
return std::make_pair(GetParent().recv, GetParent().send_done);
}
CHECK(runtime_stream == kStream1);
return std::make_pair(complement_group->GetParent().recv,
GetParent().send_done);
if (runtime_stream == kStream1) {
return std::make_pair(GetComplementGroup(p2p_group_map)->GetParent().recv,
GetParent().send_done);
}
return std::make_pair(
GetParent().recv,
GetComplementGroup(p2p_group_map)->GetParent().send_done);
}

HloInstruction* GetWhileOp() const {
return nodes[kPipelinedParentNodeIdx].while_loop;
}

bool InCycle() const { return complement_group != nullptr; }

bool InCycle() const { return complement_group_channel.has_value(); }
P2PGroup* GetComplementGroup(P2PGroupMap& p2p_group_map) const {
CHECK(InCycle());
return &p2p_group_map.at(*complement_group_channel);
}
const P2PGroup* GetComplementGroup(const P2PGroupMap& p2p_group_map) const {
CHECK(InCycle());
return &p2p_group_map.at(*complement_group_channel);
}
P2PGroupKind kind = kUnpipelined;
P2PGroupNode nodes[2];
P2PRuntimeStream runtime_stream = kUnknown;
// Another P2PGroup that forms a cycle with this group.
P2PGroup* complement_group = nullptr;
// The channel id for another P2PGroup that forms a cycle with this group.
std::optional<int64_t> complement_group_channel = std::nullopt;
};

bool MayInvokeCollectiveOp(
Expand Down Expand Up @@ -506,9 +520,11 @@ Status ConnectP2P2NodeChain(const P2PGroupNode& node0,
// while-body computation, we enforce this ordering:
// recv.0 => send.0 => recv.1 => send.1 =>
// recv-done.0 => recv-done.1 => send-done.0 => send-done.1
Status ConnectPipelined2P2PChild(const P2PGroup& p2p_group) {
return ConnectP2P2NodeChain(p2p_group.complement_group->GetChild(),
p2p_group.GetChild());
Status ConnectPipelined2P2PChild(const P2PGroup& p2p_group,
const P2PGroupMap& p2p_group_map) {
return ConnectP2P2NodeChain(
p2p_group.GetComplementGroup(p2p_group_map)->GetChild(),
p2p_group.GetChild());
}

// For a pipelined Send-Recv chain with one group in the while-body calling
Expand All @@ -522,20 +538,24 @@ Status ConnectPipelined1P2PParent(const P2PGroup& p2p_group) {
// in the while-body calling computation, we enforce this ordering:
// recv.0 => send.0 => recv.1 => send.1 => =>
// recv-done.0 => recv-done.1 => send-done.0 => send-done.1
Status ConnectPipelined2P2PParent(const P2PGroup& p2p_group) {
return ConnectP2P2NodeChain(p2p_group.complement_group->GetParent(),
p2p_group.GetParent());
Status ConnectPipelined2P2PParent(const P2PGroup& p2p_group,
const P2PGroupMap& p2p_group_map) {
return ConnectP2P2NodeChain(
p2p_group.GetComplementGroup(p2p_group_map)->GetParent(),
p2p_group.GetParent());
}

// For a Send-Recv chain with two channel groups forming a cycle in a while-body
// annotated for pipelining but not pipelined (due to skip pipelining pass), we
// enforece this ordering:
// recv.0 => send.0 => recv.1 => send.1 =>
// recv-done.0 => recv-done.1 => send-done.0 => send-done.1
Status ConnectUnpipelined2P2P(const P2PGroup& p2p_group) {
Status ConnectUnpipelined2P2P(const P2PGroup& p2p_group,
const P2PGroupMap& p2p_group_map) {
CHECK(p2p_group.runtime_stream == kStream1);
return ConnectP2P2NodeChain(p2p_group.complement_group->GetChild(),
p2p_group.GetChild());
return ConnectP2P2NodeChain(
p2p_group.GetComplementGroup(p2p_group_map)->GetChild(),
p2p_group.GetChild());
}

// Collects P2P send-done and recv-done instructions from the computation,
Expand Down Expand Up @@ -585,7 +605,7 @@ Status GatherP2PGroupsAndCollectiveInfo(
}
// We can't rely on the operation on p2p_group_map above to find out
// whether it is the first time to handle this channel for the current
// computation, as we may drop information in the present of kUncognized
// computation, as we may drop information in the present of kUnrecognized
// groups.
auto p2p_in_comp = p2p_in_computation.find(computation);
if (p2p_in_comp == p2p_in_computation.end()) {
Expand Down Expand Up @@ -634,11 +654,9 @@ Status GatherP2PGroupsAndCollectiveInfo(
for (auto& [channel, p2p_group] : p2p_group_map) {
if ((p2p_group.kind == kPipelined &&
p2p_group.ParentComputation() != computation) ||
p2p_group.complement_group != nullptr ||
p2p_group.runtime_stream != kStream1) {
p2p_group.InCycle() || p2p_group.runtime_stream != kStream1) {
continue;
}

TF_RETURN_IF_ERROR(p2p_group.RecordComplementGroup(p2p_group_map));
}

Expand Down Expand Up @@ -670,12 +688,12 @@ absl::StatusOr<std::pair<int, const P2PGroup*>> ConnectP2PChain(
if (!p2p_group.InCycle()) {
TF_RETURN_IF_ERROR(ConnectUnpipelinedP2P(p2p_group));
} else if (p2p_group.runtime_stream == kStream1) {
TF_RETURN_IF_ERROR(ConnectUnpipelined2P2P(p2p_group));
TF_RETURN_IF_ERROR(ConnectUnpipelined2P2P(p2p_group, p2p_group_map));
}
continue;
}

if (p2p_group.complement_group == nullptr) {
if (!p2p_group.InCycle()) {
if (computation == p2p_group.ParentComputation()) {
TF_RETURN_IF_ERROR(ConnectPipelined1P2PParent(p2p_group));
} else {
Expand All @@ -696,15 +714,15 @@ absl::StatusOr<std::pair<int, const P2PGroup*>> ConnectP2PChain(
}

if (computation == p2p_group.ParentComputation()) {
TF_RETURN_IF_ERROR(ConnectPipelined2P2PParent(p2p_group));
TF_RETURN_IF_ERROR(ConnectPipelined2P2PParent(p2p_group, p2p_group_map));
} else {
if (pipelined_group != nullptr) {
return Internal(
"Expected only two pipelined groups forming a cycle in a "
"while-body");
}
pipelined_group = &p2p_group;
TF_RETURN_IF_ERROR(ConnectPipelined2P2PChild(p2p_group));
TF_RETURN_IF_ERROR(ConnectPipelined2P2PChild(p2p_group, p2p_group_map));
}
}
return std::make_pair(num_p2p_chains, pipelined_group);
Expand Down Expand Up @@ -738,8 +756,7 @@ Status LinearizeCollectivesWithOtherP2P(
const std::vector<HloInstruction*>::iterator& end_iter,
HloReachabilityMap* reachability) {
HloComputation* computation = (*chain_start_iter)->parent();
ChainStartEnd start_end = group.GetChainStartEnd(computation);

ChainStartEnd start_end = group.GetChainStartEnd(computation, p2p_group_map);
// We refer to the P2P chain represented by `group` chain A.
for (auto it = begin_iter; it != end_iter; ++it) {
HloInstruction* hlo = *it;
Expand All @@ -758,8 +775,8 @@ Status LinearizeCollectivesWithOtherP2P(
// LinearizeCollectivesWithPipelinedP2PChild.
continue;
}

ChainStartEnd cur_start_end = cur_group.GetChainStartEnd(computation);
ChainStartEnd cur_start_end =
cur_group.GetChainStartEnd(computation, p2p_group_map);
if (cur_start_end.first != hlo) {
// We will linearize the two chains when we see the first instruction in
// chain B.
Expand Down Expand Up @@ -822,7 +839,7 @@ Status LinearizeCollectivesWithPipelinedP2PChild(
const P2PGroupMap& p2p_group_map, const P2PGroup& group,
const CollectiveInComputation& collective_in_computation,
HloComputation* computation, HloReachabilityMap* reachability) {
ChainStartEnd start_end = group.GetChainStartEnd(computation);
ChainStartEnd start_end = group.GetChainStartEnd(computation, p2p_group_map);

// If an hlo may invoke collective operation, we add control dependence to
// make sure that the hlo is scheduled before the pipelined chain starts.
Expand Down Expand Up @@ -852,7 +869,8 @@ Status LinearizeCollectivesWithPipelinedP2PChild(
continue;
}

ChainStartEnd cur_start_end = cur_group.GetChainStartEnd(computation);
ChainStartEnd cur_start_end =
cur_group.GetChainStartEnd(computation, p2p_group_map);
TF_RETURN_IF_ERROR(
OrderBefore(reachability, cur_start_end.second, start_end.first));

Expand Down Expand Up @@ -957,12 +975,9 @@ absl::StatusOr<bool> P2PSchedulePreparation::Run(
// to other collectives.
continue;
}
if (group.InCycle() && group.runtime_stream != kStream1) {
// We process a chain with two groups when we see the group for
// kStream1.
continue;
}
ChainStartEnd start_end = group.GetChainStartEnd(computation);

ChainStartEnd start_end =
group.GetChainStartEnd(computation, p2p_group_map);

// Handle the group when we see the beginning of the chain.
if (start_end.first != hlo) {
Expand Down

0 comments on commit d82330b

Please sign in to comment.