Skip to content

Commit

Permalink
Add tests and fix doc
Browse files Browse the repository at this point in the history
  • Loading branch information
soulitzer committed Dec 23, 2020
1 parent f474ffa commit dada6f2
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 1 deletion.
61 changes: 61 additions & 0 deletions test/test_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def is_view_of(self, base, other):

return True

# Returns true if v1 and v2 are views of the same base
def is_view_of_same_base(self, v1, v2):
if (not v1._is_view() or v1 is v2):
return False
return self.is_view_of(v1._base, v2)

# Performs transpose if contiguous=True, else returns the input tensor as is
def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
if contiguous:
Expand Down Expand Up @@ -457,6 +463,61 @@ def test_reshape_nonview(self, device):
nv[6] = 0
self.assertNotEqual(t[1, 1], nv[6])

def test_flatten_view(self, device):
def assert_is_view(t, v, t_is_view=False):
idx_t = (0,) * t.ndim
idx_v = (0,) * v.ndim
if not t_is_view:
self.assertTrue(self.is_view_of(t, v))
else:
self.assertTrue(self.is_view_of_same_base(t, v))
v[idx_v] = 0
self.assertEqual(t[idx_t], v[idx_v])

t = torch.ones(1, 2, 3, 4, device=device)
v = t.flatten()
assert_is_view(t, v)

# zero-dimensional tensor
t = torch.tensor(1, device=device)
v = t.flatten()
assert_is_view(t, v)

t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
v = t.flatten(0, 1)
assert_is_view(t, v, True)

# stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
t = torch.ones(720, device=device) \
.as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
# [--1--|---2---|-3-] [--1--|----2---|-3-]
v1 = t.flatten(0, 1)
v2 = v1.flatten(1, 3)
v3 = v2.flatten(2, 2)
assert_is_view(t, v1, True)
assert_is_view(t, v2, True)
assert_is_view(t, v3, True)

def test_flatten_nonview(self, device):
def assert_is_nonview(t, nv):
idx_t = (0,) * t.ndim
idx_nv = (0,) * nv.ndim
self.assertTrue(not nv._is_view())
nv[idx_nv] = 0
self.assertNotEqual(t[idx_t], nv[idx_nv])
t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
nv = t.flatten(1, 3)
assert_is_nonview(t, nv)

t = torch.ones(2, 2, device=device).T
nv = t.flatten()
assert_is_nonview(t, nv)

# flatten returns the original object if start_dim=end_dim
t = t = torch.ones(2, 2, device=device)
nv = t.flatten(1, 1)
self.assertTrue(t is nv)

def test_basic_indexing_slice_view(self, device):
t = torch.ones(5, 5, device=device)
v = t[:2, :3]
Expand Down
15 changes: 14 additions & 1 deletion torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3091,7 +3091,20 @@ def merge_dicts(*dicts):
r"""
flatten(input, start_dim=0, end_dim=-1) -> Tensor
Flattens a contiguous range of dims in a tensor.
By default, reshapes :attr:`input` into a one-dimensional tensor while preserving its row-major order.
If :attr:`start_dim` or :attr:`end_dim` are passed, :attr:`end_dim` cannot come before :attr:`start_dim`,
and only dimensions starting at :attr:`start_dim` and ending at :attr:`end_dim` are flattened.
If `start_dim < end_dim`, flattening an :attr:`input` with shape :math:`(l*, f+, r*)` produces an output
shape of :math:`(l*, \prod f+, r*)` where :math:`l*` is the sizes of the zero or more "left" dimensions
before :attr:`start_dim`, :math:`f+` is the sizes of the one or more dimensions between :attr:`start_dim` and
:attr:`end_dim` that will be "flattened", and :math:`r*` is the sizes of the zero or more "right" dimensions
after :attr:`end_dim`. It is equivalent to calling :math:`\text{torch.reshape}(\textit{input}, (l*, \prod f+, r*))`
and like :func:`torch.reshape` returns a view when possible. See :meth:`torch.Tensor.view` for details on
when a view will be returned.
If `start_dim == end_dim` then this returns the original object :attr:`input` if :attr:`input` has one or more
dimensions, and a one-dimensional view if :attr:`input` is a zero-dimensional tensor.
Args:
{input}
Expand Down

0 comments on commit dada6f2

Please sign in to comment.