New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[numpy] torch.{all, any} : Extend Dtype Support #44790
[numpy] torch.{all, any} : Extend Dtype Support #44790
Conversation
💊 CI failures summary and remediationsAs of commit 1b242ef (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages: binary_linux_libtorch_3_7m_cpu_devtoolset7_shared-with-deps_build (1/2)Step: "Checkout pytorch/builder repo" (full log | diagnosis details | 🔁 rerun)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for this PR, it's a very welcomed change.
Note: Now that torch.all and torch.any supports all dtypes, we should document it in the public APIs as mentioned here #44779, but this can be a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@heitorschueroff has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the XLA failures are related to the changes. Could you look into what's causing it please?
@heitorschueroff Thanks for looking at it. As for XLA, I m not really sure what is happening. Thanks! |
@kshitij12345 XLA change is ready, I will merge it when this pr is merged. |
@JackCaoG Thanks for updating XLA. @kshitij12345 Could you rebase please, I'll merge it then. |
@heitorschueroff Have fixed the conflict. ROCm failure looks irrelevant. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@heitorschueroff has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Thank you for this contribution, I'm important your changes now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like I missed some details in my review. Our internal tests on phabricator are complaining. I left some comments from phabricator, they should be fairly quick to fix and then I can land it without problems.
return c; | ||
}, | ||
/*ident=*/true); | ||
if (c10::isIntegralType(iter.dtype(), /*include_bool=*/true)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
include_bool -> includeBool
return c; | ||
}, | ||
/*ident=*/false); | ||
if (c10::isIntegralType(iter.dtype(), /*include_bool=*/true)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
include_bool -> includeBool
if (c10::isIntegralType(iter.dtype(), /*include_bool=*/true)) { | ||
binary_kernel_reduce_vec( | ||
iter, | ||
[=](uint8_t a, uint8_t b) -> uint8_t { return a && b; }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid the implicit cast with:
[=](uint8_t a, uint8_t b) -> uint8_t { return ((a && b) ? 1 : 0); },
if (c10::isIntegralType(iter.dtype(), /*include_bool=*/true)) { | ||
binary_kernel_reduce_vec( | ||
iter, | ||
[=](uint8_t a, uint8_t b) -> uint8_t { return a || b; }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid the implicit cast with:
[=](uint8_t a, uint8_t b) -> uint8_t { return ((a && b) ? 1 : 0); },
// true/false. | ||
Vec256<uint8_t> c = Vec256<uint8_t>(); | ||
for (int i = 0; i != Vec256<uint8_t>::size(); i++) { | ||
c[i] = a[i] && b[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid implicit cast with:
c[i] = ((a[i] && b[i]) ? 1 : 0);
[=](Vec256<uint8_t> a, Vec256<uint8_t> b) { | ||
Vec256<uint8_t> c = Vec256<uint8_t>(); | ||
for (int i = 0; i != Vec256<uint8_t>::size(); i++) { | ||
c[i] = a[i] || b[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid implicit cast with:
c[i] = ((a[i] && b[i]) ? 1 : 0);
@heitorschueroff Done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@heitorschueroff has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@heitorschueroff merged this pull request in 6575e67. |
1 similar comment
@heitorschueroff merged this pull request in 6575e67. |
@JackCaoG Please merge the XLA fix. Thanks! |
@kshitij12345 there are a few problems with this PR
Can you please work on fixing 2) and 3) ? |
I merged XLA change. If this pr will be reverted I will revert the xla pr as well. Otherwise I will work on a companion pr to fix the result type. |
I'd look at the |
Also please don't forget to update documentation. |
@ngimel Behaviour in previous version (1.5.1) >>> import torch
>>> torch.__version__
'1.5.1+cu101'
>>> x = torch.zeros(3,3)
>>> x.to(torch.uint8).all()
tensor(0, dtype=torch.uint8)
>>> x.to(torch.bool).all()
tensor(False) |
cc @mruberry for deprecating return type for uint8. In any case, for all other types there are no bc breaking concerns, so we should implement correct behavior. |
I would try to update the uint8 behavior to be consistent and document the change as BC-breaking. If a scripted network relies on the current behavior (extremely unlikely) we can write an upgrader. |
Summary: BC-breaking note: This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.) PR summary: #44790 (comment) Fixes 2 and 3 Also Fixes #48352 Changes * Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**) * Uses vectorized version for all dtypes on CPU * Enables test for complex * Update doc for `torch.all` and `torch.any` TODO * [x] Update docs * [x] Benchmark * [x] Raise issue on XLA Pull Request resolved: #47878 Reviewed By: H-Huang Differential Revision: D25421263 Pulled By: mruberry fbshipit-source-id: c6c681ef94004d2bcc787be61a72aa059b333e69
Summary: BC-breaking note: This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.) PR summary: pytorch#44790 (comment) Fixes 2 and 3 Also Fixes pytorch#48352 Changes * Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**) * Uses vectorized version for all dtypes on CPU * Enables test for complex * Update doc for `torch.all` and `torch.any` TODO * [x] Update docs * [x] Benchmark * [x] Raise issue on XLA Pull Request resolved: pytorch#47878 Reviewed By: H-Huang Differential Revision: D25421263 Pulled By: mruberry fbshipit-source-id: c6c681ef94004d2bcc787be61a72aa059b333e69
Summary: BC-breaking note: This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.) PR summary: #44790 (comment) Fixes 2 and 3 Also Fixes #48352 Changes * Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**) * Uses vectorized version for all dtypes on CPU * Enables test for complex * Update doc for `torch.all` and `torch.any` TODO * [x] Update docs * [x] Benchmark * [x] Raise issue on XLA Pull Request resolved: #47878 Reviewed By: albanD Differential Revision: D25714324 Pulled By: mruberry fbshipit-source-id: a87345f725297524242d69402dfe53060521ea5d
Summary: BC-breaking note: This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.) PR summary: pytorch#44790 (comment) Fixes 2 and 3 Also Fixes pytorch#48352 Changes * Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**) * Uses vectorized version for all dtypes on CPU * Enables test for complex * Update doc for `torch.all` and `torch.any` TODO * [x] Update docs * [x] Benchmark * [x] Raise issue on XLA Pull Request resolved: pytorch#47878 Reviewed By: albanD Differential Revision: D25714324 Pulled By: mruberry fbshipit-source-id: a87345f725297524242d69402dfe53060521ea5d
Reference #44779