From c6b39b73f4a6341a056c7f43b1497b40d358e7e8 Mon Sep 17 00:00:00 2001 From: Erjia Guan Date: Sat, 10 Oct 2020 16:58:24 -0700 Subject: [PATCH] Implement ravel (more tests) ghstack-source-id: 6bc7834b0099a8d2e7f1fd4fc0c32ecfc59ddd71 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46098 --- aten/src/ATen/core/aten_interned_strings.h | 1 + aten/src/ATen/native/TensorShape.cpp | 6 +- aten/src/ATen/native/native_functions.yaml | 4 ++ docs/source/tensors.rst | 1 + docs/source/torch.rst | 1 + test/test_torch.py | 71 +++++++++++++++++-- torch/_tensor_docs.py | 7 ++ torch/_torch_docs.py | 19 +++++ torch/overrides.py | 1 + .../_internal/common_methods_invocations.py | 1 + 10 files changed, 105 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index ce780f58f3a6..da259e82990a 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -565,6 +565,7 @@ _(aten, randn_like) \ _(aten, random) \ _(aten, randperm) \ _(aten, range) \ +_(aten, ravel) \ _(aten, reciprocal) \ _(aten, reflection_pad1d) \ _(aten, reflection_pad1d_backward) \ diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 5aac8e9a1715..a42e90f399d9 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1624,6 +1624,10 @@ Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) { return native::flatten(self, *dims.begin(), *(dims.end() - 1), out_dim); } +Tensor ravel(const Tensor& self) { + return self.reshape(-1); +} + Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional names) { dim = maybe_wrap_dim(dim, self.dim()); @@ -1633,7 +1637,7 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option auto numel = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); if (self.has_names()) { TORCH_CHECK(numel == self.size(dim), - "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ", + "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ", dim, " (", self.names()[dim], ": ", self.size(dim), ") in Tensor", self.names()); TORCH_CHECK(names, "unflatten: input is a named tensor but no names were given for unflattened sizes"); } else { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f35845d8cea1..9888b21082aa 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2869,6 +2869,10 @@ CPU: range_cpu_out CUDA: range_cuda_out +- func: ravel(Tensor(a) self) -> Tensor(a) + use_c10_dispatcher: full + variants: function, method + - func: reciprocal(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index a4d2ac805a8f..421445093a66 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -495,6 +495,7 @@ view of a storage and defines numeric operations on it. .. automethod:: q_per_channel_axis .. automethod:: rad2deg .. automethod:: random_ + .. automethod:: ravel .. automethod:: reciprocal .. automethod:: reciprocal_ .. automethod:: record_stream diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 1fc46780e6ac..b3c8410300c6 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -462,6 +462,7 @@ Other Operations meshgrid lcm logcumsumexp + ravel renorm repeat_interleave roll diff --git a/test/test_torch.py b/test/test_torch.py index bd159a25b0bd..260d45c641b4 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1102,7 +1102,7 @@ def compare(t, k, dim, dir): def test_topk_arguments(self): q = torch.randn(10, 2, 10) # Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1) - self.assertRaises(TypeError, lambda: q.topk(4, True)) + self.assertRaises(TypeError, lambda: q.topk(4, True)) def test_mode(self): x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE) @@ -1887,6 +1887,65 @@ def _test_gather(self, cast, test_bounds=True): def test_gather(self): self._test_gather(self, lambda t: t) + @staticmethod + def _test_ravel(self, tensors, size, nc=False): + for src in tensors: + # Continuous Tensor -> View + flat = src.ravel() + self.assertEqual(flat.shape, torch.Size([size])) + self.assertEqual(src.view(-1), flat) + self.assertEqual(flat._base, src) + + # Non-continuous Tensor -> Copy + if nc: + nc_src = src.t() + nc_flat = nc_src.ravel() + self.assertEqual(nc_flat.shape, torch.Size([size])) + self.assertEqual(nc_src.reshape(-1), nc_flat) + self.assertTrue(nc_flat._base != nc_src) + + def test_ravel(self): + # Test that flatten returns 1-dim tensor when given a 0-dim tensor + zero_dim_tensor = torch.tensor(123) + flat0 = zero_dim_tensor.ravel() + one_dim_tensor = torch.tensor([123]) + flat1 = zero_dim_tensor.ravel() + + self.assertEqual(zero_dim_tensor.shape, torch.Size([])) + self.assertEqual(flat0.shape, torch.Size([1])) + self.assertEqual(one_dim_tensor.shape, torch.Size([1])) + self.assertEqual(flat1.shape, torch.Size([1])) + self.assertEqual(flat0, one_dim_tensor) + self.assertEqual(flat0, flat1) + self.assertEqual(flat0.shape, flat1.shape) + + # Test both float tensor and quantized tensor + tensors = [torch.randn(5, 5, 5, 5), + torch._empty_affine_quantized([5, 5, 5, 5], + scale=2, + zero_point=3, + dtype=torch.quint8)] + self._test_ravel(self, tensors, 625) + + tensors = [torch.randn(0, 2, 3), + torch.randn(3, 0, 2), + torch._empty_affine_quantized([0, 2, 3], + scale=2, + zero_point=3, + dtype=torch.quint8), + torch._empty_affine_quantized([3, 0, 2], + scale=2, + zero_point=3, + dtype=torch.quint8)] + self._test_ravel(self, tensors, 0) + + tensors = [torch.randn(5, 5), + torch._empty_affine_quantized([5, 5], + scale=2, + zero_point=3, + dtype=torch.quint8)] + self._test_ravel(self, tensors, 25, True) + @staticmethod def _test_scatter_add_mult_index_base(self, cast): m, n = 30, 40 @@ -6377,7 +6436,7 @@ def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nan # Tests clamp and its alias, clip @dtypes(torch.int64, torch.float32) def test_clamp(self, device, dtype): - op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, + op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, torch.clip, torch.Tensor.clip, torch.Tensor.clip_) # min/max argument product @@ -6405,7 +6464,7 @@ def test_clamp(self, device, dtype): self.assertEqual(Y_expected, Y_out) def test_clamp_propagates_nans(self, device): - op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, + op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, torch.clip, torch.Tensor.clip, torch.Tensor.clip_) # min/max argument product @@ -6416,9 +6475,9 @@ def test_clamp_propagates_nans(self, device): if min_val is None and max_val is None: continue - X, Y_expected = self.generate_clamp_baseline(device, torch.float, - min_vals=min_val, - max_vals=max_val, + X, Y_expected = self.generate_clamp_baseline(device, torch.float, + min_vals=min_val, + max_vals=max_val, with_nans=True) Y_expected = torch.isnan(Y_expected) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 33d2249af284..208cd5805c4b 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2670,6 +2670,13 @@ def callable(a, b) -> number In-place version of :meth:`~Tensor.deg2rad` """) +add_docstr_all('ravel', + r""" +ravel(input) -> Tensor + +see :func:`torch.ravel` +""") + add_docstr_all('reciprocal', r""" reciprocal() -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 1998672d6b34..757140cf9391 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -6490,6 +6490,25 @@ def merge_dicts(*dicts): tensor([ 1.0000, 1.5000, 2.0000]) """.format(**factory_common_args)) +add_docstr(torch.ravel, + r""" +ravel(input) -> Tensor + +Return a contiguous flattened tensor. A copy is made only if needed. + +Args: + {input} + +Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.ravel(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) +""".format(**common_args)) + add_docstr(torch.remainder, r""" remainder(input, other, *, out=None) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index 8b363b9f2bf0..d64d7a4f37a4 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -680,6 +680,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch.ravel: lambda input: -1, torch.real: lambda input, out=None: -1, torch.vdot: lambda mat1, mat2: -1, torch.view_as_real: lambda input: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index af08a7f9d953..f92e2e96f360 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -590,6 +590,7 @@ def method_tests(): ('view', (S,), (S,), '1d', (False,)), ('view', (), (dont_convert(()),), 'scalar_to_scalar', (False,)), ('view', (), (1,), 'scalar_to_1d', (False,)), + ('ravel', (S, S, S), NO_ARGS, '', (False,)), ('reshape', (S, S, S), (S * S, S), '', (False,)), ('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)), ('reshape', (S,), (S,), '1d', (False,)),