-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA/GPU] Add CustomCallSchedule to give schedule hints to custom-calls. #48806
Changes from 5 commits
7de009c
f2227ba
8ac6a19
ba45314
d9e163d
9b741a9
6571b86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ limitations under the License. | |
|
||
#include "absl/memory/memory.h" | ||
#include "tensorflow/compiler/xla/service/buffer_value.h" | ||
#include "tensorflow/compiler/xla/service/hlo_instructions.h" | ||
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" | ||
#include "tensorflow/compiler/xla/service/hlo_reachability.h" | ||
#include "tensorflow/compiler/xla/service/hlo_schedule.h" | ||
|
@@ -184,6 +185,86 @@ void BFSLaunchOrder(const HloComputation* computation, | |
} | ||
} | ||
|
||
bool CustomCallWithSchedule(const HloInstruction* instr, | ||
CustomCallSchedule schedule) { | ||
return instr->opcode() == HloOpcode::kCustomCall && | ||
static_cast<const HloCustomCallInstruction*>(instr) | ||
->custom_call_schedule() == schedule; | ||
} | ||
|
||
// Schedules EARLIEST and LATEST custom-calls. This supports a custom-call use | ||
// case, where a logical operation is lowered into two HLOs (e.g., PerformX and | ||
// PerformXDone). We utilize this mechanism to either hide host latencies | ||
// between the pair of the custom-calls or more accurately identify the def-use | ||
// relationship of the two calls (typically PerformX is scheduled right after | ||
// all of its producers have been scheduled and PerformXDone is scheduled right | ||
// before its first consumer.) | ||
HloInstructionSequence postprocessor_to_custom_schedule( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Style guide says functions are UpperCase There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. I don't know why I made it lower case. Will revise it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
const HloInstructionSequence& input) { | ||
// Schedule `EARLIEST`. | ||
std::deque<HloInstruction*> earliest_scheduled; | ||
{ | ||
absl::flat_hash_set<HloInstruction*> scheduled; | ||
auto is_scheduled = [&](const HloInstruction* instr) -> bool { | ||
return scheduled.contains(instr); | ||
}; | ||
const std::vector<HloInstruction*>& instrs = input.instructions(); | ||
for (HloInstruction* instr : instrs) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional nitpick: maybe just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. |
||
if (scheduled.contains(instr)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haven't we just defined the lambda checking precisely this? Or more concretely: do we need the lambda at all then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The lambda is needed because the absl::c_all_of() below requires a function form.
It is cleaner to also use the lambda here. Will update it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. |
||
continue; | ||
} | ||
|
||
earliest_scheduled.push_back(instr); | ||
scheduled.insert(instr); | ||
|
||
for (HloInstruction* user : instr->users()) { | ||
// Schedule any user who has the attribute `EARLIEST` and all | ||
// of its producers and control_predecessors have been scheduled. | ||
if (CustomCallWithSchedule(user, CustomCallSchedule::EARLIEST) && | ||
absl::c_all_of(user->operands(), is_scheduled) && | ||
absl::c_all_of(user->control_predecessors(), is_scheduled)) { | ||
earliest_scheduled.push_back(user); | ||
scheduled.insert(user); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Schedule `LATEST`. | ||
std::deque<HloInstruction*> latest_scheduled; | ||
{ | ||
absl::flat_hash_set<HloInstruction*> scheduled; | ||
auto is_scheduled = [&](const HloInstruction* instr) -> bool { | ||
return scheduled.contains(instr); | ||
}; | ||
for (auto it = earliest_scheduled.rbegin(); it != earliest_scheduled.rend(); | ||
it++) { | ||
if (scheduled.contains(*it)) { | ||
continue; | ||
} | ||
|
||
latest_scheduled.push_front(*it); | ||
scheduled.insert(*it); | ||
|
||
for (HloInstruction* opnd : (*it)->unique_operands()) { | ||
// Schedule any opnd who has the attribute `LATEST` if all of | ||
// its users and control_successors have been scheduled. | ||
if (CustomCallWithSchedule(opnd, CustomCallSchedule::LATEST) && | ||
absl::c_all_of(opnd->users(), is_scheduled) && | ||
absl::c_all_of(opnd->control_successors(), is_scheduled)) { | ||
latest_scheduled.push_front(opnd); | ||
scheduled.insert(opnd); | ||
} | ||
} | ||
} | ||
} | ||
|
||
HloInstructionSequence result; | ||
absl::c_for_each(latest_scheduled, | ||
[&](HloInstruction* i) { result.push_back(i); }); | ||
return result; | ||
} | ||
|
||
} // end namespace | ||
|
||
GpuHloSchedule::GpuHloSchedule() {} | ||
|
@@ -206,7 +287,8 @@ StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build( | |
[pointer_size](const BufferValue& buffer) { | ||
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); | ||
}, | ||
ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler))); | ||
ComputationSchedulerToModuleScheduler( | ||
DefaultMemoryScheduler, postprocessor_to_custom_schedule))); | ||
schedule->thunk_launch_order_ = | ||
sequences.sequence(entry_computation).instructions(); | ||
schedule->hlo_ordering_ = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,12 @@ import "tensorflow/compiler/xla/xla_data.proto"; | |
|
||
option cc_enable_arenas = true; | ||
|
||
enum CustomCallSchedule { | ||
NONE = 0; | ||
LATEST = 1; | ||
EARLIEST = 2; | ||
} | ||
|
||
// Serialization of HloInstruction. | ||
// Next ID: 76 | ||
message HloInstructionProto { | ||
|
@@ -237,6 +243,9 @@ message HloInstructionProto { | |
repeated xla.CustomCallOutputOperandAliasing | ||
custom_call_output_operand_aliasing = 74; | ||
|
||
// Specifies the desired schedule for the custom-call. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specify that the field is only present for custom calls? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will add comments to make it clear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added. |
||
CustomCallSchedule custom_call_schedule = 76; | ||
|
||
// The delta value for kRngGetAndUpdateState. | ||
int64 delta = 66; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conceptually this is a bit dodgy: we are exposing the schedule to all clients, but only the GPU backend uses it.
Should we explicitly error out on other backends? Or ignore it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current behavior is that if some clients add the schedule attributes to backends other than GPU, the attributes are simply ignored as the their postprocessors are empty.
An alternative is to make a default postprocessor to give warnings. E.g., "the schedule is set but ignored on the backend." How do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the current implementation in effect ignores the schedule attributes if they are used in other backends. This is a safe and correct behavior by all means.
After some thought, it is good for me to leave this in the current way (i.e., no postprocessors in other backends and effectively ignore the schedule attributes). Let me know if you have other thoughts or I don't address your concern.