Skip to content

Commit d82330b

Browse files
bixia1tensorflower-gardener
authored andcommitted
[xla][gpu] Fixed two problems in p2p-schedule-preparation.
Previously, we record the address of a P2PGroupNode for a complement-group that forms a cycle with a group. This address can be changed as P2PGroupNode is an entity of a map. Fixing this problem by recording the channel ID of the complement-group instead. Make sure that a group and its complement-group are both annotated with a stream for pipelining. PiperOrigin-RevId: 625770614
1 parent c98067c commit d82330b

File tree

4 files changed

+177
-58
lines changed

4 files changed

+177
-58
lines changed

third_party/xla/xla/service/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1266,14 +1266,14 @@ cc_library(
12661266
":collective_ops_utils",
12671267
":hlo_pass",
12681268
"//xla:status",
1269-
"//xla:statusor",
12701269
"//xla:util",
12711270
"//xla/hlo/ir:hlo",
12721271
"//xla/hlo/ir:hlo_reachability",
12731272
"//xla/hlo/utils:hlo_query",
12741273
"@com_google_absl//absl/container:flat_hash_map",
12751274
"@com_google_absl//absl/container:flat_hash_set",
12761275
"@com_google_absl//absl/log",
1276+
"@com_google_absl//absl/log:check",
12771277
"@com_google_absl//absl/strings",
12781278
"@local_tsl//tsl/platform:errors",
12791279
"@local_tsl//tsl/platform:statusor",

third_party/xla/xla/service/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4176,6 +4176,7 @@ xla_cc_test(
41764176
"@com_google_absl//absl/log",
41774177
"@com_google_absl//absl/strings:string_view",
41784178
"@com_google_googletest//:gtest",
4179+
"@local_tsl//tsl/platform:env",
41794180
"@local_tsl//tsl/platform:status",
41804181
"@local_tsl//tsl/platform:statusor",
41814182
"@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc",

third_party/xla/xla/service/p2p_schedule_preparation.cc

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ limitations under the License.
1717

1818
#include <cstdint>
1919
#include <memory>
20+
#include <optional>
2021
#include <set>
2122
#include <utility>
2223
#include <vector>
2324

2425
#include "absl/container/flat_hash_map.h"
2526
#include "absl/container/flat_hash_set.h"
27+
#include "absl/log/check.h"
2628
#include "absl/log/log.h"
2729
#include "absl/strings/string_view.h"
2830
#include "xla/hlo/ir/hlo_casting_utils.h"
@@ -189,6 +191,8 @@ struct P2PGroupNode {
189191
return send_stream;
190192
}
191193

194+
int64_t GetChannel() const { return recv->channel_id().value(); }
195+
192196
HloRecvDoneInstruction* recv_done = nullptr;
193197
HloSendDoneInstruction* send_done = nullptr;
194198
HloRecvInstruction* recv = nullptr;
@@ -301,7 +305,7 @@ struct P2PGroup {
301305
// Records the other group that forms a cycle with this group, assuming that
302306
// we handle only two groups that form a cycle.
303307
Status RecordComplementGroup(P2PGroupMap& p2p_group_map) {
304-
CHECK(complement_group == nullptr && runtime_stream == kStream1);
308+
CHECK(!complement_group_channel.has_value() && runtime_stream == kStream1);
305309
for (auto& [channel, p2p_group] : p2p_group_map) {
306310
if (&p2p_group == this ||
307311
p2p_group.ChildComputation() != ChildComputation()) {
@@ -315,12 +319,14 @@ struct P2PGroup {
315319
return Internal(
316320
"Expected different pipeline stream for complement group");
317321
}
318-
complement_group = &p2p_group;
319-
p2p_group.complement_group = this;
322+
// Set the complement_group_channel for the current group.
323+
complement_group_channel = channel;
324+
// Set the complement_group_channel for the complement-group.
325+
p2p_group.complement_group_channel = GetChannel();
320326
} else if (p2p_group.kind == kUnpipelined &&
321-
p2p_group.runtime_stream != kStream1) {
322-
complement_group = &p2p_group;
323-
p2p_group.complement_group = this;
327+
p2p_group.runtime_stream == kStream0) {
328+
complement_group_channel = channel;
329+
p2p_group.complement_group_channel = GetChannel();
324330
}
325331
}
326332
return OkStatus();
@@ -332,6 +338,7 @@ struct P2PGroup {
332338
// Returns the child computation for the group.
333339
HloComputation* ChildComputation() const { return GetChild().computation; }
334340

341+
int64_t GetChannel() const { return nodes[kUnpipelinedNodeIdx].GetChannel(); }
335342
P2PGroupNode& GetChild() { return nodes[kPipelinedChildNodeIdx]; }
336343
P2PGroupNode& GetParent() { return nodes[kPipelinedParentNodeIdx]; }
337344
const P2PGroupNode& GetChild() const { return nodes[kPipelinedChildNodeIdx]; }
@@ -341,46 +348,53 @@ struct P2PGroup {
341348

342349
// Returns the start and end of a region marked by a pipelined chain in the
343350
// given computation, which is the region with the pipelined P2P instructions.
344-
ChainStartEnd GetChainStartEnd(HloComputation* computation) const {
345-
if (kind == kUnpipelined) {
346-
if (!InCycle()) {
347-
return std::make_pair(GetChild().recv, GetChild().send_done);
348-
}
349-
CHECK(runtime_stream == kStream1);
350-
return std::make_pair(complement_group->GetChild().recv,
351-
GetChild().send_done);
352-
}
353-
354-
CHECK(kind == kPipelined);
351+
ChainStartEnd GetChainStartEnd(const HloComputation* computation,
352+
const P2PGroupMap& p2p_group_map) const {
355353
if (computation == ChildComputation()) {
356354
if (!InCycle()) {
357355
return std::make_pair(GetChild().recv, GetChild().send_done);
358356
}
359-
CHECK(runtime_stream == kStream1);
360-
return std::make_pair(complement_group->GetChild().recv,
361-
GetChild().send_done);
357+
if (runtime_stream == kStream1) {
358+
return std::make_pair(
359+
GetComplementGroup(p2p_group_map)->GetChild().recv,
360+
GetChild().send_done);
361+
}
362+
return std::make_pair(
363+
GetChild().recv,
364+
GetComplementGroup(p2p_group_map)->GetChild().send_done);
362365
}
363366

364-
CHECK(computation == ParentComputation());
367+
CHECK(kind == kPipelined && computation == ParentComputation());
365368
if (!InCycle()) {
366369
return std::make_pair(GetParent().recv, GetParent().send_done);
367370
}
368-
CHECK(runtime_stream == kStream1);
369-
return std::make_pair(complement_group->GetParent().recv,
370-
GetParent().send_done);
371+
if (runtime_stream == kStream1) {
372+
return std::make_pair(GetComplementGroup(p2p_group_map)->GetParent().recv,
373+
GetParent().send_done);
374+
}
375+
return std::make_pair(
376+
GetParent().recv,
377+
GetComplementGroup(p2p_group_map)->GetParent().send_done);
371378
}
372379

373380
HloInstruction* GetWhileOp() const {
374381
return nodes[kPipelinedParentNodeIdx].while_loop;
375382
}
376383

377-
bool InCycle() const { return complement_group != nullptr; }
378-
384+
bool InCycle() const { return complement_group_channel.has_value(); }
385+
P2PGroup* GetComplementGroup(P2PGroupMap& p2p_group_map) const {
386+
CHECK(InCycle());
387+
return &p2p_group_map.at(*complement_group_channel);
388+
}
389+
const P2PGroup* GetComplementGroup(const P2PGroupMap& p2p_group_map) const {
390+
CHECK(InCycle());
391+
return &p2p_group_map.at(*complement_group_channel);
392+
}
379393
P2PGroupKind kind = kUnpipelined;
380394
P2PGroupNode nodes[2];
381395
P2PRuntimeStream runtime_stream = kUnknown;
382-
// Another P2PGroup that forms a cycle with this group.
383-
P2PGroup* complement_group = nullptr;
396+
// The channel id for another P2PGroup that forms a cycle with this group.
397+
std::optional<int64_t> complement_group_channel = std::nullopt;
384398
};
385399

386400
bool MayInvokeCollectiveOp(
@@ -506,9 +520,11 @@ Status ConnectP2P2NodeChain(const P2PGroupNode& node0,
506520
// while-body computation, we enforce this ordering:
507521
// recv.0 => send.0 => recv.1 => send.1 =>
508522
// recv-done.0 => recv-done.1 => send-done.0 => send-done.1
509-
Status ConnectPipelined2P2PChild(const P2PGroup& p2p_group) {
510-
return ConnectP2P2NodeChain(p2p_group.complement_group->GetChild(),
511-
p2p_group.GetChild());
523+
Status ConnectPipelined2P2PChild(const P2PGroup& p2p_group,
524+
const P2PGroupMap& p2p_group_map) {
525+
return ConnectP2P2NodeChain(
526+
p2p_group.GetComplementGroup(p2p_group_map)->GetChild(),
527+
p2p_group.GetChild());
512528
}
513529

514530
// For a pipelined Send-Recv chain with one group in the while-body calling
@@ -522,20 +538,24 @@ Status ConnectPipelined1P2PParent(const P2PGroup& p2p_group) {
522538
// in the while-body calling computation, we enforce this ordering:
523539
// recv.0 => send.0 => recv.1 => send.1 => =>
524540
// recv-done.0 => recv-done.1 => send-done.0 => send-done.1
525-
Status ConnectPipelined2P2PParent(const P2PGroup& p2p_group) {
526-
return ConnectP2P2NodeChain(p2p_group.complement_group->GetParent(),
527-
p2p_group.GetParent());
541+
Status ConnectPipelined2P2PParent(const P2PGroup& p2p_group,
542+
const P2PGroupMap& p2p_group_map) {
543+
return ConnectP2P2NodeChain(
544+
p2p_group.GetComplementGroup(p2p_group_map)->GetParent(),
545+
p2p_group.GetParent());
528546
}
529547

530548
// For a Send-Recv chain with two channel groups forming a cycle in a while-body
531549
// annotated for pipelining but not pipelined (due to skip pipelining pass), we
532550
// enforece this ordering:
533551
// recv.0 => send.0 => recv.1 => send.1 =>
534552
// recv-done.0 => recv-done.1 => send-done.0 => send-done.1
535-
Status ConnectUnpipelined2P2P(const P2PGroup& p2p_group) {
553+
Status ConnectUnpipelined2P2P(const P2PGroup& p2p_group,
554+
const P2PGroupMap& p2p_group_map) {
536555
CHECK(p2p_group.runtime_stream == kStream1);
537-
return ConnectP2P2NodeChain(p2p_group.complement_group->GetChild(),
538-
p2p_group.GetChild());
556+
return ConnectP2P2NodeChain(
557+
p2p_group.GetComplementGroup(p2p_group_map)->GetChild(),
558+
p2p_group.GetChild());
539559
}
540560

541561
// Collects P2P send-done and recv-done instructions from the computation,
@@ -585,7 +605,7 @@ Status GatherP2PGroupsAndCollectiveInfo(
585605
}
586606
// We can't rely on the operation on p2p_group_map above to find out
587607
// whether it is the first time to handle this channel for the current
588-
// computation, as we may drop information in the present of kUncognized
608+
// computation, as we may drop information in the present of kUnrecognized
589609
// groups.
590610
auto p2p_in_comp = p2p_in_computation.find(computation);
591611
if (p2p_in_comp == p2p_in_computation.end()) {
@@ -634,11 +654,9 @@ Status GatherP2PGroupsAndCollectiveInfo(
634654
for (auto& [channel, p2p_group] : p2p_group_map) {
635655
if ((p2p_group.kind == kPipelined &&
636656
p2p_group.ParentComputation() != computation) ||
637-
p2p_group.complement_group != nullptr ||
638-
p2p_group.runtime_stream != kStream1) {
657+
p2p_group.InCycle() || p2p_group.runtime_stream != kStream1) {
639658
continue;
640659
}
641-
642660
TF_RETURN_IF_ERROR(p2p_group.RecordComplementGroup(p2p_group_map));
643661
}
644662

@@ -670,12 +688,12 @@ absl::StatusOr<std::pair<int, const P2PGroup*>> ConnectP2PChain(
670688
if (!p2p_group.InCycle()) {
671689
TF_RETURN_IF_ERROR(ConnectUnpipelinedP2P(p2p_group));
672690
} else if (p2p_group.runtime_stream == kStream1) {
673-
TF_RETURN_IF_ERROR(ConnectUnpipelined2P2P(p2p_group));
691+
TF_RETURN_IF_ERROR(ConnectUnpipelined2P2P(p2p_group, p2p_group_map));
674692
}
675693
continue;
676694
}
677695

678-
if (p2p_group.complement_group == nullptr) {
696+
if (!p2p_group.InCycle()) {
679697
if (computation == p2p_group.ParentComputation()) {
680698
TF_RETURN_IF_ERROR(ConnectPipelined1P2PParent(p2p_group));
681699
} else {
@@ -696,15 +714,15 @@ absl::StatusOr<std::pair<int, const P2PGroup*>> ConnectP2PChain(
696714
}
697715

698716
if (computation == p2p_group.ParentComputation()) {
699-
TF_RETURN_IF_ERROR(ConnectPipelined2P2PParent(p2p_group));
717+
TF_RETURN_IF_ERROR(ConnectPipelined2P2PParent(p2p_group, p2p_group_map));
700718
} else {
701719
if (pipelined_group != nullptr) {
702720
return Internal(
703721
"Expected only two pipelined groups forming a cycle in a "
704722
"while-body");
705723
}
706724
pipelined_group = &p2p_group;
707-
TF_RETURN_IF_ERROR(ConnectPipelined2P2PChild(p2p_group));
725+
TF_RETURN_IF_ERROR(ConnectPipelined2P2PChild(p2p_group, p2p_group_map));
708726
}
709727
}
710728
return std::make_pair(num_p2p_chains, pipelined_group);
@@ -738,8 +756,7 @@ Status LinearizeCollectivesWithOtherP2P(
738756
const std::vector<HloInstruction*>::iterator& end_iter,
739757
HloReachabilityMap* reachability) {
740758
HloComputation* computation = (*chain_start_iter)->parent();
741-
ChainStartEnd start_end = group.GetChainStartEnd(computation);
742-
759+
ChainStartEnd start_end = group.GetChainStartEnd(computation, p2p_group_map);
743760
// We refer to the P2P chain represented by `group` chain A.
744761
for (auto it = begin_iter; it != end_iter; ++it) {
745762
HloInstruction* hlo = *it;
@@ -758,8 +775,8 @@ Status LinearizeCollectivesWithOtherP2P(
758775
// LinearizeCollectivesWithPipelinedP2PChild.
759776
continue;
760777
}
761-
762-
ChainStartEnd cur_start_end = cur_group.GetChainStartEnd(computation);
778+
ChainStartEnd cur_start_end =
779+
cur_group.GetChainStartEnd(computation, p2p_group_map);
763780
if (cur_start_end.first != hlo) {
764781
// We will linearize the two chains when we see the first instruction in
765782
// chain B.
@@ -822,7 +839,7 @@ Status LinearizeCollectivesWithPipelinedP2PChild(
822839
const P2PGroupMap& p2p_group_map, const P2PGroup& group,
823840
const CollectiveInComputation& collective_in_computation,
824841
HloComputation* computation, HloReachabilityMap* reachability) {
825-
ChainStartEnd start_end = group.GetChainStartEnd(computation);
842+
ChainStartEnd start_end = group.GetChainStartEnd(computation, p2p_group_map);
826843

827844
// If an hlo may invoke collective operation, we add control dependence to
828845
// make sure that the hlo is scheduled before the pipelined chain starts.
@@ -852,7 +869,8 @@ Status LinearizeCollectivesWithPipelinedP2PChild(
852869
continue;
853870
}
854871

855-
ChainStartEnd cur_start_end = cur_group.GetChainStartEnd(computation);
872+
ChainStartEnd cur_start_end =
873+
cur_group.GetChainStartEnd(computation, p2p_group_map);
856874
TF_RETURN_IF_ERROR(
857875
OrderBefore(reachability, cur_start_end.second, start_end.first));
858876

@@ -957,12 +975,9 @@ absl::StatusOr<bool> P2PSchedulePreparation::Run(
957975
// to other collectives.
958976
continue;
959977
}
960-
if (group.InCycle() && group.runtime_stream != kStream1) {
961-
// We process a chain with two groups when we see the group for
962-
// kStream1.
963-
continue;
964-
}
965-
ChainStartEnd start_end = group.GetChainStartEnd(computation);
978+
979+
ChainStartEnd start_end =
980+
group.GetChainStartEnd(computation, p2p_group_map);
966981

967982
// Handle the group when we see the beginning of the chain.
968983
if (start_end.first != hlo) {

0 commit comments

Comments
 (0)