From 0904426d2f223079a04d318d632d6ab1695fbd80 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Wed, 28 Feb 2018 18:25:50 -0800 Subject: [PATCH 01/17] Initial commit for unique op --- aten/src/ATen/native/Unique.cpp | 0 aten/src/ATen/native/native_functions.yaml | 3 +++ test/test_torch.py | 3 +++ 3 files changed, 6 insertions(+) create mode 100644 aten/src/ATen/native/Unique.cpp diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f7dc2d1203418..fc2737c45355e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -363,6 +363,9 @@ - func: type_as(Tensor self, Tensor other) -> Tensor variants: method +- func: unique(Tensor self, bool return_inverse=false) -> (Tensor, Tensor) + variants: method + - func: _unsafe_view(Tensor self, IntList size) -> Tensor variants: function diff --git a/test/test_torch.py b/test/test_torch.py index 1632f424fa984..2444e559f3340 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5422,6 +5422,9 @@ def test_set_flush_denormal(self): self.assertEqual(double_tensor[2], 0.0, prec=0.0) # tiny_double to zero torch.set_flush_denormal(False) + def test_unique(self): + pass + # Functions to test negative dimension wrapping METHOD = 1 INPLACE_METHOD = 2 From a8d5ffa6ec8095437d802b597dc3e73cb8ef5aa1 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Thu, 1 Mar 2018 10:52:36 -0800 Subject: [PATCH 02/17] Working unique with test --- aten/src/ATen/native/Unique.cpp | 43 ++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 +- test/test_torch.py | 43 +++++++++++++++++++++- 3 files changed, 85 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index e69de29bb2d1d..c61589ba54220 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -0,0 +1,43 @@ +#include +#include + +#include "ATen/ATen.h" + +namespace at { +namespace native{ + +std::tuple unique_1d( + const Tensor& self, const bool sorted, const bool return_inverse) { + std::unordered_set set( + self.data(), self.data() + self.numel()); + Tensor output = self.type().tensor({static_cast(set.size())}); + + if (!sorted) { + std::copy(set.begin(), set.end(), output.data()); + } else { + std::vector vec(set.begin(), set.end()); + std::sort(vec.begin(), vec.end()); + std::copy(vec.begin(), vec.end(), output.data()); + } + + Tensor inverse_indices; + if (!return_inverse) { + inverse_indices = self.type().toScalarType(kLong).tensor({0}); + + } else { + inverse_indices = self.type().toScalarType(kLong).tensor({self.numel()}); + std::unordered_map inverse_map; + inverse_map.reserve(output.numel()); + for (int i = 0; i < output.numel(); ++i) { + inverse_map[output.data()[i]] = i; + } + for (int i = 0; i < self.numel(); ++i) { + inverse_indices.data()[i] = inverse_map[self.data()[i]]; + } + } + + return std::make_tuple(output, inverse_indices); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index fc2737c45355e..5e3ee40fbbd80 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -363,7 +363,7 @@ - func: type_as(Tensor self, Tensor other) -> Tensor variants: method -- func: unique(Tensor self, bool return_inverse=false) -> (Tensor, Tensor) +- func: unique_1d(IndexTensor self, bool sorted=false, bool return_inverse=false) -> (Tensor, Tensor) variants: method - func: _unsafe_view(Tensor self, IntList size) -> Tensor diff --git a/test/test_torch.py b/test/test_torch.py index 2444e559f3340..0eae3322e285d 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5423,8 +5423,47 @@ def test_set_flush_denormal(self): torch.set_flush_denormal(False) def test_unique(self): - pass - + x = torch.LongTensor([1, 2, 3, 2, 8, 5, 2, 3]) + expected_unique = torch.LongTensor([1, 2, 3, 5, 8]) + expected_inverse = torch.LongTensor([0, 1, 2, 1, 4, 3, 1, 2]) + empty_inverse = torch.LongTensor([]) + + if TEST_NUMPY: + x_unique, x_inverse = x.unique_1d() + self.assertEqual( + expected_unique.numpy().tolist(), + sorted(x_unique.numpy().tolist()) + ) + self.assertEqual(empty_inverse, x_inverse) + + x_unique, x_inverse = x.unique_1d(return_inverse=True) + self.assertEqual( + expected_unique.numpy().tolist(), + sorted(x_unique.numpy().tolist()) + ) + self.assertEqual(expected_inverse.numel(), x_inverse.numel()) + + x_unique, x_inverse = x.unique_1d(sorted=True) + self.assertEqual(expected_unique, x_unique) + self.assertEqual(empty_inverse, x_inverse) + + x_unique, x_inverse = x.unique_1d(sorted=True, return_inverse=True) + self.assertEqual(expected_unique, x_unique) + self.assertEqual(expected_inverse, x_inverse) + + # Tests 1-D unique on a higher rank tensor. + y = x.view(2, 2, 2) + y_unique, y_inverse = y.unique_1d(sorted=True, return_inverse=True) + self.assertEqual(expected_unique, y_unique) + self.assertEqual(expected_inverse, y_inverse) + + # Tests invalid use cases. + self.assertRaises( + RuntimeError, lambda: torch.IntTensor([1, 2, 3]).unique_1d()) + self.assertRaises( + RuntimeError, lambda: torch.FloatTensor([1., 2.5, 3.5]).unique_1d()) + + # Functions to test negative dimension wrapping METHOD = 1 INPLACE_METHOD = 2 From 55a8b5ae03310c13e7a3a74e33218eaf290c86bc Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Thu, 1 Mar 2018 12:19:31 -0800 Subject: [PATCH 03/17] Make inverse indices shape conform to input --- aten/src/ATen/native/Unique.cpp | 6 ++++-- aten/src/ATen/native/native_functions.yaml | 2 +- test/test_torch.py | 18 +++++++++--------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index c61589ba54220..d3c41a209613e 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -1,3 +1,5 @@ +// Returns unique elements of input tensor. + #include #include @@ -6,7 +8,7 @@ namespace at { namespace native{ -std::tuple unique_1d( +std::tuple unique( const Tensor& self, const bool sorted, const bool return_inverse) { std::unordered_set set( self.data(), self.data() + self.numel()); @@ -25,7 +27,7 @@ std::tuple unique_1d( inverse_indices = self.type().toScalarType(kLong).tensor({0}); } else { - inverse_indices = self.type().toScalarType(kLong).tensor({self.numel()}); + inverse_indices = self.type().toScalarType(kLong).tensor(self.sizes()); std::unordered_map inverse_map; inverse_map.reserve(output.numel()); for (int i = 0; i < output.numel(); ++i) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5e3ee40fbbd80..8854c1581b41d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -363,7 +363,7 @@ - func: type_as(Tensor self, Tensor other) -> Tensor variants: method -- func: unique_1d(IndexTensor self, bool sorted=false, bool return_inverse=false) -> (Tensor, Tensor) +- func: unique(IndexTensor self, bool sorted=false, bool return_inverse=false) -> (Tensor, Tensor) variants: method - func: _unsafe_view(Tensor self, IntList size) -> Tensor diff --git a/test/test_torch.py b/test/test_torch.py index 0eae3322e285d..58f4f4c8eb621 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5429,39 +5429,39 @@ def test_unique(self): empty_inverse = torch.LongTensor([]) if TEST_NUMPY: - x_unique, x_inverse = x.unique_1d() + x_unique, x_inverse = x.unique() self.assertEqual( expected_unique.numpy().tolist(), sorted(x_unique.numpy().tolist()) ) self.assertEqual(empty_inverse, x_inverse) - x_unique, x_inverse = x.unique_1d(return_inverse=True) + x_unique, x_inverse = x.unique(return_inverse=True) self.assertEqual( expected_unique.numpy().tolist(), sorted(x_unique.numpy().tolist()) ) self.assertEqual(expected_inverse.numel(), x_inverse.numel()) - x_unique, x_inverse = x.unique_1d(sorted=True) + x_unique, x_inverse = x.unique(sorted=True) self.assertEqual(expected_unique, x_unique) self.assertEqual(empty_inverse, x_inverse) - x_unique, x_inverse = x.unique_1d(sorted=True, return_inverse=True) + x_unique, x_inverse = x.unique(sorted=True, return_inverse=True) self.assertEqual(expected_unique, x_unique) self.assertEqual(expected_inverse, x_inverse) - # Tests 1-D unique on a higher rank tensor. + # Tests per-element unique on a higher rank tensor. y = x.view(2, 2, 2) - y_unique, y_inverse = y.unique_1d(sorted=True, return_inverse=True) + y_unique, y_inverse = y.unique(sorted=True, return_inverse=True) self.assertEqual(expected_unique, y_unique) - self.assertEqual(expected_inverse, y_inverse) + self.assertEqual(expected_inverse.view(y.size()), y_inverse) # Tests invalid use cases. self.assertRaises( - RuntimeError, lambda: torch.IntTensor([1, 2, 3]).unique_1d()) + RuntimeError, lambda: torch.IntTensor([1, 2, 3]).unique()) self.assertRaises( - RuntimeError, lambda: torch.FloatTensor([1., 2.5, 3.5]).unique_1d()) + RuntimeError, lambda: torch.FloatTensor([1., 2.5, 3.5]).unique()) # Functions to test negative dimension wrapping From a197a7baef2f3e4ebe0859229f9b3979e17335f9 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Thu, 1 Mar 2018 12:26:29 -0800 Subject: [PATCH 04/17] flake8 whitespace removal --- test/test_torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index 58f4f4c8eb621..3c4f668a7e767 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5433,7 +5433,7 @@ def test_unique(self): self.assertEqual( expected_unique.numpy().tolist(), sorted(x_unique.numpy().tolist()) - ) + ) self.assertEqual(empty_inverse, x_inverse) x_unique, x_inverse = x.unique(return_inverse=True) @@ -5442,7 +5442,7 @@ def test_unique(self): sorted(x_unique.numpy().tolist()) ) self.assertEqual(expected_inverse.numel(), x_inverse.numel()) - + x_unique, x_inverse = x.unique(sorted=True) self.assertEqual(expected_unique, x_unique) self.assertEqual(empty_inverse, x_inverse) @@ -5456,13 +5456,13 @@ def test_unique(self): y_unique, y_inverse = y.unique(sorted=True, return_inverse=True) self.assertEqual(expected_unique, y_unique) self.assertEqual(expected_inverse.view(y.size()), y_inverse) - + # Tests invalid use cases. self.assertRaises( RuntimeError, lambda: torch.IntTensor([1, 2, 3]).unique()) self.assertRaises( RuntimeError, lambda: torch.FloatTensor([1., 2.5, 3.5]).unique()) - + # Functions to test negative dimension wrapping METHOD = 1 From 3c263099ee606ad1e2bb740589603f44ee762d90 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Thu, 1 Mar 2018 13:13:11 -0800 Subject: [PATCH 05/17] address review comment nits --- aten/src/ATen/native/Unique.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index d3c41a209613e..7fe2ef5f2513f 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -1,10 +1,10 @@ // Returns unique elements of input tensor. +#include "ATen/ATen.h" + #include #include -#include "ATen/ATen.h" - namespace at { namespace native{ @@ -14,19 +14,16 @@ std::tuple unique( self.data(), self.data() + self.numel()); Tensor output = self.type().tensor({static_cast(set.size())}); - if (!sorted) { - std::copy(set.begin(), set.end(), output.data()); - } else { + if (sorted) { std::vector vec(set.begin(), set.end()); std::sort(vec.begin(), vec.end()); std::copy(vec.begin(), vec.end(), output.data()); + } else { + std::copy(set.begin(), set.end(), output.data()); } Tensor inverse_indices; - if (!return_inverse) { - inverse_indices = self.type().toScalarType(kLong).tensor({0}); - - } else { + if (return_inverse) { inverse_indices = self.type().toScalarType(kLong).tensor(self.sizes()); std::unordered_map inverse_map; inverse_map.reserve(output.numel()); @@ -36,6 +33,8 @@ std::tuple unique( for (int i = 0; i < self.numel(); ++i) { inverse_indices.data()[i] = inverse_map[self.data()[i]]; } + } else { + inverse_indices = self.type().toScalarType(kLong).tensor({0}); } return std::make_tuple(output, inverse_indices); From b48b427606a8d930a60497e84cc1589b7ac52e0c Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Thu, 1 Mar 2018 18:08:50 -0800 Subject: [PATCH 06/17] Expose fn and add docs. Explicitly declare no gradients --- aten/src/ATen/native/native_functions.yaml | 5 +- test/test_torch.py | 13 ++--- tools/autograd/derivatives.yaml | 3 ++ torch/_tensor_docs.py | 7 +++ torch/_torch_docs.py | 56 ++++++++++++++++++++++ 5 files changed, 75 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8854c1581b41d..711dbba396ba2 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -363,9 +363,8 @@ - func: type_as(Tensor self, Tensor other) -> Tensor variants: method -- func: unique(IndexTensor self, bool sorted=false, bool return_inverse=false) -> (Tensor, Tensor) - variants: method - +- func: unique(Tensor self, bool sorted=false, bool return_inverse=false) -> (Tensor, Tensor) + - func: _unsafe_view(Tensor self, IntList size) -> Tensor variants: function diff --git a/test/test_torch.py b/test/test_torch.py index 3c4f668a7e767..3fbe4a1ae6b17 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5429,7 +5429,7 @@ def test_unique(self): empty_inverse = torch.LongTensor([]) if TEST_NUMPY: - x_unique, x_inverse = x.unique() + x_unique, x_inverse = torch.unique(x) self.assertEqual( expected_unique.numpy().tolist(), sorted(x_unique.numpy().tolist()) @@ -5447,7 +5447,8 @@ def test_unique(self): self.assertEqual(expected_unique, x_unique) self.assertEqual(empty_inverse, x_inverse) - x_unique, x_inverse = x.unique(sorted=True, return_inverse=True) + x_unique, x_inverse = torch.autograd.Variable.unique( + x, sorted=True, return_inverse=True) self.assertEqual(expected_unique, x_unique) self.assertEqual(expected_inverse, x_inverse) @@ -5458,10 +5459,10 @@ def test_unique(self): self.assertEqual(expected_inverse.view(y.size()), y_inverse) # Tests invalid use cases. - self.assertRaises( - RuntimeError, lambda: torch.IntTensor([1, 2, 3]).unique()) - self.assertRaises( - RuntimeError, lambda: torch.FloatTensor([1., 2.5, 3.5]).unique()) + #self.assertRaises( + # RuntimeError, lambda: torch.IntTensor([1, 2, 3]).unique()) + #self.assertRaises( + # RuntimeError, lambda: torch.FloatTensor([1., 2.5, 3.5]).unique()) # Functions to test negative dimension wrapping diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index b53c067198e62..e4391c328714e 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -637,6 +637,9 @@ - name: uniform_(Tensor self, double from, double to, Generator generator) self: zeros_like(grad) +- name: unique(Tensor self, bool sorted, bool return_inverse) + self: not_implemented("unique") + - name: _unsafe_view(Tensor self, IntList size) self: grad.contiguous().view(self.sizes()) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index aa635170510fe..c3fc92f15f458 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1795,6 +1795,13 @@ def callable(a, b) -> number P(x) = \dfrac{1}{to - from} """) +add_docstr_all('unique', + r""" +unique(sorted=False, return_inverse=False) + +See :func:`torch.unique` +""") + add_docstr_all('unsqueeze', r""" unsqueeze(dim) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 2518a29286f42..a09e22bbb8dab 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4926,6 +4926,62 @@ """) +add_docstr(torch.unique, + r""" +unique(input, sorted=False, return_inverse=False) -> (Tensor, Tensor) + +Returns the unique scalar elements of the input tensor as list. + +Args: + input (Tensor): the input tensor + sorted (bool): Whether to sort the unique elements in ascending order + before returning as output. + return_inverse (bool): Whether to also return the indices for where + elements in the original input ended up in the returned unique list. + +Returns: + (Tensor, Tensor): A tuple containing + + - **output** (*Tensor*): the list of unique scalar elements + - **inverse_indices** (*Tensor*): the indices (same shape as input) + for where elements in the original input map to in the output + if ``return_inverse`` is ``True``; otherwise, an empty tensor. + +Example:: + + >>>> output, inverse_indices = torch.unique( + torch.LongTensor([1, 3, 2, 3]), sorted=True, return_inverse=True) + >>>> output + + 1 + 2 + 3 + [torch.LongTensor of size (3,)] + + >>>> inverse_indices + + 0 + 2 + 1 + 2 + [torch.LongTensor of size (4,)] + + >>>> output, inverse_indices = torch.unique( + torch.LongTensor([[1, 3], [2, 3]]), sorted=True, return_inverse=True) + >>>> output + + 1 + 2 + 3 + [torch.LongTensor of size (3,)] + + >>>> inverse_indices + + 0 2 + 1 2 + [torch.LongTensor of size (2,2)] +""") + add_docstr(torch.unsqueeze, r""" unsqueeze(input, dim, out=None) From 00244e8fc4c1f93a519a5bce6c779d79bc1d7754 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Thu, 1 Mar 2018 18:59:02 -0800 Subject: [PATCH 07/17] Trial generic dispatch implementation --- aten/src/ATen/native/Unique.cpp | 59 ++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index 7fe2ef5f2513f..d944ada90c99a 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -1,6 +1,9 @@ // Returns unique elements of input tensor. #include "ATen/ATen.h" +#include "ATen/Dispatch.h" +#include "ATen/ExpandUtils.h" +#include "ATen/NativeFunctions.h" #include #include @@ -8,35 +11,53 @@ namespace at { namespace native{ -std::tuple unique( - const Tensor& self, const bool sorted, const bool return_inverse) { - std::unordered_set set( - self.data(), self.data() + self.numel()); - Tensor output = self.type().tensor({static_cast(set.size())}); - +namespace { +template +std::tuple unique_template( + const Tensor& self, + const bool sorted, + const bool return_inverse, + Tensor* output, + Tensor* inverse_indices) { + std::unordered_set set( + self.data(), self.data() + self.numel()); + output->resize_({static_cast(set.size())}); + if (sorted) { - std::vector vec(set.begin(), set.end()); + std::vector vec(set.begin(), set.end()); std::sort(vec.begin(), vec.end()); - std::copy(vec.begin(), vec.end(), output.data()); + std::copy(vec.begin(), vec.end(), output->data()); } else { - std::copy(set.begin(), set.end(), output.data()); + std::copy(set.begin(), set.end(), output->data()); } - Tensor inverse_indices; if (return_inverse) { - inverse_indices = self.type().toScalarType(kLong).tensor(self.sizes()); - std::unordered_map inverse_map; - inverse_map.reserve(output.numel()); - for (int i = 0; i < output.numel(); ++i) { - inverse_map[output.data()[i]] = i; + inverse_indices->resize_(self.sizes()); + std::unordered_map inverse_map; + inverse_map.reserve(output->numel()); + for (int i = 0; i < output->numel(); ++i) { + inverse_map[output->data()[i]] = i; } for (int i = 0; i < self.numel(); ++i) { - inverse_indices.data()[i] = inverse_map[self.data()[i]]; + inverse_indices->data()[i] = + inverse_map[self.data()[i]]; } - } else { - inverse_indices = self.type().toScalarType(kLong).tensor({0}); } - +} +} // namespace + +std::tuple +unique(const Tensor& self, const bool sorted, const bool return_inverse) { + // output will be resized in unique_template once we know how big it is. + // inverse_indices may also be resized depending on return_inverse. + Tensor output = self.type().tensor({0}); + Tensor inverse_indices = self.type().toScalarType(kLong).tensor({0}); + + AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] { + unique_template( + self, sorted, return_inverse, &output, &inverse_indices); + }); + return std::make_tuple(output, inverse_indices); } From bd49d0ad3c00e4cc70c2f4da8f06782f4f546f7a Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Thu, 1 Mar 2018 19:54:18 -0800 Subject: [PATCH 08/17] Add tests for generics --- aten/src/ATen/native/Unique.cpp | 8 +++++--- test/test_torch.py | 26 +++++++++++++++++++++----- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index d944ada90c99a..5e824e2420424 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -2,18 +2,20 @@ #include "ATen/ATen.h" #include "ATen/Dispatch.h" -#include "ATen/ExpandUtils.h" -#include "ATen/NativeFunctions.h" #include #include +#include +using std::cout; +using std::endl; + namespace at { namespace native{ namespace { template -std::tuple unique_template( +void unique_template( const Tensor& self, const bool sorted, const bool return_inverse, diff --git a/test/test_torch.py b/test/test_torch.py index 3fbe4a1ae6b17..97033c7c5e8eb 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5458,11 +5458,27 @@ def test_unique(self): self.assertEqual(expected_unique, y_unique) self.assertEqual(expected_inverse.view(y.size()), y_inverse) - # Tests invalid use cases. - #self.assertRaises( - # RuntimeError, lambda: torch.IntTensor([1, 2, 3]).unique()) - #self.assertRaises( - # RuntimeError, lambda: torch.FloatTensor([1., 2.5, 3.5]).unique()) + # Tests unique on other types. + int_unique, int_inverse = torch.unique( + torch.IntTensor([2, 1, 2]), sorted=True, return_inverse=True) + self.assertEqual(torch.IntTensor([1, 2]), int_unique) + self.assertEqual(torch.LongTensor([1, 0, 1]), int_inverse) + + double_unique, double_inverse = torch.unique( + torch.DoubleTensor([2., 1.5, 2.1, 2.]), + sorted=True, + return_inverse=True, + ) + self.assertEqual(torch.DoubleTensor([1.5, 2., 2.1]), double_unique) + self.assertEqual(torch.LongTensor([1, 0, 2, 1]), double_inverse) + + byte_unique, byte_inverse = torch.unique( + torch.ByteTensor([133, 7, 7, 7, 42, 128]), + sorted=True, + return_inverse=True, + ) + self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique) + self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse) # Functions to test negative dimension wrapping From a3dc7b23e8efad32c99507eb24fc35c6c0ccdec4 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Thu, 1 Mar 2018 20:42:31 -0800 Subject: [PATCH 09/17] flake8 whitespace --- torch/_torch_docs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index a09e22bbb8dab..3b799cafd6a3a 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4937,7 +4937,7 @@ sorted (bool): Whether to sort the unique elements in ascending order before returning as output. return_inverse (bool): Whether to also return the indices for where - elements in the original input ended up in the returned unique list. + elements in the original input ended up in the returned unique list. Returns: (Tensor, Tensor): A tuple containing @@ -4952,12 +4952,12 @@ >>>> output, inverse_indices = torch.unique( torch.LongTensor([1, 3, 2, 3]), sorted=True, return_inverse=True) >>>> output - + 1 2 3 [torch.LongTensor of size (3,)] - + >>>> inverse_indices 0 From a016ea268bed7294ced66f93f4b94b9bfc7f1e8b Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Fri, 2 Mar 2018 13:25:37 -0800 Subject: [PATCH 10/17] Add basic CUDA error throwing and templateize set --- aten/src/ATen/native/Unique.cpp | 40 ++++++++++------------ aten/src/ATen/native/cuda/Unique.cu | 17 +++++++++ aten/src/ATen/native/native_functions.yaml | 5 ++- test/test_torch.py | 14 ++++++-- 4 files changed, 52 insertions(+), 24 deletions(-) create mode 100644 aten/src/ATen/native/cuda/Unique.cu diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index 5e824e2420424..354588bc6a61d 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -3,35 +3,26 @@ #include "ATen/ATen.h" #include "ATen/Dispatch.h" +#include +#include #include #include -#include -using std::cout; -using std::endl; - namespace at { namespace native{ namespace { -template -void unique_template( + +template