Skip to content

Commit

Permalink
[xla][gpu] Change the point-to-point pipeliner to produce an intermed…
Browse files Browse the repository at this point in the history
…iate form

of pipelined code.

The pipeliner now rotates all point-to-point instructions for a communication
chain, including Send, Recv, SendDone and RecvDone, in a while-body. We will
add a post-scheduling pass to further transform such a pipelined loop by
pushing the SendDone and RecvDone to the next loop iteration.

Adjust the p2p-schedule-preparation pass to reflect this change.

PiperOrigin-RevId: 621959951
  • Loading branch information
bixia1 authored and tensorflower-gardener committed Apr 4, 2024
1 parent 8976ba9 commit c64eb74
Show file tree
Hide file tree
Showing 4 changed files with 559 additions and 532 deletions.
279 changes: 137 additions & 142 deletions third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc
Expand Up @@ -777,30 +777,18 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined1) {
HloModule test
while_cond {
param = (u32[], (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[])) parameter(0)
param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
count = get-tuple-element(param), index=0
ub = u32[] constant(25)
ROOT cond-result = pred[] compare(count, ub), direction=LT
}
while_body {
param = (u32[], (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[])) parameter(0)
param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
count = get-tuple-element(param), index=0
recv.1.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=1
recv-done.1 = (f32[1,1024,1024], token[]) recv-done(recv.1.q), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
}
recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0
send.1.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=2
send-done.1 = token[] send-done(send.1.q), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
}
recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0
c1 = u32[] constant(1)
new-count = u32[] add(count, c1)
Expand All @@ -820,17 +808,24 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined1) {
after-all.1 = token[] after-all()
send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1),
channel_id=1, frontend_attributes={
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
_xla_send_recv_pipeline="0"
}
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
_xla_send_recv_pipeline="0"
}
recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1,
frontend_attributes={
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
_xla_send_recv_pipeline="0"
}
ROOT body-result = (u32[], (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[])) tuple(new-count, recv.1, send.1)
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
_xla_send_recv_pipeline="0"
}
recv-done.1 = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
}
send-done.1 = token[] send-done(send.1), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
}
ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
tuple(new-count, recv-done.1, send-done.1)
}
ENTRY main {
Expand All @@ -841,35 +836,32 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined1) {
after-all.2 = token[] after-all()
recv.2 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.2), channel_id=1,
frontend_attributes={
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
_xla_send_recv_pipeline="0"
}
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
_xla_send_recv_pipeline="0"
}
send.2 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.2), channel_id=1,
frontend_attributes={
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
_xla_send_recv_pipeline="0"
}
while-init = (u32[], (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[])) tuple(c0, recv.2, send.2)
while-result = (u32[], (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[])) while(while-init),
body=while_body, condition=while_cond,
backend_config={"known_trip_count":{"n":"25"}}
recv.2.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=1
recv-done.2 = (f32[1,1024,1024], token[]) recv-done(recv.2.q), channel_id=1,
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
_xla_send_recv_pipeline="0"
}
recv-done.2 = (f32[1,1024,1024], token[]) recv-done(recv.2), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
_xla_send_recv_pipeline="0"
}
send.2.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=2
send-done.2 = token[] send-done(send.2.q), channel_id=1,
send-done.2 = token[] send-done(send.2), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
_xla_send_recv_pipeline="0"
}
while-init = (u32[], (f32[1,1024,1024], token[]), token[])
tuple(c0, recv-done.2, send-done.2)
while-result = (u32[], (f32[1,1024,1024], token[]), token[])
while(while-init),
body=while_body, condition=while_cond,
backend_config={"known_trip_count":{"n":"25"}}
ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done.2), index=0
recv-done.2.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result), index=1
ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0
}
)";

Expand All @@ -894,20 +886,23 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined1) {
}) -
instruction_sequence.begin();
};

EXPECT_TRUE(HasValidFingerprint(module.get()));
// The pipelined Send-Recv in the main.
EXPECT_LT(get_index("recv.2", main), get_index("while-result", main));
EXPECT_LT(get_index("send.2", main), get_index("while-result", main));
EXPECT_LT(get_index("while-result", main), get_index("recv-done.2", main));
EXPECT_LT(get_index("while-result", main), get_index("send-done.2", main));

// The pipelined Send-Recv in the while-body.
// The pipelined Send-Recv in the main. A pipelined Recv is scheduled right
// after its corresponding Send due to kForceEarly.
EXPECT_EQ(get_index("recv.2", main) + 1, get_index("send.2", main));
EXPECT_LT(get_index("send.2", main), get_index("recv-done.2", main));
EXPECT_LT(get_index("recv-done.2", main), get_index("send-done.2", main));
EXPECT_LT(get_index("send-done.2", main), get_index("while-result", main));

// The pipelined Send-Recv in the while-body. A pipelined Recv is scheduled
// right after its corresponding Send due to kForceEarly.
EXPECT_EQ(get_index("recv.1", while_body) + 1,
get_index("send.1", while_body));
EXPECT_LT(get_index("send.1", while_body),
get_index("recv-done.1", while_body));
EXPECT_LT(get_index("recv-done.1", while_body),
get_index("send-done.1", while_body));
EXPECT_LT(get_index("send-done.1", while_body),
get_index("recv.1", while_body));
EXPECT_LT(get_index("recv.1", while_body), get_index("send.1", while_body));
}

// Checks that with the dependence added by the gpu-hlo-scheduler, the
Expand All @@ -917,45 +912,22 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) {
HloModule test
while_cond {
param = (u32[], (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[])) parameter(0)
param = (u32[], (f32[1,1024,1024], token[]), token[],
(f32[1,1024,1024], token[]), token[]) parameter(0)
count = get-tuple-element(param), index=0
ub = u32[] constant(25)
ROOT cond-result = pred[] compare(count, ub), direction=LT
}
while_body {
param = (u32[], (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[])) parameter(0)
param = (u32[], (f32[1,1024,1024], token[]), token[],
(f32[1,1024,1024], token[]), token[]) parameter(0)
count = get-tuple-element(param), index=0
recv.0.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=1
recv-done.0 = (f32[1,1024,1024], token[]) recv-done(recv.0.q), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
}
recv-data.0 = f32[1, 1024, 1024] get-tuple-element(recv-done.0), index=0
send.0.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=2
send-done.0 = token[] send-done(send.0.q), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
}
recv.1.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=3
recv-done.1 = (f32[1,1024,1024], token[]) recv-done(recv.1.q), channel_id=2,
frontend_attributes={
_xla_send_recv_pipeline="1"
}
recv-data.1 = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0
send.1.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=4
send-done.1 = token[] send-done(send.1.q), channel_id=2,
frontend_attributes={
_xla_send_recv_pipeline="1"
}
recv-done.0.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
recv-data.0 = f32[1, 1024, 1024] get-tuple-element(recv-done.0.q), index=0
recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=3
recv-data.1 = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0
replica = u32[] replica-id()
constant0 = u32[] constant(0)
Expand All @@ -980,30 +952,46 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) {
after-all.0 = token[] after-all()
send.0 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.0),
channel_id=1, frontend_attributes={
_xla_send_recv_source_target_pairs="{{3,0}}",
_xla_send_recv_pipeline="0"
}
_xla_send_recv_source_target_pairs="{{3,0}}",
_xla_send_recv_pipeline="0"
}
recv.0 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.0), channel_id=1,
frontend_attributes={
_xla_send_recv_source_target_pairs="{{3,0}}",
_xla_send_recv_pipeline="0"
}
_xla_send_recv_source_target_pairs="{{3,0}}",
_xla_send_recv_pipeline="0"
}
recv-done.0 = (f32[1,1024,1024], token[]) recv-done(recv.0), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
}
send-done.0 = token[] send-done(send.0), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
}
after-all.1 = token[] after-all()
send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1),
channel_id=2, frontend_attributes={
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
_xla_send_recv_pipeline="1"
}
recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=2,
frontend_attributes={
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
_xla_send_recv_pipeline="1"
}
}
recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=2,
frontend_attributes={
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
_xla_send_recv_pipeline="1"
}
recv-done.1 = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=2,
frontend_attributes={
_xla_send_recv_pipeline="1"
}
send-done.1 = token[] send-done(send.1), channel_id=2,
frontend_attributes={
_xla_send_recv_pipeline="1"
}
ROOT body-result = (u32[], (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[])) tuple(new-count, recv.0, send.0, recv.1, send.1)
ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[],
(f32[1,1024,1024], token[]), token[])
tuple(new-count, recv-done.0, send-done.0, recv-done.1, send-done.1)
}
ENTRY main {
Expand All @@ -1022,6 +1010,14 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) {
_xla_send_recv_source_target_pairs="{{3,0}}",
_xla_send_recv_pipeline="0"
}
recv-done.2 = (f32[1,1024,1024], token[]) recv-done(recv.2), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
}
send-done.2 = token[] send-done(send.2), channel_id=1,
frontend_attributes={
_xla_send_recv_pipeline="0"
}
after-all.3 = token[] after-all()
recv.3 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.3), channel_id=2,
Expand All @@ -1034,41 +1030,26 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) {
_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
_xla_send_recv_pipeline="1"
}
while-init = (u32[], (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[])) tuple(c0, recv.2, send.2, recv.3, send.3)
while-result = (u32[], (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]),
(f32[1,1024,1024], u32[], token[])) while(while-init),
body=while_body, condition=while_cond,
backend_config={"known_trip_count":{"n":"25"}}
recv.2.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=1
recv-done.2 = (f32[1,1024,1024], token[]) recv-done(recv.2.q), channel_id=1,
recv-done.3 = (f32[1,1024,1024], token[]) recv-done(recv.3), channel_id=2,
frontend_attributes={
_xla_send_recv_pipeline="0"
_xla_send_recv_pipeline="1"
}
recv-data.2 = f32[1, 1024, 1024] get-tuple-element(recv-done.2), index=0
send.2.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=2
send-done.2 = token[] send-done(send.2.q), channel_id=1,
send-done.3 = token[] send-done(send.3), channel_id=2,
frontend_attributes={
_xla_send_recv_pipeline="0"
_xla_send_recv_pipeline="1"
}
recv.3.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=3
recv-done.3 = (f32[1,1024,1024], token[]) recv-done(recv.3.q), channel_id=2,
frontend_attributes={
_xla_send_recv_pipeline="1"
}
recv-data.3 = f32[1, 1024, 1024] get-tuple-element(recv-done.3), index=0
while-init = (u32[], (f32[1,1024,1024], token[]), token[],
(f32[1,1024,1024], token[]), token[]) tuple(c0, recv-done.2, send-done.2, recv-done.3, send-done.3)
while-result = (u32[], (f32[1,1024,1024], token[]), token[],
(f32[1,1024,1024], token[]), token[]) while(while-init),
body=while_body, condition=while_cond,
backend_config={"known_trip_count":{"n":"25"}}
send.3.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=4
send-done.3 = token[] send-done(send.3.q), channel_id=2,
frontend_attributes={
_xla_send_recv_pipeline="1"
}
recv-done.2.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result), index=1
recv-data.2 = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0
recv-done.3.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result), index=3
recv-data.3 = f32[1, 1024, 1024] get-tuple-element(recv-done.3.q), index=0
replica = u32[] replica-id()
constant0 = u32[] constant(0)
Expand Down Expand Up @@ -1101,18 +1082,32 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) {
};

EXPECT_TRUE(HasValidFingerprint(module.get()));
// The pipelined Send-Recv in the main.
EXPECT_LT(get_index("recv.2", main), get_index("while-result", main));
EXPECT_LT(get_index("send.2", main), get_index("while-result", main));
EXPECT_LT(get_index("while-result", main), get_index("recv-done.2", main));
EXPECT_LT(get_index("while-result", main), get_index("send-done.2", main));

// The pipelined Send-Recv in the while-body.
// The pipelined Send-Recv in the main. A pipelined Recv is scheduled right
// after its corresponding Send due to kForceEarly.
EXPECT_EQ(get_index("recv.2", main) + 1, get_index("send.2", main));
EXPECT_LT(get_index("send.2", main), get_index("recv.3", main));
EXPECT_EQ(get_index("recv.3", main) + 1, get_index("send.3", main));
EXPECT_LT(get_index("send.3", main), get_index("recv-done.2", main));
EXPECT_LT(get_index("recv-done.2", main), get_index("recv-done.3", main));
EXPECT_LT(get_index("recv-done.3", main), get_index("send-done.2", main));
EXPECT_LT(get_index("send-done.2", main), get_index("send-done.3", main));
EXPECT_LT(get_index("send-done.3", main), get_index("while-result", main));

// The pipelined Send-Recv in the while-body. A pipelined Recv is scheduled
// right after its corresponding Send due to kForceEarly.
EXPECT_EQ(get_index("recv.0", while_body) + 1,
get_index("send.0", while_body));
EXPECT_LT(get_index("send.0", while_body), get_index("recv.1", while_body));
EXPECT_EQ(get_index("recv.1", while_body) + 1,
get_index("send.1", while_body));
EXPECT_LT(get_index("send.1", while_body),
get_index("recv-done.0", while_body));
EXPECT_LT(get_index("recv-done.0", while_body),
get_index("recv-done.1", while_body));
EXPECT_LT(get_index("recv-done.1", while_body),
get_index("send-done.0", while_body));
EXPECT_LT(get_index("send-done.0", while_body),
get_index("send-done.1", while_body));
EXPECT_LT(get_index("send-done.1", while_body),
get_index("recv.1", while_body));
EXPECT_LT(get_index("recv.1", while_body), get_index("send.1", while_body));
}

TEST_F(GpuHloScheduleTest, SkipAlreadyScheduled) {
Expand Down

0 comments on commit c64eb74

Please sign in to comment.