Skip to content

Commit f4f3a55

Browse files
committed
[XLA/GPU] Add CustomCallSchedule to give schedule hints to custom-calls.
Add schedule hints EARLY_AS_POSSIBLE and LATE_AS_POSSIBLE to custom-calls. This supports a custom-call case, where a logical operation can be lowered into two HLOs (e.g., PerformX and PerformXDone). We can utilize this mechanism to either hide host latencies between the pair of the custom-calls or the two calls can more accurately identify the def-use relationship (typically PerformX is scheduled right after all of its producers have been scheduled and PerformXDone is scheduled right before its first consumer.)
1 parent a2a607d commit f4f3a55

15 files changed

+318
-58
lines changed

tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
135135
bool has_side_effect,
136136
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
137137
output_operand_aliasing,
138-
const Literal* literal) {
138+
const Literal* literal, CustomCallSchedule /*schedule*/) {
139139
if (operand_shapes_with_layout.has_value())
140140
return Unimplemented(
141141
"CustomCall doesn't support operands shapes with layout");

tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class MlirHloBuilder : public XlaBuilder {
138138
bool has_side_effect,
139139
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
140140
output_operand_aliasing,
141-
const Literal* literal) override;
141+
const Literal* literal, CustomCallSchedule schedule) override;
142142

143143
StatusOr<XlaOp> ReduceInternal(
144144
const Shape& shape, absl::Span<const XlaOp> all_operands,

tensorflow/compiler/xla/client/xla_builder.cc

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,7 +1893,7 @@ XlaOp XlaBuilder::CustomCall(
18931893
bool has_side_effect,
18941894
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
18951895
output_operand_aliasing,
1896-
const Literal* literal) {
1896+
const Literal* literal, CustomCallSchedule schedule) {
18971897
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
18981898
if (absl::StartsWith(call_target_name, "$")) {
18991899
return InvalidArgument(
@@ -1926,7 +1926,7 @@ XlaOp XlaBuilder::CustomCall(
19261926
}
19271927
return CustomCallInternal(call_target_name, operands, shape, opaque,
19281928
operand_shapes_with_layout, has_side_effect,
1929-
output_operand_aliasing, literal);
1929+
output_operand_aliasing, literal, schedule);
19301930
});
19311931
}
19321932

@@ -1937,7 +1937,7 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
19371937
bool has_side_effect,
19381938
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
19391939
output_operand_aliasing,
1940-
const Literal* literal) {
1940+
const Literal* literal, CustomCallSchedule schedule) {
19411941
HloInstructionProto instr;
19421942
*instr.mutable_shape() = shape.ToProto();
19431943
instr.set_custom_call_target(call_target_name);
@@ -1962,6 +1962,7 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
19621962
aliasing->add_output_shape_index(index);
19631963
}
19641964
}
1965+
instr.set_custom_call_schedule(schedule);
19651966
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
19661967
}
19671968

@@ -1972,7 +1973,7 @@ XlaOp XlaBuilder::CustomCall(
19721973
bool has_side_effect,
19731974
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
19741975
output_operand_aliasing,
1975-
const Literal* literal) {
1976+
const Literal* literal, CustomCallSchedule schedule) {
19761977
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
19771978
HloInstructionProto instr;
19781979
if (absl::StartsWith(call_target_name, "$")) {
@@ -2023,6 +2024,7 @@ XlaOp XlaBuilder::CustomCall(
20232024
aliasing->add_output_shape_index(index);
20242025
}
20252026
}
2027+
instr.set_custom_call_schedule(schedule);
20262028
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
20272029
});
20282030
}
@@ -4192,10 +4194,11 @@ XlaOp CustomCall(
41924194
bool has_side_effect,
41934195
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
41944196
output_operand_aliasing,
4195-
const Literal* literal) {
4197+
const Literal* literal, CustomCallSchedule schedule) {
41964198
return builder->CustomCall(call_target_name, operands, shape, opaque,
41974199
/*operand_shapes_with_layout=*/absl::nullopt,
4198-
has_side_effect, output_operand_aliasing, literal);
4200+
has_side_effect, output_operand_aliasing, literal,
4201+
schedule);
41994202
}
42004203

42014204
XlaOp CustomCallWithComputation(
@@ -4204,11 +4207,11 @@ XlaOp CustomCallWithComputation(
42044207
const Shape& shape, const string& opaque, bool has_side_effect,
42054208
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
42064209
output_operand_aliasing,
4207-
const Literal* literal) {
4208-
return builder->CustomCall(call_target_name, operands, computation, shape,
4209-
opaque,
4210-
/*operand_shapes_with_layout=*/absl::nullopt,
4211-
has_side_effect, output_operand_aliasing, literal);
4210+
const Literal* literal, CustomCallSchedule schedule) {
4211+
return builder->CustomCall(
4212+
call_target_name, operands, computation, shape, opaque,
4213+
/*operand_shapes_with_layout=*/absl::nullopt, has_side_effect,
4214+
output_operand_aliasing, literal, schedule);
42124215
}
42134216

42144217
XlaOp CustomCallWithLayout(
@@ -4218,10 +4221,10 @@ XlaOp CustomCallWithLayout(
42184221
bool has_side_effect,
42194222
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
42204223
output_operand_aliasing,
4221-
const Literal* literal) {
4224+
const Literal* literal, CustomCallSchedule schedule) {
42224225
return builder->CustomCall(call_target_name, operands, shape, opaque,
42234226
operand_shapes_with_layout, has_side_effect,
4224-
output_operand_aliasing, literal);
4227+
output_operand_aliasing, literal, schedule);
42254228
}
42264229

42274230
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,

tensorflow/compiler/xla/client/xla_builder.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ class XlaBuilder {
637637
bool has_side_effect,
638638
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
639639
output_operand_aliasing,
640-
const Literal* literal);
640+
const Literal* literal, CustomCallSchedule schedule);
641641

642642
// Internal version of CustomCall without computation that doesn't do op
643643
// specific error handling and expects arguments to be legal. CustomCall
@@ -649,7 +649,7 @@ class XlaBuilder {
649649
bool has_side_effect,
650650
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
651651
output_operand_aliasing,
652-
const Literal* literal);
652+
const Literal* literal, CustomCallSchedule schedule);
653653

654654
XlaOp CustomCall(
655655
const string& call_target_name, absl::Span<const XlaOp> operands,
@@ -659,7 +659,7 @@ class XlaBuilder {
659659
bool has_side_effect,
660660
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
661661
output_operand_aliasing,
662-
const Literal* literal);
662+
const Literal* literal, CustomCallSchedule schedule);
663663

664664
XlaOp Reduce(XlaOp operand, XlaOp init_value,
665665
const XlaComputation& computation,
@@ -1208,22 +1208,22 @@ class XlaBuilder {
12081208
const string& opaque, bool has_side_effect,
12091209
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
12101210
output_operand_aliasing,
1211-
const Literal* literal);
1211+
const Literal* literal, CustomCallSchedule schedule);
12121212
friend XlaOp CustomCallWithComputation(
12131213
XlaBuilder* builder, const string& call_target_name,
12141214
absl::Span<const XlaOp> operands, const XlaComputation& computation,
12151215
const Shape& shape, const string& opaque, bool has_side_effect,
12161216
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
12171217
output_operand_aliasing,
1218-
const Literal* literal);
1218+
const Literal* literal, CustomCallSchedule schedule);
12191219
friend XlaOp CustomCallWithLayout(
12201220
XlaBuilder* builder, const string& call_target_name,
12211221
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
12221222
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
12231223
bool has_side_effect,
12241224
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
12251225
output_operand_aliasing,
1226-
const Literal* literal);
1226+
const Literal* literal, CustomCallSchedule schedule);
12271227
friend XlaOp Complex(XlaOp real, XlaOp imag,
12281228
absl::Span<const int64> broadcast_dimensions);
12291229
friend XlaOp Conj(XlaOp operand);
@@ -2041,7 +2041,8 @@ XlaOp CustomCall(
20412041
const string& opaque = "", bool has_side_effect = false,
20422042
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
20432043
output_operand_aliasing = {},
2044-
const Literal* literal = nullptr);
2044+
const Literal* literal = nullptr,
2045+
CustomCallSchedule schedule = CustomCallSchedule::NONE);
20452046

20462047
// Overload which constructs a custom call that applies an Xla computation.
20472048
XlaOp CustomCallWithComputation(
@@ -2064,7 +2065,8 @@ XlaOp CustomCallWithLayout(
20642065
const string& opaque = "", bool has_side_effect = false,
20652066
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
20662067
output_operand_aliasing = {},
2067-
const Literal* literal = nullptr);
2068+
const Literal* literal = nullptr,
2069+
CustomCallSchedule schedule = CustomCallSchedule::NONE);
20682070

20692071
// The following methods enqueue element-wise binary arithmetic operations
20702072
// onto the computation. The shapes of the operands have to match unless one

tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121

2222
#include "absl/memory/memory.h"
2323
#include "tensorflow/compiler/xla/service/buffer_value.h"
24+
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
2425
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
2526
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
2627
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
@@ -184,6 +185,84 @@ void BFSLaunchOrder(const HloComputation* computation,
184185
}
185186
}
186187

188+
bool CustomCallWithSchedule(const HloInstruction* instr,
189+
CustomCallSchedule schedule) {
190+
return instr->opcode() == HloOpcode::kCustomCall &&
191+
static_cast<const HloCustomCallInstruction*>(instr)
192+
->custom_call_schedule() == schedule;
193+
}
194+
195+
// Schedules EARLY_AS_POSSIBLE and LATE_AS_POSSIBLE custom-calls. This supports
196+
// a custom-call case, where a logical operation can be lowered into two HLOs
197+
// (e.g., PerformX and PerformXDone). We can utilize this mechanism to either
198+
// hide host latencies between the pair of the custom-calls or the two calls
199+
// can more accurately identify the def-use relationship (typically PerformX is
200+
// scheduled right after all of its producers have been scheduled and
201+
// PerformXDone is scheduled right before its first consumer.)
202+
HloInstructionSequence postprocessor_to_custom_schedule(
203+
const HloInstructionSequence& input) {
204+
// Schedule `EARLY_AS_POSSIBLE`.
205+
std::deque<HloInstruction*> early_as_possible_sched;
206+
{
207+
absl::flat_hash_set<HloInstruction*> scheduled;
208+
const std::vector<HloInstruction*>& instrs = input.instructions();
209+
for (HloInstruction* instr : instrs) {
210+
if (scheduled.contains(instr)) {
211+
continue;
212+
}
213+
214+
early_as_possible_sched.push_back(instr);
215+
scheduled.insert(instr);
216+
217+
for (HloInstruction* user : instr->users()) {
218+
// Schedule any user who has the attribute `early_as_possible` and all
219+
// of its producers have been scheduled.
220+
if (CustomCallWithSchedule(user,
221+
CustomCallSchedule::EARLY_AS_POSSIBLE) &&
222+
absl::c_all_of(user->operands(), [&](const HloInstruction* opnd) {
223+
return scheduled.contains(opnd);
224+
})) {
225+
early_as_possible_sched.push_back(user);
226+
scheduled.insert(user);
227+
}
228+
}
229+
}
230+
}
231+
232+
// Schedule `LATE_AS_POSSIBLE`.
233+
std::deque<HloInstruction*> late_as_possible_sched;
234+
{
235+
absl::flat_hash_set<HloInstruction*> scheduled;
236+
for (auto it = early_as_possible_sched.rbegin();
237+
it != early_as_possible_sched.rend(); it++) {
238+
if (scheduled.contains(*it)) {
239+
continue;
240+
}
241+
242+
late_as_possible_sched.push_front(*it);
243+
scheduled.insert(*it);
244+
245+
for (HloInstruction* opnd : (*it)->unique_operands()) {
246+
// Schedule any opnd who has the attribute `late_as_possible` if all of
247+
// its users have been scheduled.
248+
if (CustomCallWithSchedule(opnd,
249+
CustomCallSchedule::LATE_AS_POSSIBLE) &&
250+
absl::c_all_of(opnd->users(), [&](const HloInstruction* u) {
251+
return scheduled.contains(u);
252+
})) {
253+
late_as_possible_sched.push_front(opnd);
254+
scheduled.insert(opnd);
255+
}
256+
}
257+
}
258+
}
259+
260+
HloInstructionSequence result;
261+
absl::c_for_each(late_as_possible_sched,
262+
[&](HloInstruction* i) { result.push_back(i); });
263+
return result;
264+
}
265+
187266
} // end namespace
188267

189268
GpuHloSchedule::GpuHloSchedule() {}
@@ -206,7 +285,8 @@ StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build(
206285
[pointer_size](const BufferValue& buffer) {
207286
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
208287
},
209-
ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler)));
288+
ComputationSchedulerToModuleScheduler(
289+
DefaultMemoryScheduler, postprocessor_to_custom_schedule)));
210290
schedule->thunk_launch_order_ =
211291
sequences.sequence(entry_computation).instructions();
212292
schedule->hlo_ordering_ =

tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,5 +347,58 @@ TEST_F(GpuHloScheduleTest, DISABLED_LatticeMatMul) {
347347
}
348348
}
349349

350+
TEST_F(GpuHloScheduleTest, AsyncCustomCall) {
351+
HloComputation::Builder builder("entry_computation");
352+
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
353+
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
354+
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
355+
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
356+
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
357+
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
358+
HloInstruction* add1 = builder.AddInstruction(
359+
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
360+
HloInstruction* add2 = builder.AddInstruction(
361+
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, y, z));
362+
// Nonblocking call depends on add1 but not add2.
363+
HloInstruction* nonblocking_call =
364+
builder.AddInstruction(HloInstruction::CreateCustomCall(
365+
f32_2x2_, {add1},
366+
/*custom_call_target=*/"nonblocking-call-start",
367+
/*opaque=*/""));
368+
static_cast<HloCustomCallInstruction*>(nonblocking_call)
369+
->set_custom_call_schedule(EARLY_AS_POSSIBLE);
370+
// Blocking call, which only add4 depends on.
371+
HloInstruction* blocking_call =
372+
builder.AddInstruction(HloInstruction::CreateCustomCall(
373+
f32_2x2_, {nonblocking_call},
374+
/*custom_call_target=*/"blocking-call-done",
375+
/*opaque=*/""));
376+
static_cast<HloCustomCallInstruction*>(blocking_call)
377+
->set_custom_call_schedule(LATE_AS_POSSIBLE);
378+
HloInstruction* add3 = builder.AddInstruction(
379+
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
380+
HloInstruction* add4 = builder.AddInstruction(HloInstruction::CreateBinary(
381+
f32_2x2_, HloOpcode::kAdd, add3, blocking_call));
382+
383+
auto module = CreateNewVerifiedModule();
384+
module->AddEntryComputation(builder.Build(add4));
385+
386+
std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
387+
388+
auto schedule = BuildGpuHloSchedule(module.get(), *streams);
389+
auto order = schedule->ConsumeHloOrdering();
390+
VLOG(2) << order->ToString();
391+
392+
EXPECT_TRUE(order->ExecutesBefore(add1, nonblocking_call));
393+
EXPECT_TRUE(order->ExecutesBefore(nonblocking_call, add2));
394+
EXPECT_TRUE(order->ExecutesBefore(nonblocking_call, add3));
395+
EXPECT_TRUE(order->ExecutesBefore(nonblocking_call, add4));
396+
397+
EXPECT_TRUE(order->ExecutesBefore(add1, blocking_call));
398+
EXPECT_TRUE(order->ExecutesBefore(add2, blocking_call));
399+
EXPECT_TRUE(order->ExecutesBefore(add3, blocking_call));
400+
EXPECT_TRUE(order->ExecutesBefore(blocking_call, add4));
401+
}
402+
350403
} // namespace gpu
351404
} // namespace xla

tensorflow/compiler/xla/service/hlo.proto

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ import "tensorflow/compiler/xla/xla_data.proto";
3434

3535
option cc_enable_arenas = true;
3636

37+
enum CustomCallSchedule {
38+
NONE = 0;
39+
LATE_AS_POSSIBLE = 1;
40+
EARLY_AS_POSSIBLE = 2;
41+
}
42+
3743
// Serialization of HloInstruction.
3844
// Next ID: 76
3945
message HloInstructionProto {
@@ -237,6 +243,9 @@ message HloInstructionProto {
237243
repeated xla.CustomCallOutputOperandAliasing
238244
custom_call_output_operand_aliasing = 74;
239245

246+
// Specifies the desired schedule for the custom-call.
247+
CustomCallSchedule custom_call_schedule = 76;
248+
240249
// The delta value for kRngGetAndUpdateState.
241250
int64 delta = 66;
242251

0 commit comments

Comments
 (0)