Skip to content

Commit

Permalink
[MPS] Add support for aten::masked_select on mps (#119) (#85818)
Browse files Browse the repository at this point in the history
Reuse the `index.Tensor_out` implementation since it's already expanding the bool/byte indices to long tensors.
Pull Request resolved: #85818
Approved by: https://github.com/kulinseth
  • Loading branch information
DenisVieriu97 authored and mehtanirav committed Oct 4, 2022
1 parent a64142b commit c473e22
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 4 deletions.
1 change: 0 additions & 1 deletion aten/src/ATen/mps/MPSFallback.mm
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)
m.impl("linalg_vector_norm", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("sgn.out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("nonzero", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("masked_select", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
}

} // namespace at
38 changes: 37 additions & 1 deletion aten/src/ATen/native/mps/operations/Indexing.mm
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
return true;
}


static void validateInputData(const TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, const std::string& op, bool accumulate) {
using namespace mps;

Expand Down Expand Up @@ -186,6 +185,43 @@ void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_size, IntArray
}
}

static Tensor & masked_select_out_mps_impl(Tensor & result, const Tensor & self, const Tensor & mask) {
NoNamesGuard guard;

TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
"masked_select: expected BoolTensor or ByteTensor for mask");
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
"masked_select(): self and result must have the same scalar type");

auto mask_temp = (mask.dim() == 0)
? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0))
: c10::MaybeOwned<Tensor>::borrowed(mask);
auto self_temp = (self.dim() == 0)
? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0))
: c10::MaybeOwned<Tensor>::borrowed(self);

// Cannot reassign to mask_temp and self_temp here! if they are
// owning and expand_outplace returns a borrow, the returned borrow
// would dangle.
auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
at::index_out(
result, *std::get<1>(mask_self_expanded),
c10::List<c10::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}));

return result;
}

Tensor masked_select_mps(const Tensor & self, const Tensor & mask) {
namedinference::compute_broadcast_outnames(self, mask);
Tensor result = at::empty({0}, self.options());
return masked_select_out_mps_impl(result, self, mask);
}

Tensor & masked_select_out_mps(const Tensor & self, const Tensor & mask, Tensor & result) {
namedinference::compute_broadcast_outnames(self, mask);
return masked_select_out_mps_impl(result, self, mask);
}

Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
using namespace mps;

Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8026,13 +8026,15 @@
dispatch:
CPU: masked_select_out_cpu
CUDA: masked_select_out_cuda
MPS: masked_select_out_mps
tags: dynamic_output_shape

- func: masked_select(Tensor self, Tensor mask) -> Tensor
variants: method, function
dispatch:
CPU: masked_select_cpu
CUDA: masked_select_cuda
MPS: masked_select_mps
tags: dynamic_output_shape

- func: masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor
Expand Down
15 changes: 13 additions & 2 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6041,6 +6041,17 @@ class TestAdvancedIndexing(TestCase):
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]

def test_masked_select(self):
x = torch.randn(3, 4)
x_mps = x.to("mps")
mask = x.ge(0.5)
mask_mps = x_mps.ge(0.5)

res = torch.masked_select(x, mask)
res_mps = torch.masked_select(x_mps, mask_mps)

self.assertEqual(res, res_mps)

# examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm
def test_indexing_get(self):
def helper(dtype):
Expand Down Expand Up @@ -6383,10 +6394,10 @@ def test_index_put_accumulate_duplicate_indices(self, device="mps"):
delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1)

# cumsum not supported on 'mps', fallback on 'cpu'
indices = delta.to("cpu").cumsum(0).long().to("mps")
indices = delta.cpu().cumsum(0).long().to("mps")

# abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
input = torch.randn(indices.to("cpu").abs().to("mps").max() + 1, device=device)
input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device)
values = torch.randn(indices.size(0), device=device)
output = input.index_put((indices,), values, accumulate=True)

Expand Down

0 comments on commit c473e22

Please sign in to comment.