Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds dim argument to torch.unique #10423

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
84 changes: 84 additions & 0 deletions aten/src/ATen/native/Unique.cpp
Expand Up @@ -47,6 +47,82 @@ std::tuple<Tensor, Tensor> _unique_cpu_template(
}
return std::make_tuple(output, inverse_indices);
}

template<class ForwardIt>
ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
std::vector<int64_t>& indices, Tensor inverse_indices_vec) {
if (first == last) {
return last;
}
// save to calculate distance to iterators
ForwardIt begin = first;

// set first inverse index
inverse_indices_vec[indices[0]] = 0;

ForwardIt result = first;
while (++first != last) {
if (!at::equal(*result, *first) && ++result != first) {
*result = std::move(*first);
}
int64_t idx_result = std::distance(begin, result);
int64_t idx_first = std::distance(begin, first);
inverse_indices_vec[indices[idx_first]] = idx_result;
}

return ++result;
}

template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse) {
// reshape tensor as [dim, -1]
Tensor input_flat = self.transpose(dim, 0);
auto orig_sizes = input_flat.sizes().vec();
input_flat = input_flat.contiguous().view({input_flat.size(0), -1});

std::vector<int64_t> indices(input_flat.size(0));
std::iota(indices.begin(), indices.end(), 0);
int64_t numel = input_flat.size(1);
scalar_t* input_flat_ptr = ((scalar_t*)input_flat.data_ptr());

// sort indices using data
std::sort(indices.begin(), indices.end(),
[&](int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < numel; ++i) {
scalar_t lhs = input_flat_ptr[i + a * numel];
scalar_t rhs = input_flat_ptr[i + b * numel];
if (lhs < rhs) {
return true;
} else if (lhs > rhs) {
return false;
}
}
return false;
});

Tensor input_sorted = at::empty(input_flat.sizes(), input_flat.type());
for (int i = 0; i < indices.size(); ++i) {
input_sorted[i] = input_flat[indices[i]];
}

Tensor inverse_indices = at::empty(indices.size(), self.type().toScalarType(kLong));
std::vector<Tensor> input_unbind = at::unbind(input_sorted, 0);
auto last = _unique_dim_cpu_impl(
input_unbind.begin(), input_unbind.end(), indices, inverse_indices);
input_unbind.erase(last, input_unbind.end());

// reshape back
auto output = at::stack(input_unbind, 0);
auto new_sizes = std::vector<int64_t>(orig_sizes);
new_sizes[0] = -1;
output = output.view(new_sizes);
output = output.transpose(0, dim);

return std::make_tuple(output, inverse_indices);
}
} // namespace

std::tuple<Tensor, Tensor>
Expand All @@ -56,5 +132,13 @@ _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
});
}

std::tuple<Tensor, Tensor>
_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
// The current implementation using `dim` always sorts due to unhashable tensors
return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse);
});
}

} // namespace native
} // namespace at
97 changes: 97 additions & 0 deletions aten/src/ATen/native/cuda/Unique.cu
Expand Up @@ -69,6 +69,92 @@ template <typename scalar_t>
return std::tuple<Tensor, Tensor>(output, inverse_indices);

}

template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_dim_cuda_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse) {

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);

Tensor input_flat = self.transpose(dim, 0);
auto orig_sizes = input_flat.sizes().vec();
input_flat = input_flat.contiguous().view({input_flat.size(0), -1});

scalar_t* input_flat_ptr = input_flat.data<scalar_t>();

Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong));
int64_t* indices_ptr = indices.data<int64_t>();
int64_t numel = input_flat.size(1);

// sort indices using data
thrust::sort(policy, indices_ptr, indices_ptr + indices.numel(),
[=] __device__ (int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < numel; ++i) {
scalar_t lhs = input_flat_ptr[i + a * numel];
scalar_t rhs = input_flat_ptr[i + b * numel];
if (lhs < rhs) {
return true;
} else if (lhs > rhs) {
return false;
}
}
return false;
});

Tensor input_sorted = input_flat.index_select(0, indices);

// get unique tensors
scalar_t* input_sorted_ptr = input_sorted.data<scalar_t>();
Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong));
int64_t* input_sorted_indices_ptr = input_sorted_indices.data<int64_t>();
auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(),
[=] __device__ (int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < numel; ++i) {
scalar_t lhs = input_sorted_ptr[i + a * numel];
scalar_t rhs = input_sorted_ptr[i + b * numel];
if (lhs != rhs) {
return false;
}
}
return true;
});
input_sorted_indices.resize_(last - input_sorted_indices_ptr);
Tensor output = input_sorted.index_select(0, input_sorted_indices);

// reshape back
auto new_sizes = std::vector<int64_t>(orig_sizes);
new_sizes[0] = -1;
output = output.view(new_sizes);
output = output.transpose(0, dim);

// calculate inverse indices
Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
if (return_inverse) {
int64_t size = self.size(dim);
inverse_indices.resize_(size);
Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong));
mask[0] = 1;
for (int i = 0; i < input_sorted.size(0) - 1; ++i) {
if (!at::equal(input_sorted[i], input_sorted[i+1])) {
mask[i+1] = 1;
} else {
mask[i+1] = 0;
}
}

Tensor imask = at::cumsum(mask, 0) - 1;
for (int i = 0; i < indices.size(0); ++i) {
inverse_indices[indices[i]] = imask[i];
}
}

THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor>(output, inverse_indices);
}
} // namespace

#endif
Expand All @@ -86,5 +172,16 @@ _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
#endif
}

std::tuple<Tensor, Tensor>
_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
#ifndef __HIP_PLATFORM_HCC__
return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
});
#else
AT_ERROR("unique_dim_cuda: HIP not supported");
#endif
}

} // namespace native
} // namespace at
5 changes: 5 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -1634,6 +1634,11 @@
CPU: _unique_cpu
CUDA: _unique_cuda

- func: _unique_dim(Tensor self, int64_t dim, bool sorted=false, bool return_inverse=false) -> (Tensor, Tensor)
dispatch:
CPU: _unique_dim_cpu
CUDA: _unique_dim_cuda

- func: _unsafe_view(Tensor self, IntList size) -> Tensor
variants: function

Expand Down
61 changes: 61 additions & 0 deletions test/test_torch.py
Expand Up @@ -8331,6 +8331,67 @@ def test_unique(self):
self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique)
self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse)

def test_unique_dim(self):
def run_test(dtype=torch.float):
x = torch.tensor([[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]],
[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]]], dtype=dtype)
expected_unique_dim0 = torch.tensor([[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]]], dtype=dtype)
expected_inverse_dim0 = torch.tensor([0, 0])
expected_unique_dim1 = torch.tensor([[[0., 1.],
[1., 1.],
[2., 1.]],
[[0., 1.],
[1., 1.],
[2., 1.]]], dtype=dtype)
expected_inverse_dim1 = torch.tensor([1, 0, 2, 0])
expected_unique_dim2 = torch.tensor([[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]],
[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]]], dtype=dtype)
expected_inverse_dim2 = torch.tensor([0, 1])

# dim0
x_unique = torch.unique(x, dim=0)
self.assertEqual(expected_unique_dim0, x_unique)

x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0)
self.assertEqual(expected_unique_dim0, x_unique)
self.assertEqual(expected_inverse_dim0, x_inverse)

# dim1
x_unique = torch.unique(x, dim=1)
self.assertEqual(expected_unique_dim1, x_unique)

x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1)
self.assertEqual(expected_unique_dim1, x_unique)
self.assertEqual(expected_inverse_dim1, x_inverse)

# dim2
x_unique = torch.unique(x, dim=2)
self.assertEqual(expected_unique_dim2, x_unique)

x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2)
self.assertEqual(expected_unique_dim2, x_unique)
self.assertEqual(expected_inverse_dim2, x_inverse)

run_test(torch.float)
run_test(torch.double)
run_test(torch.long)
run_test(torch.uint8)

@staticmethod
def _test_bincount(self, device):
# negative input throws
Expand Down
20 changes: 14 additions & 6 deletions torch/functional.py
Expand Up @@ -314,7 +314,7 @@ def isnan(tensor):
return tensor != tensor


def unique(input, sorted=False, return_inverse=False):
def unique(input, sorted=False, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the input tensor as a 1-D tensor.

Arguments:
Expand Down Expand Up @@ -356,11 +356,19 @@ def unique(input, sorted=False, return_inverse=False):
[ 1, 2]])

"""
output, inverse_indices = torch._unique(
input,
sorted=sorted,
return_inverse=return_inverse,
)
if dim is not None:
output, inverse_indices = torch._unique_dim(
input,
dim,
sorted=sorted,
return_inverse=return_inverse
)
else:
output, inverse_indices = torch._unique(
input,
sorted=sorted,
return_inverse=return_inverse,
)
if return_inverse:
return output, inverse_indices
else:
Expand Down
15 changes: 12 additions & 3 deletions torch/tensor.py
Expand Up @@ -305,13 +305,22 @@ def masked_scatter(self, mask, tensor):
def masked_fill(self, mask, value):
return self.clone().masked_fill_(mask, value)

def unique(self, sorted=False, return_inverse=False):
def unique(self, sorted=False, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the tensor as a 1-D tensor.

See :func:`torch.unique`
"""
output, inverse_indices = self._unique(
sorted=sorted, return_inverse=return_inverse)
if dim is not None:
output, inverse_indices = self._unique_dim(
sorted=sorted,
return_inverse=return_inverse,
dim=dim
)
else:
output, inverse_indices = self._unique(
sorted=sorted,
return_inverse=return_inverse
)
if return_inverse:
return output, inverse_indices
else:
Expand Down