New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NNC] Fix lowering of aten::pow #47795
Conversation
💊 CI failures summary and remediationsAs of commit 21c6a78 (more details on the Dr. CI page):
🚧 2 fixed upstream failures:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
|
||
@torch.jit.script | ||
def do_exp(x, y, z): | ||
return ((x * y) * 2) * torch.pow(z, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: can the expression be simplified? I guess what really matters here is just torch.pow(z, 2)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be, but this is equivalent to the repro in the bug report.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nickgg has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: NNC lowering of aten::pow assumes that the types of the exponent is either float or int cast to to float, which doesn't work great with double (or half for that matter). Fixes pytorch#47304 Pull Request resolved: pytorch#47795 Reviewed By: ZolotukhinM Differential Revision: D24904201 Pulled By: nickgg fbshipit-source-id: 43c3ea704399ebb36c33cd222db16c60e5b7ada5
NNC lowering of aten::pow assumes that the types of the exponent is either float or int cast to to float, which doesn't work great with double (or half for that matter).
Fixes #47304