-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA/GPU] rsqrt is cheap and should be fused. #40998
[XLA/GPU] rsqrt is cheap and should be fused. #40998
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks Fine and it worked in my env.
It's more optimize tf module for GPU usage
@sanjoy @thomasjoerg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Trent,
Can you share some specific cases where this helps? The last time I wanted to make this change I concluded 379268e was a more principled fix instead.
[edit: To clarify, I'm implying that tuning the heuristic added in https://github.com/tensorflow/tensorflow/commit/379268e9f4cbccfc46827408a0e67896c75af5b4 might be more effective.]
The heuristic is good (I like it). However, rsqrt (or div) is just mapped to one hardware instruction unlike other instrinsics, which will be expanded into a bunch of instructions when linking in libdevice. So, I think that marking cheap instructions cheap is orthogonal to the heuristic (which better deals with real expensive instructions). |
@@ -29,12 +29,23 @@ limitations under the License. | |||
namespace xla { | |||
namespace gpu { | |||
|
|||
bool ElementIsF32OrF16(const Shape& shape) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static
or put under anonymous namespace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My oversight. Thanks for the catch. Will update it soon,
// We say that some floating-point math ops are cheap on the GPU. | ||
switch (instruction.opcode()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the rationale you mentioned in the PR (that these lower to single instructions).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense. Will add.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will update the PR soon.
@@ -29,12 +29,23 @@ limitations under the License. | |||
namespace xla { | |||
namespace gpu { | |||
|
|||
bool ElementIsF32OrF16(const Shape& shape) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My oversight. Thanks for the catch. Will update it soon,
// We say that some floating-point math ops are cheap on the GPU. | ||
switch (instruction.opcode()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense. Will add.
Also, polish comments in instruction_fusion.cc.
Updated. Please help to take a look again. Thanks! |
Hi @trentlo , This seems to regress a variant of resnet only on V100 by around 10%. Here is the pre-optimization HLO: https://gist.github.com/sanjoy/8161733b3e8f303d2f81b38814661f9a Can you PTAL? Let me know if you can't reproduce the regression. |
I'd guess that it interacts with the fusion heuristic and produces a surprising fusion result. I will take a look. |
@sanjoy, I instead see 1% speedup with this PR on V100 (according to the perf numbers reported by xla_profile). See the attached log file for some more details. Are you sure if the regression is related to this PR? Also, I wonder if you see any perf gain? |
Could have been operator error, trying again. |
Please help to review the codes. Thanks.
The pattern is observed (at least) in BERT.