Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ravel #46098

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -565,6 +565,7 @@ _(aten, randn_like) \
_(aten, random) \
_(aten, randperm) \
_(aten, range) \
_(aten, ravel) \
_(aten, reciprocal) \
_(aten, reflection_pad1d) \
_(aten, reflection_pad1d_backward) \
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/native/TensorShape.cpp
Expand Up @@ -1624,6 +1624,10 @@ Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) {
return native::flatten(self, *dims.begin(), *(dims.end() - 1), out_dim);
}

Tensor ravel(const Tensor& self) {
return self.reshape(-1);
}

Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional<DimnameList> names) {
dim = maybe_wrap_dim(dim, self.dim());

Expand All @@ -1633,7 +1637,7 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option
auto numel = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<int64_t>());
if (self.has_names()) {
TORCH_CHECK(numel == self.size(dim),
"unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ",
"unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ",
dim, " (", self.names()[dim], ": ", self.size(dim), ") in Tensor", self.names());
TORCH_CHECK(names, "unflatten: input is a named tensor but no names were given for unflattened sizes");
} else {
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2865,6 +2865,10 @@
CPU: range_cpu_out
CUDA: range_cuda_out

- func: ravel(Tensor(a) self) -> Tensor(a)
use_c10_dispatcher: full
variants: function, method

- func: reciprocal(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Expand Up @@ -495,6 +495,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: q_per_channel_axis
.. automethod:: rad2deg
.. automethod:: random_
.. automethod:: ravel
.. automethod:: reciprocal
.. automethod:: reciprocal_
.. automethod:: record_stream
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -462,6 +462,7 @@ Other Operations
meshgrid
lcm
logcumsumexp
ravel
renorm
repeat_interleave
roll
Expand Down
47 changes: 41 additions & 6 deletions test/test_torch.py
Expand Up @@ -1102,7 +1102,7 @@ def compare(t, k, dim, dir):
def test_topk_arguments(self):
q = torch.randn(10, 2, 10)
# Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1)
self.assertRaises(TypeError, lambda: q.topk(4, True))
self.assertRaises(TypeError, lambda: q.topk(4, True))

def test_mode(self):
x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE)
Expand Down Expand Up @@ -1887,6 +1887,41 @@ def _test_gather(self, cast, test_bounds=True):
def test_gather(self):
self._test_gather(self, lambda t: t)

def test_ravel(self):
zou3519 marked this conversation as resolved.
Show resolved Hide resolved
# Test that flatten returns 1-dim tensor when given a 0-dim tensor
zero_dim_tensor = torch.tensor(123)
flat0 = zero_dim_tensor.ravel()
one_dim_tensor = torch.tensor([123])
flat1 = zero_dim_tensor.ravel()

self.assertEqual(zero_dim_tensor.shape, torch.Size([]))
self.assertEqual(flat0.shape, torch.Size([1]))
self.assertEqual(one_dim_tensor.shape, torch.Size([1]))
self.assertEqual(flat1.shape, torch.Size([1]))
self.assertEqual(flat0, one_dim_tensor)
self.assertEqual(flat0, flat1)
self.assertEqual(flat0.shape, flat1.shape)

# Test both float tensor and quantized tensor
tensors = [torch.randn(5, 5),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We should probably test more shapes. The implementation code for ravel is simple enough that I can believe it works for everything, but there's no harm in testing more

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One edge case we should test is a tensor with a zero sized dimension. For example, the following should be true:

x = torch.randn([0, 2, 3]) # tensor of size [0, 2, 3]
y = x.ravel()
assert y.shape = ()

torch._empty_affine_quantized([5, 5],
scale=2,
zero_point=3,
dtype=torch.quint8),]
for src in tensors:
# Continuous Tensor -> View
flat = src.ravel()
self.assertEqual(flat.shape, torch.Size([25]))
self.assertEqual(src.view(-1), flat.view(-1))
zou3519 marked this conversation as resolved.
Show resolved Hide resolved
ejguan marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(flat._base, src)

# Non-continuous Tensor -> Copy
nc_src = src.t()
nc_flat = nc_src.ravel()
self.assertEqual(nc_flat.shape, torch.Size([25]))
self.assertEqual(nc_src.reshape(-1), nc_flat.reshape(-1))
ejguan marked this conversation as resolved.
Show resolved Hide resolved
self.assertTrue(nc_flat._base != nc_src)

zou3519 marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def _test_scatter_add_mult_index_base(self, cast):
m, n = 30, 40
Expand Down Expand Up @@ -6377,7 +6412,7 @@ def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nan
# Tests clamp and its alias, clip
@dtypes(torch.int64, torch.float32)
def test_clamp(self, device, dtype):
op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_,
op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_,
torch.clip, torch.Tensor.clip, torch.Tensor.clip_)

# min/max argument product
Expand Down Expand Up @@ -6405,7 +6440,7 @@ def test_clamp(self, device, dtype):
self.assertEqual(Y_expected, Y_out)

def test_clamp_propagates_nans(self, device):
op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_,
op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_,
torch.clip, torch.Tensor.clip, torch.Tensor.clip_)

# min/max argument product
Expand All @@ -6416,9 +6451,9 @@ def test_clamp_propagates_nans(self, device):
if min_val is None and max_val is None:
continue

X, Y_expected = self.generate_clamp_baseline(device, torch.float,
min_vals=min_val,
max_vals=max_val,
X, Y_expected = self.generate_clamp_baseline(device, torch.float,
min_vals=min_val,
max_vals=max_val,
with_nans=True)
Y_expected = torch.isnan(Y_expected)

Expand Down
7 changes: 7 additions & 0 deletions torch/_tensor_docs.py
Expand Up @@ -2670,6 +2670,13 @@ def callable(a, b) -> number
In-place version of :meth:`~Tensor.deg2rad`
""")

add_docstr_all('ravel',
r"""
ravel(input) -> Tensor

see :func:`torch.ravel`
""")

add_docstr_all('reciprocal',
r"""
reciprocal() -> Tensor
Expand Down
19 changes: 19 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -6490,6 +6490,25 @@ def merge_dicts(*dicts):
tensor([ 1.0000, 1.5000, 2.0000])
""".format(**factory_common_args))

add_docstr(torch.ravel,
r"""
ravel(input) -> Tensor

Return a contiguous flattened tensor. A copy is made only if needed.

Args:
{input}

Example::

>>> t = torch.tensor([[[1, 2],
... [3, 4]],
... [[5, 6],
... [7, 8]]])
>>> torch.ravel(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
""".format(**common_args))

add_docstr(torch.remainder,
r"""
remainder(input, other, *, out=None) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -680,6 +680,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.ravel: lambda input: -1,
torch.real: lambda input, out=None: -1,
torch.vdot: lambda mat1, mat2: -1,
torch.view_as_real: lambda input: -1,
Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -590,6 +590,7 @@ def method_tests():
('view', (S,), (S,), '1d', (False,)),
('view', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
('view', (), (1,), 'scalar_to_1d', (False,)),
('ravel', (S, S, S), NO_ARGS, '', (True,)),
ejguan marked this conversation as resolved.
Show resolved Hide resolved
('reshape', (S, S, S), (S * S, S), '', (False,)),
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
('reshape', (S,), (S,), '1d', (False,)),
Expand Down