Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.movedim #41480

Closed
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
47b53e2
add moveaxis impl
kshitij12345 Jul 15, 2020
2ec66e9
add docs
kshitij12345 Jul 15, 2020
b15a905
add dim uniqueness check
kshitij12345 Jul 15, 2020
e73f592
remove native namespace qualifier
kshitij12345 Jul 16, 2020
2206c0e
add signature to _overrides
kshitij12345 Jul 16, 2020
2e58ae6
add autograd test
kshitij12345 Jul 16, 2020
b02c386
add test for invalid args
kshitij12345 Jul 16, 2020
6eccc99
add test
kshitij12345 Jul 16, 2020
75d46c6
address comments
kshitij12345 Jul 16, 2020
2cc8692
update overload name
kshitij12345 Jul 16, 2020
88b7209
Merge branch 'master' into numpy/develop/moveaxis
kshitij12345 Jul 16, 2020
92f1ac9
update error msg and corresponding test
kshitij12345 Jul 17, 2020
f3c0976
use self_dim variable
kshitij12345 Jul 17, 2020
5703583
use DimVector
kshitij12345 Jul 17, 2020
66960dc
add algorithm walkthrough
kshitij12345 Jul 17, 2020
345104c
test_autograd: use device in constructing tensor
kshitij12345 Jul 17, 2020
b54bb51
Merge branch 'master' into numpy/develop/moveaxis
kshitij12345 Jul 18, 2020
2b3d6bb
remove stray change from merge
kshitij12345 Jul 18, 2020
978f9d4
changes
kshitij12345 Jul 20, 2020
22e8f5a
update doc examples
kshitij12345 Jul 20, 2020
ce3923d
update doc
kshitij12345 Jul 20, 2020
c5e5da5
update doc
kshitij12345 Jul 20, 2020
1f6f7cb
changes
kshitij12345 Jul 20, 2020
e77d322
changes
kshitij12345 Jul 20, 2020
0b6d29d
add to tensor_view.rst
kshitij12345 Jul 20, 2020
e88f99d
changes
kshitij12345 Jul 20, 2020
9fe2399
fix function call
kshitij12345 Jul 20, 2020
4b48232
changes
kshitij12345 Jul 20, 2020
184696e
change argument names to match numpy
kshitij12345 Jul 21, 2020
f4cf3c1
update src dst to source and destination
kshitij12345 Jul 21, 2020
cc19379
Merge branch 'master' into numpy/develop/moveaxis
kshitij12345 Jul 21, 2020
7fbd2fc
Merge branch 'master' into numpy/develop/moveaxis
kshitij12345 Jul 21, 2020
9bcd29c
address changes
kshitij12345 Jul 21, 2020
3ad19aa
add extra asserts after std::remove
kshitij12345 Jul 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 42 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Expand Up @@ -1693,4 +1693,46 @@ Tensor& diag_cpu_out(Tensor &result, const Tensor& self, int64_t dimension) {
return result;
}

Tensor moveaxis(const Tensor& self, IntArrayRef src, IntArrayRef dst) {
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
TORCH_CHECK(src.size() == dst.size(), "moveaxis: Invalid source or destination dims: src (",
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
src.size(), " dims ) should contain the same number of dims as dst (", dst.size(), "dims)");
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
auto normalized_src = src.vec();
maybe_wrap_dims(normalized_src, self.dim());
auto normalized_dst = dst.vec();
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
maybe_wrap_dims(normalized_dst, self.dim());
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

auto it_src = std::unique(normalized_src.begin(), normalized_src.end());
TORCH_CHECK(it_src == normalized_src.end(), "moveaxis: repeated axis in `src` (", src, ")");
auto it_dst = std::unique(normalized_dst.begin(), normalized_dst.end());
TORCH_CHECK(it_dst == normalized_dst.end(), "moveaxis: repeated axis in `dst` (", dst, ")");
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

std::vector<int64_t> order, source_dims, destination_dims;
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
order.resize(self.dim());
source_dims.resize(self.dim());
destination_dims.resize(self.dim());

std::iota(source_dims.begin(), source_dims.end(), 0);
std::iota(destination_dims.begin(), destination_dims.end(), 0);

for (int64_t i = 0; i < src.size(); ++i) {
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
order[normalized_dst[i]] = normalized_src[i];
source_dims[normalized_src[i]] = -1;
destination_dims[normalized_dst[i]] = -1;
}

auto source_iter = std::remove(source_dims.begin(), source_dims.end(), -1);
auto destination_iter = std::remove(destination_dims.begin(), destination_dims.end(), -1);
Comment on lines +1757 to +1758
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An internal linter pointed these out. source_iter and destination_iter are never actually used, can we remove them and just do:

std::remove(source_dims.begin(), source_dims.end(), -1);
std::remove(destination_dims.begin(), destination_dims.end(), -1);

?

Copy link
Contributor

@zou3519 zou3519 Jul 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I thought about it a little more. It would be nice to use source_iter and destination_iter to assert that source_dims and destination_dims have the correct number of elements. So something like

TORCH_INTERNAL_ASSERT(std::distance(source_dims.begin(), source_iter)  == rest_dim);
TORCH_INTERNAL_ASSERT(std::distance(destination_dims.begin(), destination_iter)  == rest_dim);

But either way, we should either use source_iter / destination_iter or delete them

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.
Will use them as suggested in second comment.

Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!


int64_t rest_dim = self.dim() - src.size();
for (int64_t i = 0; i < rest_dim; ++i) {
order[destination_dims[i]] = source_dims[i];
}

return self.permute(order);
}
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

Tensor moveaxis(const Tensor& self, int64_t src, int64_t dst) {
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
return at::moveaxis(self, IntArrayRef{src}, IntArrayRef{dst});
}

}} // at::native
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2159,6 +2159,14 @@
use_c10_dispatcher: full
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.

- func: moveaxis.intlist(Tensor self, int[] src, int[] dst) -> Tensor
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
use_c10_dispatcher: full
variants: function
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No method variant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since numpy.moveaxis is function only.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK not to have a method variant to start, but just because NumPy doesn't have one doesn't mean we shouldn't, either. The closet analogue to this function is probably permute, which is a method.


- func: moveaxis.int(Tensor self, int src, int dst) -> Tensor
use_c10_dispatcher: full
variants: function

# Only exposed from C++ -- in Python,
# we expose it as an attribute `T`, not a function.
#
Expand Down
10 changes: 10 additions & 0 deletions test/test_autograd.py
Expand Up @@ -6510,6 +6510,16 @@ def test_strided_leaf_grad_layout(self, device):
(c * d).sum().backward()
self.assertEqual(c.grad.stride(), (2, 1))

def test_moveaxis(self, device):
x = torch.randn(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

# Positive axis
gradcheck(lambda x: torch.moveaxis(x, (0, 1, 2, 3), (3, 2, 1, 0)), x)
gradgradcheck(lambda x: torch.moveaxis(x, (0, 1, 2, 3), (3, 2, 1, 0)), x)

# Negative axis
gradcheck(lambda x: torch.moveaxis(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x)
gradgradcheck(lambda x: torch.moveaxis(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x)

class TestMultithreadAutograd(TestCase):
def _run_py_multithread_fn(self, fn, args=(), num_threads=10, kwargs=None):
Expand Down
51 changes: 51 additions & 0 deletions test/test_torch.py
Expand Up @@ -18121,6 +18121,57 @@ def test_large_linspace(self, device, dtype):
x = torch.linspace(start, end, steps, dtype=dtype, device=device)
self.assertGreater(x[1] - x[0], (end - start) / steps)

@dtypes(torch.int64, torch.float, torch.complex128)
def test_moveaxis_invalid(self, device, dtype):
shape = self._rand_shape(4, min_size=5, max_size=10)
x = self._generate_input(shape, dtype, device, False)

# Invalid `src` and `dst` dimension
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
torch.moveaxis(x, 5, 0)

with self.assertRaisesRegex(IndexError, "Dimension out of range"):
torch.moveaxis(x, 0, 5)

# Mismatch in size of `src` and `dst`
with self.assertRaisesRegex(RuntimeError, "moveaxis: Invalid source or destination dims:"):
torch.moveaxis(x, (1, 0), (0, ))

with self.assertRaisesRegex(RuntimeError, "moveaxis: repeated axis in `src`"):
torch.moveaxis(x, (0, 0), (0, 1))

with self.assertRaisesRegex(RuntimeError, "moveaxis: repeated axis in `dst`"):
torch.moveaxis(x, (0, 1), (1, 1))

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@dtypes(torch.int64, torch.float, torch.complex128)
def test_moveaxis(self, device, dtype):
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
for nd in range(5):
shape = self._rand_shape(nd, min_size=5, max_size=10)
x = self._generate_input(shape, dtype, device, with_extremal=False)
for random_negative in [True, False]:
for src_dim, dst_dim in permutations(range(nd), r=2):
if random_negative:
src_dim = src_dim - nd
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
# Integer Inputs
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
torch_fn = partial(torch.moveaxis, src=src_dim, dst=dst_dim)
np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim)
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)

if nd > 0:
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
for src_sequence in permutations(range(nd), r=random.randint(1, nd)):
# Randomly change a dim to a negative dim representation of itself.
if random_negative:
random_idx = random.randint(0, len(src_sequence) - 1)
src_sequence = list(src_sequence)
src_sequence[random_idx] = src_sequence[random_idx] - nd
src_sequence = tuple(src_sequence)
# Sequence Inputs
dst_sequence = tuple(random.sample(range(nd), len(src_sequence)))
torch_fn = partial(torch.moveaxis, src=src_sequence, dst=dst_sequence)
np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence)
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)

kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
# NOTE [Linspace+Logspace precision override]
# Our Linspace and logspace torch.half CUDA kernels are not very precise.
# Since linspace/logspace are deterministic, we can compute an expected
Expand Down
1 change: 1 addition & 0 deletions torch/_overrides.py
Expand Up @@ -407,6 +407,7 @@ def get_testing_overrides():
dropout, train, bidirectional, batch_sizes, dropout_state: -1),
torch.mm: lambda input, mat2, out=None: -1,
torch.mode: lambda input: -1,
torch.moveaxis: lambda input, src, dst: -1,
torch.mul: lambda input, other, out=None: -1,
torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
torch.mv: lambda input, vec, out=None: -1,
Expand Down
38 changes: 38 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -4365,6 +4365,44 @@ def merge_dicts(*dicts):
[1.0311, 0.3901, 0.5049]])
""")

add_docstr(torch.moveaxis,
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
r"""
moveaxis(input, src, dst) -> Tensor

Move axes of an array to new positions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be clearer about which dimensions are moving where. Also: what does it mean for a dimension to "move?"

Other axes remain in their original order.

Args:
input (Tensor)
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
src (int or tuple of ints): Original positions of the axes to move. These must be unique.
dst (int or tuple of ints): Destination positions for each of the original axes. These must also be unique.

Example::

>>> a = torch.randn(3,2,1)
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
>>> a
tensor([[[-0.3362],
[-0.8437]],

[[-0.9627],
[ 0.1727]],

[[ 0.5173],
[-0.1398]]])
>>> torch.moveaxis(a, 1, 0)
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
tensor([[[-0.3362],
[-0.9627],
[ 0.5173]],

[[-0.8437],
[ 0.1727],
[-0.1398]]])
>>> torch.moveaxis(a, (1, 2), (0, 1))
tensor([[[-0.3362, -0.9627, 0.5173]],

[[-0.8437, 0.1727, -0.1398]]])
""")
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

add_docstr(torch.narrow,
r"""
narrow(input, dim, start, length) -> Tensor
Expand Down