New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA] Add simple HLO if conversion pass #22974
[XLA] Add simple HLO if conversion pass #22974
Conversation
|
||
namespace xla { | ||
|
||
StatusOr<bool> DoConditionalToSelect(HloInstruction* conditional) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please either make this static
or put in an anonymous namespace.
StatusOr<bool> DoConditionalToSelect(HloInstruction* conditional) { | ||
// Only allow conditional to select if the called computations | ||
// do not have side effects. | ||
for (HloComputation* computation : conditional->called_computations()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Up to you, but I think conditional->true_computation()->HasSideEffects() || conditional->false_computation()->HasSideEffects()
looks a bit cleaner.
conditional->false_computation())); | ||
HloInstruction* select_op = | ||
computation->AddInstruction(HloInstruction::CreateTernary( | ||
conditional->shape(), HloOpcode::kSelect, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this may need to be a kTupleSelect
if conditional->shape()
is a tuple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think we need to broadcast the condition when creating the select if conditional->shape()
is not a scalar.
TF_RETURN_IF_ERROR( | ||
call_graph->VisitNodes([&](const CallGraphNode& node) -> Status { | ||
std::vector<HloInstruction*> ToInline; | ||
if (node.context() != CallContext::kParallel) return Status::OK(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We prefer using braces even for single statement if
and for
.
Feedback addressed. |
namespace xla { | ||
|
||
// A pass which transforms conditionals to selects in places where conditionals | ||
// are not allowed to appear (e.g. mapped computation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be clear, this is a backend constraint right (that we can't codegen this) not a general constraint (i.e. the un-transformed IR still passes the verifier)? If yes, would be nice to note explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It passes the verifier, but fails in buffer assignment. The backend folks would have to say what the long term plan here is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, can you instead s/not allowed/legal but not supported/ or something like that? I don't want folks to get misled by this comment about what HLO is allowed.
Looks like there's some whitespace issues with my BUILD changes? Should I update the PR to fix that or do you want to do that as part of the merge? |
Hi Keno, It'd be great if you could fix it on the PR itself; I believe the error message has a diff? If you run into any unforeseen issues let me know and I'll send you a buildified BUILD file. |
Alright. Let's try that. I assume you'll have to re-trigger CI since that doesn't seem to run automatically. |
bump |
… TPUs Will be fixed once tensorflow/tensorflow#22974 is deployed to TPUs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm with comments addressed but not actually lgtm'ing since I'm not sure if the code will be auto merged automatically without giving you time to address the comments.
namespace xla { | ||
|
||
// A pass which transforms conditionals to selects in places where conditionals | ||
// are not allowed to appear (e.g. mapped computation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, can you instead s/not allowed/legal but not supported/ or something like that? I don't want folks to get misled by this comment about what HLO is allowed.
// which expects the condition to be a scalar. | ||
if (!ShapeUtil::IsScalar(conditional->shape()) && | ||
!ShapeUtil::IsTuple(conditional->shape())) { | ||
condition = computation->AddInstruction(HloInstruction::CreateBroadcast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be a new utility in tensorflow/compiler/xla/service/hlo_creation_utils.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that it already has MakeSelectHlo
. In principle that could be expanded do do this whole rigamarole. Should I do that? Or should I create a separate function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making the existing switch between select and triple-select sgtm.
I've addressed those two comments. @davidel had some thoughts on the split of work here between HLO and the backend, but I think the conclusion ended up being that while we should fix this in the backend also, the HLO-level pass in nevertheless useful. |
@ymodak I'm happy to manually pull in the CL to Google and check it in (and resolve the merge conflicts internally). However it is missing the |
@Keno do you mind resolving the conflicts? Thank you! |
I was under the impression @sanjoy would do that on import |
Let me just do it. Sounds like that's easiest. |
thank you |
PiperOrigin-RevId: 239060519
Thanks, @sanjoy. |
kConditional operations are currently generally disallowed in parallel contexts
(e.g. in mapped computations). The julia XLA frontend was running into this limitation
quite a bit, because existing julia code tends to use the terniary operator for select,
e.g. to describe the derivative of a
max
call (and thus arelu
) - see thedefinitions of the derivatives of
max
athttps://github.com/JuliaDiff/DiffRules.jl/blob/master/src/rules.jl#L94
To support these sorts of patterns, add a simple if conversion pass that converts
conditionals in parallel context by equivalent select operations (which are well supported),
i.e. a computation like:
gets rewritten to
To keep things simple, this is accomplished by first rewriting the conditional
to two calls and a select and then inlining the individual calls. Naturally,
the transformation is only applied if the called computation do not
have side effects (which they generally don't if they're in parallel
context). In the future, it would be good to let MapInliner further
simplify this to an implicitly mapped select.