Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flip a Tensor #229

Closed
glample opened this issue Nov 18, 2016 · 34 comments · Fixed by #7873

Comments

@glample
Copy link
Contributor

@glample glample commented Nov 18, 2016

Sometimes it's convenient to do stuff like y = x[::-1] to inverse the order of the elements of a tensor, even on another axis like: y = x[:, ::-1]
Would it be possible to add this feature to pytorch?

@apaszke

This comment has been minimized.

Copy link
Member

@apaszke apaszke commented Nov 18, 2016

Some stride modifications should do the trick. I'll add it

@apaszke

This comment has been minimized.

Copy link
Member

@apaszke apaszke commented Nov 21, 2016

Unfortunately it appears that it will require some additional changes in TH, so it's going to be delayed a bit.

@jekbradbury

This comment has been minimized.

Copy link
Contributor

@jekbradbury jekbradbury commented Dec 16, 2016

Giving a +1 to this being a useful addition

@rtqichen

This comment has been minimized.

Copy link
Contributor

@rtqichen rtqichen commented Feb 14, 2017

To add another use case for strides other than 1 (not just -1): accessing diagonals of a n by n matrix X would be very easy with X.view(-1)[::n+1]. The current torch.diag function returns a copy of the diagonal.

@fmassa

This comment has been minimized.

Copy link
Member

@fmassa fmassa commented Feb 14, 2017

@rtqichen I believe we can pass a stride parameter in the constructor of tensor initalization, meaning that from a storage, you can create a tensor with the sizes ans strides that you want (as long as they are positive). I just think that the slicing operator in pytorch doesn't allow that, but I'd need to double check that. I think it's something like

torch.Tensor(storage, sizes, strides)
@rtqichen

This comment has been minimized.

Copy link
Contributor

@rtqichen rtqichen commented Feb 14, 2017

@fmassa Thanks, but I think that's also not implemented yet. #215

@fmassa

This comment has been minimized.

Copy link
Member

@fmassa fmassa commented Feb 14, 2017

I see, but as mentioned in the issue you pointed out, it is possible to use set_ for that.
Anyway, I agree that allowing strides in slicing is going to be handy

@apaszke

This comment has been minimized.

Copy link
Member

@apaszke apaszke commented Feb 14, 2017

Non-negative strides should be easy to add. I'll do that today.

@apaszke apaszke self-assigned this Feb 14, 2017
@soumith

This comment has been minimized.

Copy link
Member

@soumith soumith commented Feb 21, 2017

non-negative strides are now supported in master

@apaszke apaszke removed their assignment Feb 21, 2017
@soumith soumith added the 24hr+ label Apr 18, 2017
@aosokin

This comment has been minimized.

Copy link
Contributor

@aosokin aosokin commented May 5, 2017

@soumith, @apaszke Without negative strides, is there any relatively efficient way to reverse the 1-dimensional tensor (sequence)?

@fmassa

This comment has been minimized.

Copy link
Member

@fmassa fmassa commented May 5, 2017

@aosokin the most efficient way at the moment is to use index_select in the tensor, which will make a copy of it.

inv_idx = torch.arange(tensor.size(0)-1, -1, -1).long()
# or equivalently torch.range(tensor.size(0)-1, 0, -1).long()
inv_tensor = tensor.index_select(0, inv_idx)
# or equivalently
inv_tensor = tensor[inv_idx]
@aosokin

This comment has been minimized.

Copy link
Contributor

@aosokin aosokin commented May 5, 2017

@fmassa Thanks!

@authman

This comment has been minimized.

Copy link

@authman authman commented Jul 16, 2017

Still holding out hope.

@soumith soumith changed the title inverse tensor flip a Tensor Jul 21, 2017
@ruotianluo

This comment has been minimized.

Copy link
Contributor

@ruotianluo ruotianluo commented Aug 6, 2017

It seems it's still not supported in 0.2?

@soumith

This comment has been minimized.

Copy link
Member

@soumith soumith commented Aug 6, 2017

it is not :( sorry.

@mariogeiger

This comment has been minimized.

Copy link

@mariogeiger mariogeiger commented Aug 14, 2017

How to inverse a cuda tensor ?
The method given by fmassa works well on cpu but not on cuda (#2412)

EDIT : inv_tensor = tensor.index_select(0, inv_idx) works fine 😄

@soumith soumith added this to Low Priority in Issue Status Aug 23, 2017
@dmarnerides

This comment has been minimized.

Copy link

@dmarnerides dmarnerides commented Sep 9, 2017

Is there a better (more efficient) way of flipping an arbitrary dimension than this?

def flip(x, dim):
    dim = x.dim() + dim if dim < 0 else dim
    return x[tuple(slice(None, None) if i != dim
             else torch.arange(x.size(i)-1, -1, -1).long()
             for i in range(x.dim()))]

Code to test it with:

a = torch.Tensor([range(1, 25)]).view(1, 2, 3, 4)
print(a)
print(flip(a, 0)) # Or -4
print(flip(a, 1)) # Or -3
print(flip(a, 2)) # Or -2
print(flip(a, 3)) # Or -1
@soumith soumith added this to nn / autograd / torch in Issue Categories Sep 13, 2017
@soumith soumith moved this from nn / autograd / torch to torch /autograd in Issue Categories Sep 20, 2017
@wassname

This comment has been minimized.

Copy link

@wassname wassname commented Nov 21, 2017

Here's @dmarnerides code but with cuda support

# https://github.com/pytorch/pytorch/issues/229
def flip(x, dim):
    dim = x.dim() + dim if dim < 0 else dim
    inds = tuple(slice(None, None) if i != dim
             else x.new(torch.arange(x.size(i)-1, -1, -1).tolist()).long()
             for i in range(x.dim()))
    return x[inds]

# Code to test it with cpu
a = torch.Tensor([range(1, 25)]).view(1, 2, 3, 4)
print(a)
print(flip(a, 0)) # Or -4
print(flip(a, 1)) # Or -3
print(flip(a, 2)) # Or -2
print(flip(a, 3)) # Or -1

# Code to test it with cuda
a = torch.Tensor([range(1, 25)]).view(1, 2, 3, 4).cuda()
print(a)
print(flip(a, 0)) # Or -4
print(flip(a, 1)) # Or -3
print(flip(a, 2)) # Or -2
print(flip(a, 3)) # Or -1
@Evpok

This comment has been minimized.

Copy link
Contributor

@Evpok Evpok commented Nov 21, 2017

The solution could come from #3750, though negative strides are not implemented atm.

@sytrus-in-github

This comment has been minimized.

Copy link
Contributor

@sytrus-in-github sytrus-in-github commented Dec 7, 2017

To make it work also with torch.autograd.Variablewhich does not have new attribute as well as avoiding the RuntimeError Assertion 'ndim <= MAX_ADVINDEX_CALC_DIMS' failed. for input with dimension bigger than 6 in the cuda case, I had to change the code from @dmarnerides and @wassname as follows:

# https://github.com/pytorch/pytorch/issues/229
import torch
from torch.autograd import Variable
def flip(x, dim):
    xsize = x.size()
    dim = x.dim() + dim if dim < 0 else dim
    x = x.view(-1, *xsize[dim:])
    x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, 
                      -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :]
    return x.view(xsize)

# Code to test it with cuda Variable
a = Variable(torch.Tensor([range(1, 25)]).view(1, 2, 3, 4, 1, 1).cuda())
print(a)
print(flip(a, 0)) # Or -6
print(flip(a, 1)) # Or -5
print(flip(a, 2)) # Or -4
print(flip(a, 3)) # Or -3
print(flip(a, 4)) # Or -2
print(flip(a, -1)) # Or 5

A pytorch 0.4.0+ version:

# https://github.com/pytorch/pytorch/issues/229
import torch
def flip(x, dim):
    indices = [slice(None)] * x.dim()
    indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
                                dtype=torch.long, device=x.device)
    return x[tuple(indices)]

Of course, it would be nicer if we could use negative strides directly. This is a useful operation. At least I personally need it for a project that I am working on.

@alok

This comment has been minimized.

Copy link

@alok alok commented Mar 9, 2018

Any updates for this?

@MaximArtemev

This comment has been minimized.

Copy link

@MaximArtemev MaximArtemev commented Apr 8, 2018

Still waiting :(

@ezyang

This comment has been minimized.

Copy link
Contributor

@ezyang ezyang commented Apr 12, 2018

We're going to make sure negative strides work in the C10 rewrite of the core tensor class.

Warning: negative strides are a BC-breaking change; in the old days you could have used -1 to mean "fill this stride in with whatever the contiguous stride should have been". But this didn't even work correctly in all situations:

>>> torch.randn(2,3).as_strided((2,3), (2,1)).set_(torch.randn(2,3).storage(), 0, (2,3), (-1,1)).stride()
(2, 1)
>>> torch.randn(2,3).as_strided((2,3), (2,1)).as_strided((2,3), (-1,1)).stride()
(3, 1)

so I don't think anyone should be affected by this fix. Holler if you are.

@ivan-bilan

This comment has been minimized.

Copy link

@ivan-bilan ivan-bilan commented May 23, 2018

@sytrus-in-github thank you for the flip function. How would I have to adapt it to get a[..., ::-1, :] for the matrix a[batch_size, x_size, y_size]. I would need to flip x and y, and leave the batch dimension untouched (full example).

@sytrus-in-github

This comment has been minimized.

Copy link
Contributor

@sytrus-in-github sytrus-in-github commented May 23, 2018

@ivan-bilan does this answer your question?

@wangg12

This comment has been minimized.

Copy link
Contributor

@wangg12 wangg12 commented Dec 11, 2018

Any update for this?

@dmarnerides

This comment has been minimized.

Copy link

@dmarnerides dmarnerides commented Dec 11, 2018

Hi @wangg12, there is now a torch.flip function you can use.

@soumith

This comment has been minimized.

Copy link
Member

@soumith soumith commented Dec 12, 2018

closed via #7873

@soumith soumith reopened this Dec 12, 2018
@soumith soumith closed this Dec 12, 2018
@Oktai15

This comment has been minimized.

Copy link

@Oktai15 Oktai15 commented Jun 4, 2019

@soumith torch.flip changes the order of dims in tensor, but doesn't change the order in concrete axis like torch.tensor([1,2,3]) -> torch.tensor([3,2,1]). Why did you close this issue?

@willprice

This comment has been minimized.

Copy link

@willprice willprice commented Jun 4, 2019

>>> import torch
>>> xs = torch.rand(5)
>>> xs
tensor([0.1820, 0.4174, 0.6118, 0.2575, 0.5187])
>>> torch.flip(xs, (0,))
tensor([0.5187, 0.2575, 0.6118, 0.4174, 0.1820])

What is it that you expect flip to do? Flip reverses the order of elements along a dimension. Do you want to reverse the order of dimensions, e.g. THWC -> CWHT ?

@dmarnerides

This comment has been minimized.

Copy link

@dmarnerides dmarnerides commented Jun 4, 2019

@Oktai15 I think the function that changes the order of the dimensions is torch.permute and not torch.flip.

@Oktai15

This comment has been minimized.

Copy link

@Oktai15 Oktai15 commented Jun 4, 2019

@dmarnerides @willprice yes, you are right, sorry, misunderstanding.

@drscotthawley

This comment has been minimized.

Copy link

@drscotthawley drscotthawley commented Oct 30, 2019

I know this is closed, but I'm finding torch.flip() to be a very compute-intensive operation -- profiling my code shows that 30% of the computing ops are being used on flip().

@soumith Is this operation actually copying data? If so, would it be more efficient to use a view?

Related: np.flip() claims that it's producing a view, and yet it runs at least 5 times slower than just using negative step slicing (e.g. "a[::-1]"). cf. https://stackoverflow.com/questions/6771428/most-efficient-way-to-reverse-a-numpy-array

I tried using negative step slicing but I get the error

ValueError: negative step not yet supported

@ezyang Any update on the negative stride support?

Thanks!

@soumith

This comment has been minimized.

Copy link
Member

@soumith soumith commented Oct 30, 2019

torch.flip() is copying data, as we dont have negative stride support.
I think it's unlikely that we'll ever get to adding negative strides

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Issue Status
Low Priority
Issue Categories
torch /autograd
You can’t perform that action at this time.