Skip to content

Commit

Permalink
Implement ravel (more tests)
Browse files Browse the repository at this point in the history
ghstack-source-id: 6bc7834b0099a8d2e7f1fd4fc0c32ecfc59ddd71
Pull Request resolved: #46098
  • Loading branch information
ejguan committed Oct 10, 2020
1 parent 4c87d33 commit c6b39b7
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 7 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -565,6 +565,7 @@ _(aten, randn_like) \
_(aten, random) \
_(aten, randperm) \
_(aten, range) \
_(aten, ravel) \
_(aten, reciprocal) \
_(aten, reflection_pad1d) \
_(aten, reflection_pad1d_backward) \
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/native/TensorShape.cpp
Expand Up @@ -1624,6 +1624,10 @@ Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) {
return native::flatten(self, *dims.begin(), *(dims.end() - 1), out_dim);
}

Tensor ravel(const Tensor& self) {
return self.reshape(-1);
}

Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional<DimnameList> names) {
dim = maybe_wrap_dim(dim, self.dim());

Expand All @@ -1633,7 +1637,7 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option
auto numel = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<int64_t>());
if (self.has_names()) {
TORCH_CHECK(numel == self.size(dim),
"unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ",
"unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ",
dim, " (", self.names()[dim], ": ", self.size(dim), ") in Tensor", self.names());
TORCH_CHECK(names, "unflatten: input is a named tensor but no names were given for unflattened sizes");
} else {
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2869,6 +2869,10 @@
CPU: range_cpu_out
CUDA: range_cuda_out

- func: ravel(Tensor(a) self) -> Tensor(a)
use_c10_dispatcher: full
variants: function, method

- func: reciprocal(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Expand Up @@ -495,6 +495,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: q_per_channel_axis
.. automethod:: rad2deg
.. automethod:: random_
.. automethod:: ravel
.. automethod:: reciprocal
.. automethod:: reciprocal_
.. automethod:: record_stream
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -462,6 +462,7 @@ Other Operations
meshgrid
lcm
logcumsumexp
ravel
renorm
repeat_interleave
roll
Expand Down
71 changes: 65 additions & 6 deletions test/test_torch.py
Expand Up @@ -1102,7 +1102,7 @@ def compare(t, k, dim, dir):
def test_topk_arguments(self):
q = torch.randn(10, 2, 10)
# Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1)
self.assertRaises(TypeError, lambda: q.topk(4, True))
self.assertRaises(TypeError, lambda: q.topk(4, True))

def test_mode(self):
x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE)
Expand Down Expand Up @@ -1887,6 +1887,65 @@ def _test_gather(self, cast, test_bounds=True):
def test_gather(self):
self._test_gather(self, lambda t: t)

@staticmethod
def _test_ravel(self, tensors, size, nc=False):
for src in tensors:
# Continuous Tensor -> View
flat = src.ravel()
self.assertEqual(flat.shape, torch.Size([size]))
self.assertEqual(src.view(-1), flat)
self.assertEqual(flat._base, src)

# Non-continuous Tensor -> Copy
if nc:
nc_src = src.t()
nc_flat = nc_src.ravel()
self.assertEqual(nc_flat.shape, torch.Size([size]))
self.assertEqual(nc_src.reshape(-1), nc_flat)
self.assertTrue(nc_flat._base != nc_src)

def test_ravel(self):
# Test that flatten returns 1-dim tensor when given a 0-dim tensor
zero_dim_tensor = torch.tensor(123)
flat0 = zero_dim_tensor.ravel()
one_dim_tensor = torch.tensor([123])
flat1 = zero_dim_tensor.ravel()

self.assertEqual(zero_dim_tensor.shape, torch.Size([]))
self.assertEqual(flat0.shape, torch.Size([1]))
self.assertEqual(one_dim_tensor.shape, torch.Size([1]))
self.assertEqual(flat1.shape, torch.Size([1]))
self.assertEqual(flat0, one_dim_tensor)
self.assertEqual(flat0, flat1)
self.assertEqual(flat0.shape, flat1.shape)

# Test both float tensor and quantized tensor
tensors = [torch.randn(5, 5, 5, 5),
torch._empty_affine_quantized([5, 5, 5, 5],
scale=2,
zero_point=3,
dtype=torch.quint8)]
self._test_ravel(self, tensors, 625)

tensors = [torch.randn(0, 2, 3),
torch.randn(3, 0, 2),
torch._empty_affine_quantized([0, 2, 3],
scale=2,
zero_point=3,
dtype=torch.quint8),
torch._empty_affine_quantized([3, 0, 2],
scale=2,
zero_point=3,
dtype=torch.quint8)]
self._test_ravel(self, tensors, 0)

tensors = [torch.randn(5, 5),
torch._empty_affine_quantized([5, 5],
scale=2,
zero_point=3,
dtype=torch.quint8)]
self._test_ravel(self, tensors, 25, True)

@staticmethod
def _test_scatter_add_mult_index_base(self, cast):
m, n = 30, 40
Expand Down Expand Up @@ -6377,7 +6436,7 @@ def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nan
# Tests clamp and its alias, clip
@dtypes(torch.int64, torch.float32)
def test_clamp(self, device, dtype):
op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_,
op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_,
torch.clip, torch.Tensor.clip, torch.Tensor.clip_)

# min/max argument product
Expand Down Expand Up @@ -6405,7 +6464,7 @@ def test_clamp(self, device, dtype):
self.assertEqual(Y_expected, Y_out)

def test_clamp_propagates_nans(self, device):
op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_,
op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_,
torch.clip, torch.Tensor.clip, torch.Tensor.clip_)

# min/max argument product
Expand All @@ -6416,9 +6475,9 @@ def test_clamp_propagates_nans(self, device):
if min_val is None and max_val is None:
continue

X, Y_expected = self.generate_clamp_baseline(device, torch.float,
min_vals=min_val,
max_vals=max_val,
X, Y_expected = self.generate_clamp_baseline(device, torch.float,
min_vals=min_val,
max_vals=max_val,
with_nans=True)
Y_expected = torch.isnan(Y_expected)

Expand Down
7 changes: 7 additions & 0 deletions torch/_tensor_docs.py
Expand Up @@ -2670,6 +2670,13 @@ def callable(a, b) -> number
In-place version of :meth:`~Tensor.deg2rad`
""")

add_docstr_all('ravel',
r"""
ravel(input) -> Tensor
see :func:`torch.ravel`
""")

add_docstr_all('reciprocal',
r"""
reciprocal() -> Tensor
Expand Down
19 changes: 19 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -6490,6 +6490,25 @@ def merge_dicts(*dicts):
tensor([ 1.0000, 1.5000, 2.0000])
""".format(**factory_common_args))

add_docstr(torch.ravel,
r"""
ravel(input) -> Tensor
Return a contiguous flattened tensor. A copy is made only if needed.
Args:
{input}
Example::
>>> t = torch.tensor([[[1, 2],
... [3, 4]],
... [[5, 6],
... [7, 8]]])
>>> torch.ravel(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
""".format(**common_args))

add_docstr(torch.remainder,
r"""
remainder(input, other, *, out=None) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -680,6 +680,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.ravel: lambda input: -1,
torch.real: lambda input, out=None: -1,
torch.vdot: lambda mat1, mat2: -1,
torch.view_as_real: lambda input: -1,
Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -590,6 +590,7 @@ def method_tests():
('view', (S,), (S,), '1d', (False,)),
('view', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
('view', (), (1,), 'scalar_to_1d', (False,)),
('ravel', (S, S, S), NO_ARGS, '', (False,)),
('reshape', (S, S, S), (S * S, S), '', (False,)),
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
('reshape', (S,), (S,), '1d', (False,)),
Expand Down

0 comments on commit c6b39b7

Please sign in to comment.