Skip to content

Commit

Permalink
Enabled where for bool tensor on CUDA (#26430)
Browse files Browse the repository at this point in the history
Summary:
Enabled "where_cuda" for bool tensors on CUDA
Fixing #26247
Tested via unit tests
Pull Request resolved: #26430

Differential Revision: D17464181

Pulled By: izdeby

fbshipit-source-id: cbb09925753b2e6f35e7400da3243d4d3fc86b69
  • Loading branch information
izdeby authored and facebook-github-bot committed Sep 19, 2019
1 parent aad8738 commit f673def
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/TensorCompare.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Tensor _s_where_cuda(
const Tensor& self,
const Tensor& other) {
Tensor ret = at::empty(self.sizes(), self.options());
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ret.scalar_type(), "where_cuda", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, ret.scalar_type(), "where_cuda", [&] {
where_cuda<scalar_t>(ret, condition, self, other);
});
return ret;
Expand Down
6 changes: 6 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,12 @@ def test(size):
test((10,))
test((5, 5))

def test_where_bool_tensor(self):
for d in torch.testing.get_all_device_types():
a = torch.tensor([True, False], device=d)
res = torch.where(a > 0)
self.assertEqual(1, len(res))

def test_all_any_with_dim(self):
def test(x):
r1 = x.prod(dim=0, keepdim=False).byte()
Expand Down

0 comments on commit f673def

Please sign in to comment.