[proposal] batch mode for randperm #42502
Labels
enhancement
Not as big of a feature, but technically not a bug. Should be easy to fix
module: random
Related to random number generation in PyTorch (rng generator)
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃殌 Feature
I'd like to implement some features for torch.random.randperm. What I've thought of so far:
-batch parameter, allowing multiple permutations to be sampled at the same time.
-partial or k-permutations
These would be accessible using optional arguments whose default behavior match current behavior (i.e. batch=1, k=None).
I'm uncertain about what options would be given for controlling the shape of the output (i.e. column or row permutations, automatic squeezing). One possibility is an optional size parameter for the output, and a dim parameter that specifies which axis the permutation lies on. If size is none then it defaults to current behavior.
I'm also open to additional features.
Motivation
Whenever I've used randperm I've used it in a batched context, so it seems like a natural feature.
Pitch
I will code cpu and gpu replacements for the current randperm implementations to allow the new features.
Alternatives
Currently batching can be done using a for loop, which is tedious and unperformant
Partial permutations can easily be obtained by taking a subtensor of the return value but can be unperformant if n is large and k is small.
Additional context
cc @pbelevich
The text was updated successfully, but these errors were encountered: