Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ Tensor reshape(const Tensor& self, IntList proposed_shape) {
return at::_unsafe_view(self.clone(), shape);
}

Tensor reshape_as(const Tensor& self, const Tensor& other) {
return self.reshape(other.sizes());
}

Tensor select(const Tensor& self, int64_t dim, int64_t index) {
int64_t ndim = self.dim();
AT_CHECK(ndim > 0, "select() cannot be applied to a 0-dim tensor.");
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,9 @@

- func: reshape(Tensor self, IntList shape) -> Tensor

- func: reshape_as(Tensor self, Tensor other) -> Tensor
variants: method

- func: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale) -> (Tensor, Tensor)
variants: function
dispatch:
Expand Down
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: repeat
.. automethod:: requires_grad_
.. automethod:: reshape
.. automethod:: reshape_as
.. automethod:: resize_
.. automethod:: resize_as_
.. automethod:: round
Expand Down
3 changes: 3 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2620,6 +2620,9 @@ class dont_convert(tuple):
('reshape', (S,), (S,), '1d'),
('reshape', (), (dont_convert(()),), 'scalar_to_scalar'),
('reshape', (), (1,), 'scalar_to_1d'),
('reshape_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
('reshape_as', (), (non_differentiable(torch.tensor(42.)),), 'scalar'),
('reshape_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
('flip', (S, S, S), ([0],), 'd0'),
('flip', (S, S, S), ([0, 1, 2],), 'd012'),
('flip', (S, S, S), ([0, 2],), 'd02'),
Expand Down
5 changes: 5 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6027,6 +6027,11 @@ def test_reshape(self):
self.assertEqual(empty.reshape([1, -1]).shape, (0,))
self.assertRaises(RuntimeError, lambda: empty.reshape(1))

x = torch.randn(3, 3)
self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10)))

@skipIfNoZeroSize
def test_empty_reshape(self):
x = torch.randn(0, 6)
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@
self: repeat_backward(grad, self.dim(), repeats)

# DO NOT define a backward for reshape!
# reshape is special in that it sometimes returns a view, and somtimes not.
# reshape is special in that it sometimes returns a view, and sometimes not.
# Defining a backward will make codegen spit out the forward call as
# as_variable(baseType->reshape(self)),
# making it impossible (hard) to detect when it is actually a view.
Expand Down
42 changes: 42 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,6 +1666,20 @@ def callable(a, b) -> number
See :func:`torch.reshape`
""")

add_docstr_all('reshape_as',
r"""
reshape_as(other) -> Tensor

Returns this tensor as the same shape as :attr:`other`.
``self.reshape_as(other)`` is equivalent to ``self.reshape(other.sizes())``.

Please see :meth:`~Tensor.reshape` for more information about ``reshape``.

Args:
other (:class:`torch.Tensor`): The result tensor has the same shape
as :attr:`other`.
""")

add_docstr_all('resize_',
r"""
resize_(*sizes) -> Tensor
Expand Down Expand Up @@ -2407,6 +2421,20 @@ def callable(a, b) -> number

""")

add_docstr_all('view_as',
r"""
view_as(other) -> Tensor

View this tensor as the same size as :attr:`other`.

This comment was marked as off-topic.

``self.view_as(other)`` is equivalent to ``self.view(other.size())``.

Please see :meth:`~Tensor.view` for more information about ``view``.

Args:
other (:class:`torch.Tensor`): The result tensor has the same size
as :attr:`other`.
""")

add_docstr_all('expand',
r"""
expand(*sizes) -> Tensor
Expand Down Expand Up @@ -2445,6 +2473,20 @@ def callable(a, b) -> number
[ 3, 3, 3, 3]])
""")

add_docstr_all('expand_as',
r"""
expand_as(other) -> Tensor

Expand this tensor to the same size as :attr:`other`.
``self.expand_as(other)`` is equivalent to ``self.expand(other.size())``.

Please see :meth:`~Tensor.expand` for more information about ``expand``.

Args:
other (:class:`torch.Tensor`): The result tensor has the same size
as :attr:`other`.
""")

add_docstr_all('zero_',
r"""
zero_() -> Tensor
Expand Down
12 changes: 0 additions & 12 deletions torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,6 @@ def share_memory_(self):
self.storage().share_memory_()
return self

def view_as(self, tensor):
r"""view_as(other) -> Tensor

View this tensor as the same size as :attr:`other`.
``self.view_as(other)`` is equivalent to ``self.view(other.size())``.

Args:
other (:class:`torch.Tensor`): The result tensor has the same size
as :attr:`other.size()`.
"""
return self.view(tensor.size())

def __reversed__(self):
r"""Reverses the tensor along dimension 0."""
if self.dim() == 0:
Expand Down