Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions aten/src/ATen/native/cuda/IndexKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntAr
static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self, const Tensor & mask) {
NoNamesGuard guard;

TORCH_CHECK(self.scalar_type() != ScalarType::BFloat16,
"masked_select: bfloat16 not supported for CUDA implementation");
TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
"masked_select: expected BoolTensor or ByteTensor for mask");
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
Expand Down
12 changes: 9 additions & 3 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11256,7 +11256,6 @@ def test_masked_scatter_bool_tensor(self, device):
dst = dst.masked_scatter(mask, src)
self.assertEqual(dst, torch.tensor([True, True, True], device=device))

@dtypesIfCUDA(*[dtype for dtype in torch.testing.get_all_dtypes() if dtype not in [torch.bfloat16]])
@dtypes(*torch.testing.get_all_dtypes())
def test_masked_select(self, device, dtype):
if device == 'cpu':
Expand Down Expand Up @@ -11292,7 +11291,11 @@ def test_masked_select(self, device, dtype):
return

# Ensure that masks are expanded to match tensor properly
a = torch.rand(100, 100, device=device).mul(100).to(dtype)
if IS_WINDOWS and dtype == torch.bfloat16 and torch.device(device).type == 'cuda':
# TODO .to() for bfloat16 does not work on windows
a = torch.ones(100, 100, device=device, dtype=dtype)
else:
a = torch.rand(100, 100, device=device).mul(100).to(dtype)
mask_first_el_each_row = torch.zeros(100, device=device).bool()
mask_first_el_each_row[0] = True
a_masked = a.masked_select(mask_first_el_each_row)
Expand All @@ -11304,7 +11307,10 @@ def test_masked_select(self, device, dtype):
self.assertEqual(a_masked, a[0, :])

# Ensure that tensor is expanded to match mask properly
a = torch.rand(100, device=device).mul(100).to(maskType)
if IS_WINDOWS and dtype == torch.bfloat16 and torch.device(device).type == 'cuda':
a = torch.ones(100, device=device, dtype=dtype)
else:
a = torch.rand(100, device=device).mul(100).to(maskType)
mask_copy_3_times = torch.tensor([[True], [True], [False], [True]], device=device)
a_masked = a.masked_select(mask_copy_3_times)
self.assertEqual(a_masked, a.unsqueeze(0).expand(3, 100).flatten())
Expand Down