Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions aten/src/ATen/native/Unique.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,28 @@ ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
if (first == last) {
return last;
}
// save to calculate distance to iterators
ForwardIt begin = first;

// set first inverse index and count
inverse_indices_vec[indices[0]] = 0;
counts[0] += 1;
TORCH_INTERNAL_ASSERT(inverse_indices_vec.is_contiguous(),
"_unique_dim_cpu_impl only support contiguous inverse_indices_vec");
TORCH_INTERNAL_ASSERT(counts.is_contiguous(),
"_unique_dim_cpu_impl only support contiguous counts");

int64_t *indices_data = indices.data();
int64_t *inverse_data = inverse_indices_vec.data_ptr<int64_t>();
int64_t *counts_data = counts.data_ptr<int64_t>();

ForwardIt result = first;
while (++first != last) {
if (!at::equal(*result, *first) && ++result != first) {
*result = std::move(*first);
ForwardIt previous = first;
int64_t *current_counts = counts_data;
for (ForwardIt current = first; current != last; current++) {
if (!at::equal(*current, *result)) {
*(++result) = std::move(*current);
*(current_counts++) = std::distance(previous, current);
previous = current;
}
int64_t idx_result = std::distance(begin, result);
int64_t idx_first = std::distance(begin, first);
inverse_indices_vec[indices[idx_first]] = idx_result;
counts[idx_result] += 1;
inverse_data[*(indices_data++)] = std::distance(first, result);
}

*current_counts = std::distance(previous, last);
return ++result;
}

Expand Down Expand Up @@ -275,7 +279,7 @@ unique_dim_consecutive_cpu(const Tensor& self, const int64_t dim, const bool ret

std::tuple<Tensor, Tensor, Tensor>
unique_consecutive_cpu(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional<int64_t> dim) {
if (!dim.has_value()) {
if (!dim.has_value() || (dim.value() == 0 && self.dim() == 1)) {
return AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "unique", [&] {
return unique_consecutive_cpu_template<scalar_t>(self, return_inverse, return_counts);
});
Expand Down