Skip to content

Commit

Permalink
[XLA/GPU] Add CustomCallSchedule to give schedule hints to custom-calls.
Browse files Browse the repository at this point in the history
Add schedule hints EARLY_AS_POSSIBLE and LATE_AS_POSSIBLE to custom-calls.
This supports a custom-call case, where a logical operation can be lowered into
two HLOs (e.g., PerformX and PerformXDone). We can utilize this mechanism to
either hide host latencies between the pair of the custom-calls or the two calls
can more accurately identify the def-use relationship (typically PerformX is
scheduled right after all of its producers have been scheduled and PerformXDone
is scheduled right before its first consumer.)
  • Loading branch information
trentlo committed Apr 28, 2021
1 parent a2a607d commit f4f3a55
Show file tree
Hide file tree
Showing 15 changed files with 318 additions and 58 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule /*schedule*/) {
if (operand_shapes_with_layout.has_value())
return Unimplemented(
"CustomCall doesn't support operands shapes with layout");
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class MlirHloBuilder : public XlaBuilder {
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) override;
const Literal* literal, CustomCallSchedule schedule) override;

StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
Expand Down
29 changes: 16 additions & 13 deletions tensorflow/compiler/xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1893,7 +1893,7 @@ XlaOp XlaBuilder::CustomCall(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule schedule) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
Expand Down Expand Up @@ -1926,7 +1926,7 @@ XlaOp XlaBuilder::CustomCall(
}
return CustomCallInternal(call_target_name, operands, shape, opaque,
operand_shapes_with_layout, has_side_effect,
output_operand_aliasing, literal);
output_operand_aliasing, literal, schedule);
});
}

Expand All @@ -1937,7 +1937,7 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule schedule) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
Expand All @@ -1962,6 +1962,7 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
aliasing->add_output_shape_index(index);
}
}
instr.set_custom_call_schedule(schedule);
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
}

Expand All @@ -1972,7 +1973,7 @@ XlaOp XlaBuilder::CustomCall(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule schedule) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
Expand Down Expand Up @@ -2023,6 +2024,7 @@ XlaOp XlaBuilder::CustomCall(
aliasing->add_output_shape_index(index);
}
}
instr.set_custom_call_schedule(schedule);
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
Expand Down Expand Up @@ -4192,10 +4194,11 @@ XlaOp CustomCall(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule schedule) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
has_side_effect, output_operand_aliasing, literal);
has_side_effect, output_operand_aliasing, literal,
schedule);
}

XlaOp CustomCallWithComputation(
Expand All @@ -4204,11 +4207,11 @@ XlaOp CustomCallWithComputation(
const Shape& shape, const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
return builder->CustomCall(call_target_name, operands, computation, shape,
opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
has_side_effect, output_operand_aliasing, literal);
const Literal* literal, CustomCallSchedule schedule) {
return builder->CustomCall(
call_target_name, operands, computation, shape, opaque,
/*operand_shapes_with_layout=*/absl::nullopt, has_side_effect,
output_operand_aliasing, literal, schedule);
}

XlaOp CustomCallWithLayout(
Expand All @@ -4218,10 +4221,10 @@ XlaOp CustomCallWithLayout(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule schedule) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
operand_shapes_with_layout, has_side_effect,
output_operand_aliasing, literal);
output_operand_aliasing, literal, schedule);
}

XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
Expand Down
18 changes: 10 additions & 8 deletions tensorflow/compiler/xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ class XlaBuilder {
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);

// Internal version of CustomCall without computation that doesn't do op
// specific error handling and expects arguments to be legal. CustomCall
Expand All @@ -649,7 +649,7 @@ class XlaBuilder {
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);

XlaOp CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
Expand All @@ -659,7 +659,7 @@ class XlaBuilder {
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);

XlaOp Reduce(XlaOp operand, XlaOp init_value,
const XlaComputation& computation,
Expand Down Expand Up @@ -1208,22 +1208,22 @@ class XlaBuilder {
const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);
friend XlaOp CustomCallWithComputation(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);
friend XlaOp CustomCallWithLayout(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);
friend XlaOp Complex(XlaOp real, XlaOp imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(XlaOp operand);
Expand Down Expand Up @@ -2041,7 +2041,8 @@ XlaOp CustomCall(
const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing = {},
const Literal* literal = nullptr);
const Literal* literal = nullptr,
CustomCallSchedule schedule = CustomCallSchedule::NONE);

// Overload which constructs a custom call that applies an Xla computation.
XlaOp CustomCallWithComputation(
Expand All @@ -2064,7 +2065,8 @@ XlaOp CustomCallWithLayout(
const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing = {},
const Literal* literal = nullptr);
const Literal* literal = nullptr,
CustomCallSchedule schedule = CustomCallSchedule::NONE);

// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
Expand Down
82 changes: 81 additions & 1 deletion tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.

#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
Expand Down Expand Up @@ -184,6 +185,84 @@ void BFSLaunchOrder(const HloComputation* computation,
}
}

bool CustomCallWithSchedule(const HloInstruction* instr,
CustomCallSchedule schedule) {
return instr->opcode() == HloOpcode::kCustomCall &&
static_cast<const HloCustomCallInstruction*>(instr)
->custom_call_schedule() == schedule;
}

// Schedules EARLY_AS_POSSIBLE and LATE_AS_POSSIBLE custom-calls. This supports
// a custom-call case, where a logical operation can be lowered into two HLOs
// (e.g., PerformX and PerformXDone). We can utilize this mechanism to either
// hide host latencies between the pair of the custom-calls or the two calls
// can more accurately identify the def-use relationship (typically PerformX is
// scheduled right after all of its producers have been scheduled and
// PerformXDone is scheduled right before its first consumer.)
HloInstructionSequence postprocessor_to_custom_schedule(
const HloInstructionSequence& input) {
// Schedule `EARLY_AS_POSSIBLE`.
std::deque<HloInstruction*> early_as_possible_sched;
{
absl::flat_hash_set<HloInstruction*> scheduled;
const std::vector<HloInstruction*>& instrs = input.instructions();
for (HloInstruction* instr : instrs) {
if (scheduled.contains(instr)) {
continue;
}

early_as_possible_sched.push_back(instr);
scheduled.insert(instr);

for (HloInstruction* user : instr->users()) {
// Schedule any user who has the attribute `early_as_possible` and all
// of its producers have been scheduled.
if (CustomCallWithSchedule(user,
CustomCallSchedule::EARLY_AS_POSSIBLE) &&
absl::c_all_of(user->operands(), [&](const HloInstruction* opnd) {
return scheduled.contains(opnd);
})) {
early_as_possible_sched.push_back(user);
scheduled.insert(user);
}
}
}
}

// Schedule `LATE_AS_POSSIBLE`.
std::deque<HloInstruction*> late_as_possible_sched;
{
absl::flat_hash_set<HloInstruction*> scheduled;
for (auto it = early_as_possible_sched.rbegin();
it != early_as_possible_sched.rend(); it++) {
if (scheduled.contains(*it)) {
continue;
}

late_as_possible_sched.push_front(*it);
scheduled.insert(*it);

for (HloInstruction* opnd : (*it)->unique_operands()) {
// Schedule any opnd who has the attribute `late_as_possible` if all of
// its users have been scheduled.
if (CustomCallWithSchedule(opnd,
CustomCallSchedule::LATE_AS_POSSIBLE) &&
absl::c_all_of(opnd->users(), [&](const HloInstruction* u) {
return scheduled.contains(u);
})) {
late_as_possible_sched.push_front(opnd);
scheduled.insert(opnd);
}
}
}
}

HloInstructionSequence result;
absl::c_for_each(late_as_possible_sched,
[&](HloInstruction* i) { result.push_back(i); });
return result;
}

} // end namespace

GpuHloSchedule::GpuHloSchedule() {}
Expand All @@ -206,7 +285,8 @@ StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build(
[pointer_size](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
},
ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler)));
ComputationSchedulerToModuleScheduler(
DefaultMemoryScheduler, postprocessor_to_custom_schedule)));
schedule->thunk_launch_order_ =
sequences.sequence(entry_computation).instructions();
schedule->hlo_ordering_ =
Expand Down
53 changes: 53 additions & 0 deletions tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,5 +347,58 @@ TEST_F(GpuHloScheduleTest, DISABLED_LatticeMatMul) {
}
}

TEST_F(GpuHloScheduleTest, AsyncCustomCall) {
HloComputation::Builder builder("entry_computation");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
HloInstruction* add2 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, y, z));
// Nonblocking call depends on add1 but not add2.
HloInstruction* nonblocking_call =
builder.AddInstruction(HloInstruction::CreateCustomCall(
f32_2x2_, {add1},
/*custom_call_target=*/"nonblocking-call-start",
/*opaque=*/""));
static_cast<HloCustomCallInstruction*>(nonblocking_call)
->set_custom_call_schedule(EARLY_AS_POSSIBLE);
// Blocking call, which only add4 depends on.
HloInstruction* blocking_call =
builder.AddInstruction(HloInstruction::CreateCustomCall(
f32_2x2_, {nonblocking_call},
/*custom_call_target=*/"blocking-call-done",
/*opaque=*/""));
static_cast<HloCustomCallInstruction*>(blocking_call)
->set_custom_call_schedule(LATE_AS_POSSIBLE);
HloInstruction* add3 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
HloInstruction* add4 = builder.AddInstruction(HloInstruction::CreateBinary(
f32_2x2_, HloOpcode::kAdd, add3, blocking_call));

auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build(add4));

std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);

auto schedule = BuildGpuHloSchedule(module.get(), *streams);
auto order = schedule->ConsumeHloOrdering();
VLOG(2) << order->ToString();

EXPECT_TRUE(order->ExecutesBefore(add1, nonblocking_call));
EXPECT_TRUE(order->ExecutesBefore(nonblocking_call, add2));
EXPECT_TRUE(order->ExecutesBefore(nonblocking_call, add3));
EXPECT_TRUE(order->ExecutesBefore(nonblocking_call, add4));

EXPECT_TRUE(order->ExecutesBefore(add1, blocking_call));
EXPECT_TRUE(order->ExecutesBefore(add2, blocking_call));
EXPECT_TRUE(order->ExecutesBefore(add3, blocking_call));
EXPECT_TRUE(order->ExecutesBefore(blocking_call, add4));
}

} // namespace gpu
} // namespace xla
9 changes: 9 additions & 0 deletions tensorflow/compiler/xla/service/hlo.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ import "tensorflow/compiler/xla/xla_data.proto";

option cc_enable_arenas = true;

enum CustomCallSchedule {
NONE = 0;
LATE_AS_POSSIBLE = 1;
EARLY_AS_POSSIBLE = 2;
}

// Serialization of HloInstruction.
// Next ID: 76
message HloInstructionProto {
Expand Down Expand Up @@ -237,6 +243,9 @@ message HloInstructionProto {
repeated xla.CustomCallOutputOperandAliasing
custom_call_output_operand_aliasing = 74;

// Specifies the desired schedule for the custom-call.
CustomCallSchedule custom_call_schedule = 76;

// The delta value for kRngGetAndUpdateState.
int64 delta = 66;

Expand Down

0 comments on commit f4f3a55

Please sign in to comment.