Skip to content

[feature request] "Batched" index_select (i.e. simplified torch.gather with not specifying full index) #64208

Open
@vadimkantorov

Description

@vadimkantorov

E.g. we have I = torch.randint(0, n3, (n1, n2)) and T = torch.rand(n1, n2, n3, n4, n5)
We'd like to compute O[i, j, ...] = T[i, j, I[i, j], ...]

This is fairly frequent. Currently this is possible with a combination of multi unsqueeze + expand + gather + squeeze.

I propose a new helper function to do this in one go (and thus making this pattern less complex and error-prone) or just making gather support this directly by relaxing index tensor shape constraints

cc @ezyang @gchanan @zou3519 @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: scatter & gather opstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions