Open
Description
E.g. we have I = torch.randint(0, n3, (n1, n2))
and T = torch.rand(n1, n2, n3, n4, n5)
We'd like to compute O[i, j, ...] = T[i, j, I[i, j], ...]
This is fairly frequent. Currently this is possible with a combination of multi unsqueeze + expand + gather + squeeze.
I propose a new helper function to do this in one go (and thus making this pattern less complex and error-prone) or just making gather support this directly by relaxing index tensor shape constraints