diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index d9bd94e1f7810b9..d5ff300c0dd9e26 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -47,6 +47,82 @@ std::tuple _unique_cpu_template( } return std::make_tuple(output, inverse_indices); } + +template +ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last, + std::vector& indices, Tensor inverse_indices_vec) { + if (first == last) { + return last; + } + // save to calculate distance to iterators + ForwardIt begin = first; + + // set first inverse index + inverse_indices_vec[indices[0]] = 0; + + ForwardIt result = first; + while (++first != last) { + if (!at::equal(*result, *first) && ++result != first) { + *result = std::move(*first); + } + int64_t idx_result = std::distance(begin, result); + int64_t idx_first = std::distance(begin, first); + inverse_indices_vec[indices[idx_first]] = idx_result; + } + + return ++result; + } + +template +std::tuple _unique_dim_cpu_template( + const Tensor& self, + const int64_t dim, + const bool return_inverse) { + // reshape tensor as [dim, -1] + Tensor input_flat = self.transpose(dim, 0); + auto orig_sizes = input_flat.sizes().vec(); + input_flat = input_flat.contiguous().view({input_flat.size(0), -1}); + + std::vector indices(input_flat.size(0)); + std::iota(indices.begin(), indices.end(), 0); + int64_t numel = input_flat.size(1); + scalar_t* input_flat_ptr = ((scalar_t*)input_flat.data_ptr()); + + // sort indices using data + std::sort(indices.begin(), indices.end(), + [&](int64_t a, int64_t b) -> bool { + for (int64_t i = 0; i < numel; ++i) { + scalar_t lhs = input_flat_ptr[i + a * numel]; + scalar_t rhs = input_flat_ptr[i + b * numel]; + if (lhs < rhs) { + return true; + } else if (lhs > rhs) { + return false; + } + } + return false; + }); + + Tensor input_sorted = at::empty(input_flat.sizes(), input_flat.type()); + for (int i = 0; i < indices.size(); ++i) { + input_sorted[i] = input_flat[indices[i]]; + } + + Tensor inverse_indices = at::empty(indices.size(), self.type().toScalarType(kLong)); + std::vector input_unbind = at::unbind(input_sorted, 0); + auto last = _unique_dim_cpu_impl( + input_unbind.begin(), input_unbind.end(), indices, inverse_indices); + input_unbind.erase(last, input_unbind.end()); + + // reshape back + auto output = at::stack(input_unbind, 0); + auto new_sizes = std::vector(orig_sizes); + new_sizes[0] = -1; + output = output.view(new_sizes); + output = output.transpose(0, dim); + + return std::make_tuple(output, inverse_indices); +} } // namespace std::tuple @@ -56,5 +132,13 @@ _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) { }); } +std::tuple +_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) { + return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] { + // The current implementation using `dim` always sorts due to unhashable tensors + return _unique_dim_cpu_template(self, dim, return_inverse); + }); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu index f2e13b4c708b621..c29337f90f1347f 100644 --- a/aten/src/ATen/native/cuda/Unique.cu +++ b/aten/src/ATen/native/cuda/Unique.cu @@ -69,6 +69,92 @@ template return std::tuple(output, inverse_indices); } + +template + std::tuple _unique_dim_cuda_template( + const Tensor& self, + const int64_t dim, + const bool return_inverse) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + Tensor input_flat = self.transpose(dim, 0); + auto orig_sizes = input_flat.sizes().vec(); + input_flat = input_flat.contiguous().view({input_flat.size(0), -1}); + + scalar_t* input_flat_ptr = input_flat.data(); + + Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong)); + int64_t* indices_ptr = indices.data(); + int64_t numel = input_flat.size(1); + + // sort indices using data + thrust::sort(policy, indices_ptr, indices_ptr + indices.numel(), + [=] __device__ (int64_t a, int64_t b) -> bool { + for (int64_t i = 0; i < numel; ++i) { + scalar_t lhs = input_flat_ptr[i + a * numel]; + scalar_t rhs = input_flat_ptr[i + b * numel]; + if (lhs < rhs) { + return true; + } else if (lhs > rhs) { + return false; + } + } + return false; + }); + + Tensor input_sorted = input_flat.index_select(0, indices); + + // get unique tensors + scalar_t* input_sorted_ptr = input_sorted.data(); + Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong)); + int64_t* input_sorted_indices_ptr = input_sorted_indices.data(); + auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(), + [=] __device__ (int64_t a, int64_t b) -> bool { + for (int64_t i = 0; i < numel; ++i) { + scalar_t lhs = input_sorted_ptr[i + a * numel]; + scalar_t rhs = input_sorted_ptr[i + b * numel]; + if (lhs != rhs) { + return false; + } + } + return true; + }); + input_sorted_indices.resize_(last - input_sorted_indices_ptr); + Tensor output = input_sorted.index_select(0, input_sorted_indices); + + // reshape back + auto new_sizes = std::vector(orig_sizes); + new_sizes[0] = -1; + output = output.view(new_sizes); + output = output.transpose(0, dim); + + // calculate inverse indices + Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong)); + if (return_inverse) { + int64_t size = self.size(dim); + inverse_indices.resize_(size); + Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong)); + mask[0] = 1; + for (int i = 0; i < input_sorted.size(0) - 1; ++i) { + if (!at::equal(input_sorted[i], input_sorted[i+1])) { + mask[i+1] = 1; + } else { + mask[i+1] = 0; + } + } + + Tensor imask = at::cumsum(mask, 0) - 1; + for (int i = 0; i < indices.size(0); ++i) { + inverse_indices[indices[i]] = imask[i]; + } + } + + THCudaCheck(cudaGetLastError()); + return std::tuple(output, inverse_indices); + } } // namespace #endif @@ -86,5 +172,16 @@ _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) { #endif } +std::tuple +_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) { + #ifndef __HIP_PLATFORM_HCC__ + return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] { + return _unique_dim_cuda_template(self, dim, return_inverse); + }); + #else + AT_ERROR("unique_dim_cuda: HIP not supported"); + #endif +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d8b632efc98bf04..b8a7d1cf50c610f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1634,6 +1634,11 @@ CPU: _unique_cpu CUDA: _unique_cuda +- func: _unique_dim(Tensor self, int64_t dim, bool sorted=false, bool return_inverse=false) -> (Tensor, Tensor) + dispatch: + CPU: _unique_dim_cpu + CUDA: _unique_dim_cuda + - func: _unsafe_view(Tensor self, IntList size) -> Tensor variants: function diff --git a/test/test_torch.py b/test/test_torch.py index 20e1f59c84e43a8..bb256466a093026 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -8331,6 +8331,67 @@ def test_unique(self): self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique) self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse) + def test_unique_dim(self): + def run_test(dtype=torch.float): + x = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]], + [[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], dtype=dtype) + expected_unique_dim0 = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], dtype=dtype) + expected_inverse_dim0 = torch.tensor([0, 0]) + expected_unique_dim1 = torch.tensor([[[0., 1.], + [1., 1.], + [2., 1.]], + [[0., 1.], + [1., 1.], + [2., 1.]]], dtype=dtype) + expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) + expected_unique_dim2 = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]], + [[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], dtype=dtype) + expected_inverse_dim2 = torch.tensor([0, 1]) + + # dim0 + x_unique = torch.unique(x, dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + self.assertEqual(expected_inverse_dim0, x_inverse) + + # dim1 + x_unique = torch.unique(x, dim=1) + self.assertEqual(expected_unique_dim1, x_unique) + + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1) + self.assertEqual(expected_unique_dim1, x_unique) + self.assertEqual(expected_inverse_dim1, x_inverse) + + # dim2 + x_unique = torch.unique(x, dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + self.assertEqual(expected_inverse_dim2, x_inverse) + + run_test(torch.float) + run_test(torch.double) + run_test(torch.long) + run_test(torch.uint8) + @staticmethod def _test_bincount(self, device): # negative input throws diff --git a/torch/functional.py b/torch/functional.py index e6a2ee21208c6e4..8efb7451dc2dc77 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -314,7 +314,7 @@ def isnan(tensor): return tensor != tensor -def unique(input, sorted=False, return_inverse=False): +def unique(input, sorted=False, return_inverse=False, dim=None): r"""Returns the unique scalar elements of the input tensor as a 1-D tensor. Arguments: @@ -356,11 +356,19 @@ def unique(input, sorted=False, return_inverse=False): [ 1, 2]]) """ - output, inverse_indices = torch._unique( - input, - sorted=sorted, - return_inverse=return_inverse, - ) + if dim is not None: + output, inverse_indices = torch._unique_dim( + input, + dim, + sorted=sorted, + return_inverse=return_inverse + ) + else: + output, inverse_indices = torch._unique( + input, + sorted=sorted, + return_inverse=return_inverse, + ) if return_inverse: return output, inverse_indices else: diff --git a/torch/tensor.py b/torch/tensor.py index 3fa47cbcb86514f..5be865bd6adb531 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -305,13 +305,22 @@ def masked_scatter(self, mask, tensor): def masked_fill(self, mask, value): return self.clone().masked_fill_(mask, value) - def unique(self, sorted=False, return_inverse=False): + def unique(self, sorted=False, return_inverse=False, dim=None): r"""Returns the unique scalar elements of the tensor as a 1-D tensor. See :func:`torch.unique` """ - output, inverse_indices = self._unique( - sorted=sorted, return_inverse=return_inverse) + if dim is not None: + output, inverse_indices = self._unique_dim( + sorted=sorted, + return_inverse=return_inverse, + dim=dim + ) + else: + output, inverse_indices = self._unique( + sorted=sorted, + return_inverse=return_inverse + ) if return_inverse: return output, inverse_indices else: