@@ -17,12 +17,14 @@ limitations under the License.
17
17
18
18
#include < cstdint>
19
19
#include < memory>
20
+ #include < optional>
20
21
#include < set>
21
22
#include < utility>
22
23
#include < vector>
23
24
24
25
#include " absl/container/flat_hash_map.h"
25
26
#include " absl/container/flat_hash_set.h"
27
+ #include " absl/log/check.h"
26
28
#include " absl/log/log.h"
27
29
#include " absl/strings/string_view.h"
28
30
#include " xla/hlo/ir/hlo_casting_utils.h"
@@ -189,6 +191,8 @@ struct P2PGroupNode {
189
191
return send_stream;
190
192
}
191
193
194
+ int64_t GetChannel () const { return recv->channel_id ().value (); }
195
+
192
196
HloRecvDoneInstruction* recv_done = nullptr ;
193
197
HloSendDoneInstruction* send_done = nullptr ;
194
198
HloRecvInstruction* recv = nullptr ;
@@ -301,7 +305,7 @@ struct P2PGroup {
301
305
// Records the other group that forms a cycle with this group, assuming that
302
306
// we handle only two groups that form a cycle.
303
307
Status RecordComplementGroup (P2PGroupMap& p2p_group_map) {
304
- CHECK (complement_group == nullptr && runtime_stream == kStream1 );
308
+ CHECK (!complement_group_channel. has_value () && runtime_stream == kStream1 );
305
309
for (auto & [channel, p2p_group] : p2p_group_map) {
306
310
if (&p2p_group == this ||
307
311
p2p_group.ChildComputation () != ChildComputation ()) {
@@ -315,12 +319,14 @@ struct P2PGroup {
315
319
return Internal (
316
320
" Expected different pipeline stream for complement group" );
317
321
}
318
- complement_group = &p2p_group;
319
- p2p_group.complement_group = this ;
322
+ // Set the complement_group_channel for the current group.
323
+ complement_group_channel = channel;
324
+ // Set the complement_group_channel for the complement-group.
325
+ p2p_group.complement_group_channel = GetChannel ();
320
326
} else if (p2p_group.kind == kUnpipelined &&
321
- p2p_group.runtime_stream != kStream1 ) {
322
- complement_group = &p2p_group ;
323
- p2p_group.complement_group = this ;
327
+ p2p_group.runtime_stream == kStream0 ) {
328
+ complement_group_channel = channel ;
329
+ p2p_group.complement_group_channel = GetChannel () ;
324
330
}
325
331
}
326
332
return OkStatus ();
@@ -332,6 +338,7 @@ struct P2PGroup {
332
338
// Returns the child computation for the group.
333
339
HloComputation* ChildComputation () const { return GetChild ().computation ; }
334
340
341
+ int64_t GetChannel () const { return nodes[kUnpipelinedNodeIdx ].GetChannel (); }
335
342
P2PGroupNode& GetChild () { return nodes[kPipelinedChildNodeIdx ]; }
336
343
P2PGroupNode& GetParent () { return nodes[kPipelinedParentNodeIdx ]; }
337
344
const P2PGroupNode& GetChild () const { return nodes[kPipelinedChildNodeIdx ]; }
@@ -341,46 +348,53 @@ struct P2PGroup {
341
348
342
349
// Returns the start and end of a region marked by a pipelined chain in the
343
350
// given computation, which is the region with the pipelined P2P instructions.
344
- ChainStartEnd GetChainStartEnd (HloComputation* computation) const {
345
- if (kind == kUnpipelined ) {
346
- if (!InCycle ()) {
347
- return std::make_pair (GetChild ().recv , GetChild ().send_done );
348
- }
349
- CHECK (runtime_stream == kStream1 );
350
- return std::make_pair (complement_group->GetChild ().recv ,
351
- GetChild ().send_done );
352
- }
353
-
354
- CHECK (kind == kPipelined );
351
+ ChainStartEnd GetChainStartEnd (const HloComputation* computation,
352
+ const P2PGroupMap& p2p_group_map) const {
355
353
if (computation == ChildComputation ()) {
356
354
if (!InCycle ()) {
357
355
return std::make_pair (GetChild ().recv , GetChild ().send_done );
358
356
}
359
- CHECK (runtime_stream == kStream1 );
360
- return std::make_pair (complement_group->GetChild ().recv ,
361
- GetChild ().send_done );
357
+ if (runtime_stream == kStream1 ) {
358
+ return std::make_pair (
359
+ GetComplementGroup (p2p_group_map)->GetChild ().recv ,
360
+ GetChild ().send_done );
361
+ }
362
+ return std::make_pair (
363
+ GetChild ().recv ,
364
+ GetComplementGroup (p2p_group_map)->GetChild ().send_done );
362
365
}
363
366
364
- CHECK (computation == ParentComputation ());
367
+ CHECK (kind == kPipelined && computation == ParentComputation ());
365
368
if (!InCycle ()) {
366
369
return std::make_pair (GetParent ().recv , GetParent ().send_done );
367
370
}
368
- CHECK (runtime_stream == kStream1 );
369
- return std::make_pair (complement_group->GetParent ().recv ,
370
- GetParent ().send_done );
371
+ if (runtime_stream == kStream1 ) {
372
+ return std::make_pair (GetComplementGroup (p2p_group_map)->GetParent ().recv ,
373
+ GetParent ().send_done );
374
+ }
375
+ return std::make_pair (
376
+ GetParent ().recv ,
377
+ GetComplementGroup (p2p_group_map)->GetParent ().send_done );
371
378
}
372
379
373
380
HloInstruction* GetWhileOp () const {
374
381
return nodes[kPipelinedParentNodeIdx ].while_loop ;
375
382
}
376
383
377
- bool InCycle () const { return complement_group != nullptr ; }
378
-
384
+ bool InCycle () const { return complement_group_channel.has_value (); }
385
+ P2PGroup* GetComplementGroup (P2PGroupMap& p2p_group_map) const {
386
+ CHECK (InCycle ());
387
+ return &p2p_group_map.at (*complement_group_channel);
388
+ }
389
+ const P2PGroup* GetComplementGroup (const P2PGroupMap& p2p_group_map) const {
390
+ CHECK (InCycle ());
391
+ return &p2p_group_map.at (*complement_group_channel);
392
+ }
379
393
P2PGroupKind kind = kUnpipelined ;
380
394
P2PGroupNode nodes[2 ];
381
395
P2PRuntimeStream runtime_stream = kUnknown ;
382
- // Another P2PGroup that forms a cycle with this group.
383
- P2PGroup* complement_group = nullptr ;
396
+ // The channel id for another P2PGroup that forms a cycle with this group.
397
+ std::optional< int64_t > complement_group_channel = std::nullopt ;
384
398
};
385
399
386
400
bool MayInvokeCollectiveOp (
@@ -506,9 +520,11 @@ Status ConnectP2P2NodeChain(const P2PGroupNode& node0,
506
520
// while-body computation, we enforce this ordering:
507
521
// recv.0 => send.0 => recv.1 => send.1 =>
508
522
// recv-done.0 => recv-done.1 => send-done.0 => send-done.1
509
- Status ConnectPipelined2P2PChild (const P2PGroup& p2p_group) {
510
- return ConnectP2P2NodeChain (p2p_group.complement_group ->GetChild (),
511
- p2p_group.GetChild ());
523
+ Status ConnectPipelined2P2PChild (const P2PGroup& p2p_group,
524
+ const P2PGroupMap& p2p_group_map) {
525
+ return ConnectP2P2NodeChain (
526
+ p2p_group.GetComplementGroup (p2p_group_map)->GetChild (),
527
+ p2p_group.GetChild ());
512
528
}
513
529
514
530
// For a pipelined Send-Recv chain with one group in the while-body calling
@@ -522,20 +538,24 @@ Status ConnectPipelined1P2PParent(const P2PGroup& p2p_group) {
522
538
// in the while-body calling computation, we enforce this ordering:
523
539
// recv.0 => send.0 => recv.1 => send.1 => =>
524
540
// recv-done.0 => recv-done.1 => send-done.0 => send-done.1
525
- Status ConnectPipelined2P2PParent (const P2PGroup& p2p_group) {
526
- return ConnectP2P2NodeChain (p2p_group.complement_group ->GetParent (),
527
- p2p_group.GetParent ());
541
+ Status ConnectPipelined2P2PParent (const P2PGroup& p2p_group,
542
+ const P2PGroupMap& p2p_group_map) {
543
+ return ConnectP2P2NodeChain (
544
+ p2p_group.GetComplementGroup (p2p_group_map)->GetParent (),
545
+ p2p_group.GetParent ());
528
546
}
529
547
530
548
// For a Send-Recv chain with two channel groups forming a cycle in a while-body
531
549
// annotated for pipelining but not pipelined (due to skip pipelining pass), we
532
550
// enforece this ordering:
533
551
// recv.0 => send.0 => recv.1 => send.1 =>
534
552
// recv-done.0 => recv-done.1 => send-done.0 => send-done.1
535
- Status ConnectUnpipelined2P2P (const P2PGroup& p2p_group) {
553
+ Status ConnectUnpipelined2P2P (const P2PGroup& p2p_group,
554
+ const P2PGroupMap& p2p_group_map) {
536
555
CHECK (p2p_group.runtime_stream == kStream1 );
537
- return ConnectP2P2NodeChain (p2p_group.complement_group ->GetChild (),
538
- p2p_group.GetChild ());
556
+ return ConnectP2P2NodeChain (
557
+ p2p_group.GetComplementGroup (p2p_group_map)->GetChild (),
558
+ p2p_group.GetChild ());
539
559
}
540
560
541
561
// Collects P2P send-done and recv-done instructions from the computation,
@@ -585,7 +605,7 @@ Status GatherP2PGroupsAndCollectiveInfo(
585
605
}
586
606
// We can't rely on the operation on p2p_group_map above to find out
587
607
// whether it is the first time to handle this channel for the current
588
- // computation, as we may drop information in the present of kUncognized
608
+ // computation, as we may drop information in the present of kUnrecognized
589
609
// groups.
590
610
auto p2p_in_comp = p2p_in_computation.find (computation);
591
611
if (p2p_in_comp == p2p_in_computation.end ()) {
@@ -634,11 +654,9 @@ Status GatherP2PGroupsAndCollectiveInfo(
634
654
for (auto & [channel, p2p_group] : p2p_group_map) {
635
655
if ((p2p_group.kind == kPipelined &&
636
656
p2p_group.ParentComputation () != computation) ||
637
- p2p_group.complement_group != nullptr ||
638
- p2p_group.runtime_stream != kStream1 ) {
657
+ p2p_group.InCycle () || p2p_group.runtime_stream != kStream1 ) {
639
658
continue ;
640
659
}
641
-
642
660
TF_RETURN_IF_ERROR (p2p_group.RecordComplementGroup (p2p_group_map));
643
661
}
644
662
@@ -670,12 +688,12 @@ absl::StatusOr<std::pair<int, const P2PGroup*>> ConnectP2PChain(
670
688
if (!p2p_group.InCycle ()) {
671
689
TF_RETURN_IF_ERROR (ConnectUnpipelinedP2P (p2p_group));
672
690
} else if (p2p_group.runtime_stream == kStream1 ) {
673
- TF_RETURN_IF_ERROR (ConnectUnpipelined2P2P (p2p_group));
691
+ TF_RETURN_IF_ERROR (ConnectUnpipelined2P2P (p2p_group, p2p_group_map ));
674
692
}
675
693
continue ;
676
694
}
677
695
678
- if (p2p_group.complement_group == nullptr ) {
696
+ if (! p2p_group.InCycle () ) {
679
697
if (computation == p2p_group.ParentComputation ()) {
680
698
TF_RETURN_IF_ERROR (ConnectPipelined1P2PParent (p2p_group));
681
699
} else {
@@ -696,15 +714,15 @@ absl::StatusOr<std::pair<int, const P2PGroup*>> ConnectP2PChain(
696
714
}
697
715
698
716
if (computation == p2p_group.ParentComputation ()) {
699
- TF_RETURN_IF_ERROR (ConnectPipelined2P2PParent (p2p_group));
717
+ TF_RETURN_IF_ERROR (ConnectPipelined2P2PParent (p2p_group, p2p_group_map ));
700
718
} else {
701
719
if (pipelined_group != nullptr ) {
702
720
return Internal (
703
721
" Expected only two pipelined groups forming a cycle in a "
704
722
" while-body" );
705
723
}
706
724
pipelined_group = &p2p_group;
707
- TF_RETURN_IF_ERROR (ConnectPipelined2P2PChild (p2p_group));
725
+ TF_RETURN_IF_ERROR (ConnectPipelined2P2PChild (p2p_group, p2p_group_map ));
708
726
}
709
727
}
710
728
return std::make_pair (num_p2p_chains, pipelined_group);
@@ -738,8 +756,7 @@ Status LinearizeCollectivesWithOtherP2P(
738
756
const std::vector<HloInstruction*>::iterator& end_iter,
739
757
HloReachabilityMap* reachability) {
740
758
HloComputation* computation = (*chain_start_iter)->parent ();
741
- ChainStartEnd start_end = group.GetChainStartEnd (computation);
742
-
759
+ ChainStartEnd start_end = group.GetChainStartEnd (computation, p2p_group_map);
743
760
// We refer to the P2P chain represented by `group` chain A.
744
761
for (auto it = begin_iter; it != end_iter; ++it) {
745
762
HloInstruction* hlo = *it;
@@ -758,8 +775,8 @@ Status LinearizeCollectivesWithOtherP2P(
758
775
// LinearizeCollectivesWithPipelinedP2PChild.
759
776
continue ;
760
777
}
761
-
762
- ChainStartEnd cur_start_end = cur_group.GetChainStartEnd (computation);
778
+ ChainStartEnd cur_start_end =
779
+ cur_group.GetChainStartEnd (computation, p2p_group_map );
763
780
if (cur_start_end.first != hlo) {
764
781
// We will linearize the two chains when we see the first instruction in
765
782
// chain B.
@@ -822,7 +839,7 @@ Status LinearizeCollectivesWithPipelinedP2PChild(
822
839
const P2PGroupMap& p2p_group_map, const P2PGroup& group,
823
840
const CollectiveInComputation& collective_in_computation,
824
841
HloComputation* computation, HloReachabilityMap* reachability) {
825
- ChainStartEnd start_end = group.GetChainStartEnd (computation);
842
+ ChainStartEnd start_end = group.GetChainStartEnd (computation, p2p_group_map );
826
843
827
844
// If an hlo may invoke collective operation, we add control dependence to
828
845
// make sure that the hlo is scheduled before the pipelined chain starts.
@@ -852,7 +869,8 @@ Status LinearizeCollectivesWithPipelinedP2PChild(
852
869
continue ;
853
870
}
854
871
855
- ChainStartEnd cur_start_end = cur_group.GetChainStartEnd (computation);
872
+ ChainStartEnd cur_start_end =
873
+ cur_group.GetChainStartEnd (computation, p2p_group_map);
856
874
TF_RETURN_IF_ERROR (
857
875
OrderBefore (reachability, cur_start_end.second , start_end.first ));
858
876
@@ -957,12 +975,9 @@ absl::StatusOr<bool> P2PSchedulePreparation::Run(
957
975
// to other collectives.
958
976
continue ;
959
977
}
960
- if (group.InCycle () && group.runtime_stream != kStream1 ) {
961
- // We process a chain with two groups when we see the group for
962
- // kStream1.
963
- continue ;
964
- }
965
- ChainStartEnd start_end = group.GetChainStartEnd (computation);
978
+
979
+ ChainStartEnd start_end =
980
+ group.GetChainStartEnd (computation, p2p_group_map);
966
981
967
982
// Handle the group when we see the beginning of the chain.
968
983
if (start_end.first != hlo) {
0 commit comments