Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA/GPU] Add CustomCallSchedule to give schedule hints to custom-calls. #48806

Merged
4 changes: 3 additions & 1 deletion tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule schedule) {
if (operand_shapes_with_layout.has_value())
return Unimplemented(
"CustomCall doesn't support operands shapes with layout");
Expand All @@ -145,6 +145,8 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
<< "MLIR CustomCallOp does not support output_operand_aliasing yet";
TF_RET_CHECK(literal == nullptr)
<< "MLIR CustomCallOp does not support literal yet";
TF_RET_CHECK(schedule == CustomCallSchedule::NONE)
<< "MLIR CustomCallOp does not support custom-call-schedule yet";
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
/*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class MlirHloBuilder : public XlaBuilder {
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) override;
const Literal* literal, CustomCallSchedule schedule) override;

StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
Expand Down
29 changes: 16 additions & 13 deletions tensorflow/compiler/xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1893,7 +1893,7 @@ XlaOp XlaBuilder::CustomCall(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule schedule) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
Expand Down Expand Up @@ -1926,7 +1926,7 @@ XlaOp XlaBuilder::CustomCall(
}
return CustomCallInternal(call_target_name, operands, shape, opaque,
operand_shapes_with_layout, has_side_effect,
output_operand_aliasing, literal);
output_operand_aliasing, literal, schedule);
});
}

Expand All @@ -1937,7 +1937,7 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule schedule) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
Expand All @@ -1962,6 +1962,7 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
aliasing->add_output_shape_index(index);
}
}
instr.set_custom_call_schedule(schedule);
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
}

Expand All @@ -1972,7 +1973,7 @@ XlaOp XlaBuilder::CustomCall(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule schedule) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
Expand Down Expand Up @@ -2023,6 +2024,7 @@ XlaOp XlaBuilder::CustomCall(
aliasing->add_output_shape_index(index);
}
}
instr.set_custom_call_schedule(schedule);
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
Expand Down Expand Up @@ -4192,10 +4194,11 @@ XlaOp CustomCall(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule schedule) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
has_side_effect, output_operand_aliasing, literal);
has_side_effect, output_operand_aliasing, literal,
schedule);
}

XlaOp CustomCallWithComputation(
Expand All @@ -4204,11 +4207,11 @@ XlaOp CustomCallWithComputation(
const Shape& shape, const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
return builder->CustomCall(call_target_name, operands, computation, shape,
opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
has_side_effect, output_operand_aliasing, literal);
const Literal* literal, CustomCallSchedule schedule) {
return builder->CustomCall(
call_target_name, operands, computation, shape, opaque,
/*operand_shapes_with_layout=*/absl::nullopt, has_side_effect,
output_operand_aliasing, literal, schedule);
}

XlaOp CustomCallWithLayout(
Expand All @@ -4218,10 +4221,10 @@ XlaOp CustomCallWithLayout(
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal) {
const Literal* literal, CustomCallSchedule schedule) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
operand_shapes_with_layout, has_side_effect,
output_operand_aliasing, literal);
output_operand_aliasing, literal, schedule);
}

XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
Expand Down
18 changes: 10 additions & 8 deletions tensorflow/compiler/xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ class XlaBuilder {
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);

// Internal version of CustomCall without computation that doesn't do op
// specific error handling and expects arguments to be legal. CustomCall
Expand All @@ -649,7 +649,7 @@ class XlaBuilder {
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);

XlaOp CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
Expand All @@ -659,7 +659,7 @@ class XlaBuilder {
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);

XlaOp Reduce(XlaOp operand, XlaOp init_value,
const XlaComputation& computation,
Expand Down Expand Up @@ -1208,22 +1208,22 @@ class XlaBuilder {
const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);
friend XlaOp CustomCallWithComputation(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);
friend XlaOp CustomCallWithLayout(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal);
const Literal* literal, CustomCallSchedule schedule);
friend XlaOp Complex(XlaOp real, XlaOp imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(XlaOp operand);
Expand Down Expand Up @@ -2041,7 +2041,8 @@ XlaOp CustomCall(
const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing = {},
const Literal* literal = nullptr);
const Literal* literal = nullptr,
CustomCallSchedule schedule = CustomCallSchedule::NONE);

// Overload which constructs a custom call that applies an Xla computation.
XlaOp CustomCallWithComputation(
Expand All @@ -2064,7 +2065,8 @@ XlaOp CustomCallWithLayout(
const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing = {},
const Literal* literal = nullptr);
const Literal* literal = nullptr,
CustomCallSchedule schedule = CustomCallSchedule::NONE);

// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
Expand Down
84 changes: 83 additions & 1 deletion tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.

#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
Expand Down Expand Up @@ -184,6 +185,86 @@ void BFSLaunchOrder(const HloComputation* computation,
}
}

bool CustomCallWithSchedule(const HloInstruction* instr,
CustomCallSchedule schedule) {
return instr->opcode() == HloOpcode::kCustomCall &&
static_cast<const HloCustomCallInstruction*>(instr)
->custom_call_schedule() == schedule;
}

// Schedules EARLIEST and LATEST custom-calls. This supports a custom-call use
// case, where a logical operation is lowered into two HLOs (e.g., PerformX and
// PerformXDone). We utilize this mechanism to either hide host latencies
// between the pair of the custom-calls or more accurately identify the def-use
// relationship of the two calls (typically PerformX is scheduled right after
// all of its producers have been scheduled and PerformXDone is scheduled right
// before its first consumer.)
HloInstructionSequence postprocessor_to_custom_schedule(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually this is a bit dodgy: we are exposing the schedule to all clients, but only the GPU backend uses it.

Should we explicitly error out on other backends? Or ignore it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behavior is that if some clients add the schedule attributes to backends other than GPU, the attributes are simply ignored as the their postprocessors are empty.

An alternative is to make a default postprocessor to give warnings. E.g., "the schedule is set but ignored on the backend." How do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the current implementation in effect ignores the schedule attributes if they are used in other backends. This is a safe and correct behavior by all means.

After some thought, it is good for me to leave this in the current way (i.e., no postprocessors in other backends and effectively ignore the schedule attributes). Let me know if you have other thoughts or I don't address your concern.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style guide says functions are UpperCase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I don't know why I made it lower case. Will revise it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const HloInstructionSequence& input) {
// Schedule `EARLIEST`.
std::deque<HloInstruction*> earliest_scheduled;
{
absl::flat_hash_set<HloInstruction*> scheduled;
auto is_scheduled = [&](const HloInstruction* instr) -> bool {
return scheduled.contains(instr);
};
const std::vector<HloInstruction*>& instrs = input.instructions();
for (HloInstruction* instr : instrs) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional nitpick: maybe just use input.instructions() inline?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

if (scheduled.contains(instr)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't we just defined the lambda checking precisely this? Or more concretely: do we need the lambda at all then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lambda is needed because the absl::c_all_of() below requires a function form.

absl::c_all_of(user->operands(), is_scheduled) && absl::c_all_of(user->control_predecessors(), is_scheduled)) {

It is cleaner to also use the lambda here. Will update it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

continue;
}

earliest_scheduled.push_back(instr);
scheduled.insert(instr);

for (HloInstruction* user : instr->users()) {
// Schedule any user who has the attribute `EARLIEST` and all
// of its producers and control_predecessors have been scheduled.
if (CustomCallWithSchedule(user, CustomCallSchedule::EARLIEST) &&
absl::c_all_of(user->operands(), is_scheduled) &&
absl::c_all_of(user->control_predecessors(), is_scheduled)) {
earliest_scheduled.push_back(user);
scheduled.insert(user);
}
}
}
}

// Schedule `LATEST`.
std::deque<HloInstruction*> latest_scheduled;
{
absl::flat_hash_set<HloInstruction*> scheduled;
auto is_scheduled = [&](const HloInstruction* instr) -> bool {
return scheduled.contains(instr);
};
for (auto it = earliest_scheduled.rbegin(); it != earliest_scheduled.rend();
it++) {
if (scheduled.contains(*it)) {
continue;
}

latest_scheduled.push_front(*it);
scheduled.insert(*it);

for (HloInstruction* opnd : (*it)->unique_operands()) {
// Schedule any opnd who has the attribute `LATEST` if all of
// its users and control_successors have been scheduled.
if (CustomCallWithSchedule(opnd, CustomCallSchedule::LATEST) &&
absl::c_all_of(opnd->users(), is_scheduled) &&
absl::c_all_of(opnd->control_successors(), is_scheduled)) {
latest_scheduled.push_front(opnd);
scheduled.insert(opnd);
}
}
}
}

HloInstructionSequence result;
absl::c_for_each(latest_scheduled,
[&](HloInstruction* i) { result.push_back(i); });
return result;
}

} // end namespace

GpuHloSchedule::GpuHloSchedule() {}
Expand All @@ -206,7 +287,8 @@ StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build(
[pointer_size](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
},
ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler)));
ComputationSchedulerToModuleScheduler(
DefaultMemoryScheduler, postprocessor_to_custom_schedule)));
schedule->thunk_launch_order_ =
sequences.sequence(entry_computation).instructions();
schedule->hlo_ordering_ =
Expand Down
62 changes: 62 additions & 0 deletions tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,5 +347,67 @@ TEST_F(GpuHloScheduleTest, DISABLED_LatticeMatMul) {
}
}

TEST_F(GpuHloScheduleTest, AsyncCustomCall) {
HloComputation::Builder builder("entry_computation");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
HloInstruction* add0 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add0, y));
HloInstruction* add2 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, z));
// Create nonblocking_call(add0).
HloInstruction* nonblocking_call =
builder.AddInstruction(HloInstruction::CreateCustomCall(
f32_2x2_, {add0},
/*custom_call_target=*/"nonblocking-call-start",
/*opaque=*/""));
static_cast<HloCustomCallInstruction*>(nonblocking_call)
->set_custom_call_schedule(EARLIEST);
// In addition, add control_depedency: add1->nonblocking_call.
TF_CHECK_OK(add1->AddControlDependencyTo(nonblocking_call));
// Blocking call, which only add4 depends on.
HloInstruction* blocking_call =
builder.AddInstruction(HloInstruction::CreateCustomCall(
f32_2x2_, {nonblocking_call},
/*custom_call_target=*/"blocking-call-done",
/*opaque=*/""));
static_cast<HloCustomCallInstruction*>(blocking_call)
->set_custom_call_schedule(LATEST);
HloInstruction* add3 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
HloInstruction* add4 = builder.AddInstruction(HloInstruction::CreateBinary(
f32_2x2_, HloOpcode::kAdd, add3, blocking_call));

auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build(add4));

std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);

auto schedule = BuildGpuHloSchedule(module.get(), *streams);
auto order = schedule->ConsumeHloOrdering();
VLOG(2) << order->ToString();

// Order constrained by data dependency.
EXPECT_TRUE(order->ExecutesBefore(add0, nonblocking_call));
// Order constrained by control dependency.
EXPECT_TRUE(order->ExecutesBefore(add1, nonblocking_call));
// Test that nonblocking_call is scheduled before add2, so that we know
// EARLIEST is in effect.
EXPECT_TRUE(order->ExecutesBefore(nonblocking_call, add2));
EXPECT_TRUE(order->ExecutesBefore(nonblocking_call, add3));
EXPECT_TRUE(order->ExecutesBefore(nonblocking_call, add4));

// Test that blocking_call is scheduled after add3, so that we know
// LATEST is in effect.
EXPECT_TRUE(order->ExecutesBefore(add3, blocking_call));
EXPECT_TRUE(order->ExecutesBefore(blocking_call, add4));
}

} // namespace gpu
} // namespace xla
9 changes: 9 additions & 0 deletions tensorflow/compiler/xla/service/hlo.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ import "tensorflow/compiler/xla/xla_data.proto";

option cc_enable_arenas = true;

enum CustomCallSchedule {
NONE = 0;
LATEST = 1;
EARLIEST = 2;
}

// Serialization of HloInstruction.
// Next ID: 76
message HloInstructionProto {
Expand Down Expand Up @@ -237,6 +243,9 @@ message HloInstructionProto {
repeated xla.CustomCallOutputOperandAliasing
custom_call_output_operand_aliasing = 74;

// Specifies the desired schedule for the custom-call.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify that the field is only present for custom calls?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add comments to make it clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

CustomCallSchedule custom_call_schedule = 76;

// The delta value for kRngGetAndUpdateState.
int64 delta = 66;

Expand Down