[proposal] Let's unify the various gather/scatter ops. #64936
Labels
better-engineering
Relatively self-contained tasks for better engineering contributors
module: advanced indexing
Related to x[i] = y, index functions
module: scatter & gather ops
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Proposal: Let's make all of the
gather
-type ops (index.Tensor
,index_select
,gather
) composite, and have them all call into the same underlying operator (gather_general
). This will 1. reduce effort of transformations/backends that need to handle every operator, 2. reduce inefficiencies from having users use the wrong operator, and 3. allow us to focus on optimizing a single operator.Context
Motivated by needing to write batching rules for
index.Tensor
, I realized that there are often multiple ways of doing the same operation for the scatter ops. For example, let's say we want toindex_select
along the first dimension. Here are 3 ways of writing it:Now, which one is the fastest? This ... is a very annoying question. The answer is ... it depends on your rank, your shape, and your device.
Let's define the function we're benchmarking as
torch.index_select(A, 0, B)
, and let's say we're going over the shape range[10, 100]
, with a rank of 2. That means we're going to benchmark 8 results:A: [10, 10], B: [10]
,A: [10, 10], B: [100]
,A: [10, 100], B: [10]
, and so on. Then, we're going to count how often implementation 1 (torch.index_select
), implementation 2 (torch.gather
), and implementation 3 (indexing
) is the fastest.Let's start off by going over the shape range
[10, 100, 1000]
with a rank of 3 on the CPU. We see that of the 81 results,index_select
was the fastest 22 times,gather
was never the fastest, andindex
was the fastest 59 times, with an average margin of 36% (i.e. 1.36 seconds vs 1.0 seconds). Ok, that's somewhat reasonable, but perhapsindex_select
should be calling intoindex
in some cases?Now, let's check the same shapes/ranks, but on GPU.
Oh... now we see that
gather
has thrown its hat into the ring, and is sometimes the fastest. But... if we try limiting only to fairly big shapes, we see...Ok, so now for fairly big shapes on CUDA,
gather
seems to be faster than everything else.So, we have a situation where for the same operation, there are multiple ways of writing it, and it's non-obvious which one will be the fastest. Judging from these results, it seems very possible that for large enough GPU indexing, we
torch.index_select
should just call intogather
instead of running it itself.This hurts both 1. people writing transformations, 2. users writing code, and 3. PyTorch developers
For people writing transformations (like vmap), I don't know whether I should call into
gather
,index_select
, orindex.Tensor
and I'm worried that choosing the wrong one will hurt performance.For users, they need to worry about the same thing. Moreover, most of them probably have no idea how these differ in performance, and thus might be leaving performance on the table by choosing the wrong option.
For PyTorch developers, what this situation means is that any optimizations/improvements they make to these ops have limited applicability, since users might be using a completely different code path.
I propose that we make all of these ops (
gather
,index_select
, andindex.Tensor
) composite, and have them all call into the same unified implementation. It's not totally clear to me what the required semantics are, but XLA's gather may be a good place to start: https://www.tensorflow.org/xla/operation_semantics#gatherI've mainly talked about gather ops here, but I suspect our scatter ops are in a similar situation.
The text was updated successfully, but these errors were encountered: