Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random Shuffle along Axis #71409

Open
adgaudio opened this issue Jan 18, 2022 · 5 comments
Open

Random Shuffle along Axis #71409

adgaudio opened this issue Jan 18, 2022 · 5 comments
Labels
feature A request for a proper, new feature. module: random Related to random number generation in PyTorch (rng generator) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@adgaudio
Copy link

adgaudio commented Jan 18, 2022

馃殌 The feature, motivation and pitch

Dear PyTorch Devs,

Thank you for your hard work and dedication to creating a great ecosystem of tools and community of users.

This feature request proposes adding a standard lib function to shuffle "rows" across an axis of a tensor. We will consider two ways to do this (use-case 1 and use-case 2). Consider that an axis partitions a tensor into sub-tensors. For instance a tensor of shape (2,3,4) has 3 sub-tensors at axis=1.

Example input:

>>>  x = torch.arange(2*3*4).reshape(2,3,4)
>>>  x
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

Use-case 1: Randomly shuffle an axis the same way for all sub-tensors. (This implementation is trivially "solved" below).

>>> axis = 1  # choose an axis
>>> x.index_select(axis, T.randperm(x.shape[axis]))
tensor([[[ 0,  1,  2,  3],
         [ 8,  9, 10, 11],
         [ 4,  5,  6,  7]],

        [[12, 13, 14, 15],
         [20, 21, 22, 23],
         [16, 17, 18, 19]]])

Use-case 2: Randomly shuffle an axis differently for each other axes

>>> shufflerow(x, axis=1)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[16, 17, 18, 19],
         [12, 13, 14, 15],
         [20, 21, 22, 23]]])

Where the shufflerow function (for use-case 2) could be implemented like this

def shufflerow(tensor, axis):
    row_perm = torch.rand(tensor.shape[:axis+1]).argsort(axis)  # get permutation indices
    for _ in range(tensor.ndim-axis-1): row_perm.unsqueeze_(-1)
    row_perm = row_perm.repeat(*[1 for _ in range(axis+1)], *(tensor.shape[axis+1:]))  # reformat this for the gather operation
    return tensor.gather(axis, row_perm)

It would be convenient to have a simple function like shuffle(tensor, axis, mode:str) for both use cases, where for example, we might say mode='same' (use case 1) or mode='different' (use case 2).

Alternatives

I do not know of a better shuffle implementation for use-case 2, but my proposed implementation seems not efficient. Maybe an improvement would have an implementation in C or cuda that covers use-case 2 without the for loops and without the repeat and gather.

Currently, we have torch.randperm to randomly shuffle one axis the same way across all the same way.

Perhaps off topic comment: I also wish PyTorch (and NumPy) had a toolkit dedicated to sampling, such as reservoir sampling across minibatches. Sampling often introduces subtle bugs.

Additional context

Variations of this feature request on the internet:

When axis=-1:

For 2-d tensor:

ChannelShuffle:

Thank you!

cc @pbelevich

@samdow samdow added feature A request for a proper, new feature. module: random Related to random number generation in PyTorch (rng generator) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 18, 2022
@vadimkantorov
Copy link
Contributor

Related request: #42502

@MrMois
Copy link

MrMois commented Jan 21, 2022

+1. I need this so many times, a built in function would be awesome.

@aiqc
Copy link

aiqc commented May 9, 2022

Related: https://discuss.pytorch.org/t/shuffle-a-tensor-a-long-a-certain-dimension/129798/3

This is a big gap in numpy parity

@ksurya
Copy link

ksurya commented Aug 29, 2023

+1

@cabralpinto
Copy link

cabralpinto commented Oct 5, 2023

I think I found the solution for this. Take the following tensor as an example:

x = torch.arange(4 * 5).view(4, 5)
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])

Then, we can use the following code to shuffle along dimension d independently:

d = 1
x.take_along_dim(torch.sort(torch.rand(*x.shape))[1], d)
tensor([[ 2,  3,  0,  1,  4],
        [ 8,  7,  9,  6,  5],
        [10, 13, 14, 11, 12],
        [19, 18, 15, 16, 17]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: random Related to random number generation in PyTorch (rng generator) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants