Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with any/all when all reduce dimensions of input have size 1 #2590

Merged
merged 2 commits into from
Nov 11, 2020

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Oct 30, 2020

This is to fix #2585, in pytorch pr pytorch/pytorch#44790, any and all can take non-boolean input. With our current lowering, reduce won't do anything if all reduce dimensions of the input all have size 1.

(Pdb) torch.all(torch.tensor([-147.5]).to(device))
tensor(-147.5000, device='xla:0')

Add a xla::Select when all reduce dimensions have size 1 to force the result value to be 1 or 0

Copy link
Collaborator

@davidel davidel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with using xla::Zeros()

torch_xla/csrc/reduction.cpp Outdated Show resolved Hide resolved
@JackCaoG
Copy link
Collaborator Author

Can not repo the TestNNDeviceTypeXLA.test_GroupNorm_empty_xla failure, will take another look if it failed again.

@JackCaoG
Copy link
Collaborator Author

Test failure might be due to the pinned pytorch_pr needs a rebase. I manually applied the patch on the pytorch head and run all test with this pr and all test passed.

@JackCaoG JackCaoG changed the title Fix issue with any/all when input has only one element Fix issue with any/all when all reduce dimensions of input all have size 1 Nov 2, 2020
@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Nov 2, 2020

XLA::NE will change the result type to PRED but torch.any and torch.all expect the result to be the same type as the input. Printing the result of torch.any/all with this implementation will cause error (test performs torch.any().cpu() so it did not catch this).

Use the xla::select instead to fix this.

@JackCaoG JackCaoG changed the title Fix issue with any/all when all reduce dimensions of input all have size 1 Fix issue with any/all when all reduce dimensions of input have size 1 Nov 2, 2020
@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Nov 2, 2020

All test passed locally

torch_xla/csrc/reduction.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/reduction.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/reduction.cpp Outdated Show resolved Hide resolved
@JackCaoG JackCaoG merged commit a917acb into master Nov 11, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[PT_BREAK] [numpy] torch.{all, any} : Extend Dtype Support
2 participants