New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Iterative horizontal fusion. #48706
Iterative horizontal fusion. #48706
Conversation
1. Extend horizontal fusion to support non-fusion instructions. 2. Enable iterative optimization for horizontal fusion. After each iteration, new horizontal fusion opportunites are exposed because the producers to the previously generated horizontally fused instructions will become fusion candidates.
@cheshire could you help to review this PR? Thanks! |
BTW I'm currently working on trying to change the calling convention to allow more than N arguments for calls. If successful, that should enable a lot more horizontal fusion opportunities (in many cases I have seen it hits the # of arguments to the kernel boundary) |
@@ -67,6 +67,25 @@ PrimitiveType GetUniqueOutputTypeOfFusion(const HloInstruction& instr) { | |||
return first_output_type; | |||
} | |||
|
|||
size_t GetInstrCountOfFusible(const HloInstruction& instr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to reduce the duplication with the same function in the other file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good cache. Will do.
|
||
// Creates a kLoop fusion instruction and fuses `fused` into the created | ||
// fusion instruction. | ||
HloInstruction* MakeLoopFusionInstruction(HloInstruction* fused) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same note regarding reducing duplication. Fusion type could be simply passed as a parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
// Convert fusible into fusion_instrs to simplify the implementation of | ||
// `Fuse()`. | ||
std::vector<HloInstruction*> fusion_instrs; | ||
for (auto instr : fusibles) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicit type preferred
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do.
@@ -493,16 +524,26 @@ StatusOr<bool> HorizontalLoopFusionImpl::Run() { | |||
auto consumer = def_to_use_order[i]; | |||
HorizontalLoopFusionImpl::FusionCandidates fusion_candidates(consumer); | |||
while (true) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are running this in a fixed point, could we remove this while/true loop then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Theoretically we could, but having it in this way is better in practice.
Note that the granularity of this while loop processes is very fine, i.e., just the instructions that share the same immediate consumer. We write this while loop here is because we don't want to fuse all of these instructions into a kernel (;instead, fuse them into multiple kernels), as a kernel too large can be problematic. Theoretically, we could rely on the fixed point to run this pass many times to get these instructions fused but it is not efficient.
On the other hand, the fixed point is used to process the fusions newly generated by this pass. For example, in the unittest IterativeHorizontalFusion
, the fusion created by fusing fusion.0
and fusion.1
won't be traversed until next iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Response to the comments. Will revise the code soon.
@@ -67,6 +67,25 @@ PrimitiveType GetUniqueOutputTypeOfFusion(const HloInstruction& instr) { | |||
return first_output_type; | |||
} | |||
|
|||
size_t GetInstrCountOfFusible(const HloInstruction& instr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good cache. Will do.
|
||
// Creates a kLoop fusion instruction and fuses `fused` into the created | ||
// fusion instruction. | ||
HloInstruction* MakeLoopFusionInstruction(HloInstruction* fused) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
// Convert fusible into fusion_instrs to simplify the implementation of | ||
// `Fuse()`. | ||
std::vector<HloInstruction*> fusion_instrs; | ||
for (auto instr : fusibles) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do.
@@ -493,16 +524,26 @@ StatusOr<bool> HorizontalLoopFusionImpl::Run() { | |||
auto consumer = def_to_use_order[i]; | |||
HorizontalLoopFusionImpl::FusionCandidates fusion_candidates(consumer); | |||
while (true) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Theoretically we could, but having it in this way is better in practice.
Note that the granularity of this while loop processes is very fine, i.e., just the instructions that share the same immediate consumer. We write this while loop here is because we don't want to fuse all of these instructions into a kernel (;instead, fuse them into multiple kernels), as a kernel too large can be problematic. Theoretically, we could rely on the fixed point to run this pass many times to get these instructions fused but it is not efficient.
On the other hand, the fixed point is used to process the fusions newly generated by this pass. For example, in the unittest IterativeHorizontalFusion
, the fusion created by fusing fusion.0
and fusion.1
won't be traversed until next iteration.
Cool! As it is a limitation in the CUDA kernel signature, (wondering if) are you going to pack the arguments into a struct and pass the struct? |
I've addressed the review comments, please help to take a look again. Thanks! |
This would not work, as the limit is on the argument size and not the number of arguments. I'm experimenting with passing the buffer table directly instead. |
absl::InlinedVector<HloInstruction*, 2> GetOutputsOfFusible( | ||
const HloInstruction& instr) { | ||
if (instr.opcode() != HloOpcode::kFusion) { | ||
return {const_cast<HloInstruction*>(&instr)}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Must we const-cast? Why not return a vector of const pointers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type has to align with the return type of HloInstruction::operands()
, which is of absl::InlinedVector<HloInstruction*, 2>
. So, returning absl::InlinedVector<const HloInstruction*, 2>
will require new allocation of the vector. const_cast
may be a bit less evil.
Let me know if you have any other ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const_cast
seems to be technically UB if modifying methods are then called, right? Allocating 2 pointers on a stack (new vector) should be essentially free, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just did some search. It is undefined behavior if the object itself (not the pointer) is indeed const. Then, changing the object is undefined behavior through const_cast.
My previous statement was wrong though--we allocate a new vector (in the stack) anyway. So, let's remove the const_cast. Please help to take a look of the latest commit.
new horizontal fusion opportunities are exposed because the producers to
the previously generated horizontally fused instructions will become
fusion candidates. See
IterativeHorizontalFusion
in the unittest as an example.