New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
trying to make pow work for tensor raised to the power of a scalar #46185
Conversation
unit test? |
aten/src/ATen/native/Pow.cpp
Outdated
@@ -62,6 +62,9 @@ Tensor& pow_(Tensor& base, Scalar alpha) { | |||
} | |||
|
|||
Tensor pow(const Tensor& base, const Tensor& exp) { | |||
if (exp.dim() == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test in test/test_torch.py
?
💊 CI failures summary and remediationsAs of commit f3765e9 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 32 times. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@janeyx99 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
I'm guessing there's already a test for this?
I see you added a test for this 👍
There's no test for this, right? Should one be added? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@janeyx99 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@samestep: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@janeyx99 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes #46037
I'm not sure this is the most performant solution, but this works:
torch.pow(cuda_tensor, 5) should work and worked before.
torch.pow(cuda_tensor, torch.tensor(5)), should work and works now!
torch.pow(cuda_tensor, torch.tensor((5,))), should NOT work and complain the tensors are on different devices and indeed continues to complain.