-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix auto exponent issue for torch.pow #47024
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: e57f2011fed8add87a0f2ba818f7cbf1d790390f Pull Request resolved: #47024
💊 CI failures summary and remediationsAs of commit cf7edfa (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_linux_xenial_py3_clang5_asan_test2 (1/1)Step: "Run tests" (full log | diagnosis details | 🔁 rerun)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Fixes #46936 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#47024 Fix auto exponent issue for torch.pow** [ghstack-poisoned]
Fixes #46936 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#47024 Fix auto exponent issue for torch.pow** [ghstack-poisoned]
ghstack-source-id: e664eaa9159db964a980f8386d532f9dbf7c15df Pull Request resolved: #47024
Fixes #46936 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#47024 Fix auto exponent issue for torch.pow** [ghstack-poisoned]
ghstack-source-id: 3cfe280c6cdeff1cb7984ab79fb6e415a197e2df Pull Request resolved: #47024
@@ -88,6 +88,17 @@ class C10_API Scalar { | |||
|
|||
Scalar operator-() const; | |||
|
|||
template<typename T> | |||
bool equal(T num) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc. @ezyang to verify if it's ok to add equal
for Scalar
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems fine.
Fixes #46936 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#47024 Fix auto exponent issue for torch.pow** [ghstack-poisoned]
ghstack-source-id: 62abe07430dfca7524ff6ca6e47e4ce899714c68 Pull Request resolved: #47024
auto grad_lambda = [](Tensor a, Scalar b) { | ||
return AT_DISPATCH_DOUBLE_COMPLEXDOUBLE(b.type(), "scalar_val", ([&] { | ||
scalar_t val = b.to<scalar_t>(); | ||
return (a * std::log(val)).conj(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems a little goofy. Why not also add a log()
method on Scalar
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should Scalar have more operations defined on it? For example, in pow_backward
, we would need to define an operator- for Scalar to avoid the dispatch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang added log.
I think it makes sense to support more math operations for Scalar
. For now, I added log
, and equal
in Scalar.h
but we should add a new file ScalarMath.{h,cpp}
if we plan to add more ops for scalar
Fixes #46936 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#47024 Fix auto exponent issue for torch.pow** [ghstack-poisoned]
ghstack-source-id: 0300183c1941c0287ba4334524c8b79a4df20a95 Pull Request resolved: #47024
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||
} else { | ||
auto out = grad * (exponent * self.pow(exponent - 1)).conj(); | ||
auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Time to read up on auto type deduction...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to https://en.cppreference.com/w/cpp/language/lambda, grad_lambda
is a generic lambda, which behaves like a template with one parameter. So this should work.
It's too bad that we can't actually write a test for this
Fixes #46936 Differential Revision: [D24698027](https://our.internmc.facebook.com/intern/diff/D24698027) [ghstack-poisoned]
ghstack-source-id: 6a7db7a10a0a0debf9762892a6b14ee70f85982d Pull Request resolved: #47024
Fixes #46936 Differential Revision: [D24698027](https://our.internmc.facebook.com/intern/diff/D24698027) [ghstack-poisoned]
ghstack-source-id: 5f99cd686fd30520ce16d8e5d1b6599169402dd4 Pull Request resolved: #47024
Fixes #46936 Differential Revision: [D24698027](https://our.internmc.facebook.com/intern/diff/D24698027) [ghstack-poisoned]
ghstack-source-id: f6df8f5fc1f841490eef810e0e413524dc2f1824 Pull Request resolved: #47024
test/cpp/api/autograd.cpp
Outdated
// auto y = x.pow(1.5); | ||
// auto gr = | ||
// grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true); | ||
// ASSERT_THROWS_WITH(grad({gr[0]}, {x});, "returned nan"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try adding the grad_output as torch::tensor({0.0})
here? That should make it nan
as you expect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup will do
Fixes #46936 Differential Revision: [D24698027](https://our.internmc.facebook.com/intern/diff/D24698027) [ghstack-poisoned]
ghstack-source-id: b68bb0477d6bf582852136b497d012ec001dc583 Pull Request resolved: #47024
@anjali411 merged this pull request in 8ef7ccd. |
Unlanding. This appears to have broken pytorch_linux_xenial_py3_clang5_asan_test2. Relevant snippet:
|
This pull request has been reverted by 013e6a3. |
Fixes #46936
Stack from ghstack:
Differential Revision: D24698027