Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[proposal] Let's unify the various gather/scatter ops. #64936

Open
Chillee opened this issue Sep 13, 2021 · 5 comments
Open

[proposal] Let's unify the various gather/scatter ops. #64936

Chillee opened this issue Sep 13, 2021 · 5 comments
Labels
better-engineering Relatively self-contained tasks for better engineering contributors module: advanced indexing Related to x[i] = y, index functions module: scatter & gather ops triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Chillee
Copy link
Contributor

Chillee commented Sep 13, 2021

Proposal: Let's make all of the gather-type ops (index.Tensor, index_select, gather) composite, and have them all call into the same underlying operator (gather_general). This will 1. reduce effort of transformations/backends that need to handle every operator, 2. reduce inefficiencies from having users use the wrong operator, and 3. allow us to focus on optimizing a single operator.

Context

Motivated by needing to write batching rules for index.Tensor, I realized that there are often multiple ways of doing the same operation for the scatter ops. For example, let's say we want to index_select along the first dimension. Here are 3 ways of writing it:

def index_select1(x, y):
    return torch.index_select(x, 0, y)
   
def index_select2(x, y):
    dim = 0
    for ii in range(0, len(x.shape)):
        if ii != dim:
            y = y.unsqueeze(ii)
    expanse = list(x.shape)
    expanse[dim] = -1
    return torch.gather(x, dim, y.expand(expanse))
    
def index_select3(x, y):
    return x[y] # calls index.Tensor under the hood

Now, which one is the fastest? This ... is a very annoying question. The answer is ... it depends on your rank, your shape, and your device.

Let's define the function we're benchmarking as torch.index_select(A, 0, B), and let's say we're going over the shape range [10, 100], with a rank of 2. That means we're going to benchmark 8 results: A: [10, 10], B: [10], A: [10, 10], B: [100], A: [10, 100], B: [10], and so on. Then, we're going to count how often implementation 1 (torch.index_select), implementation 2 (torch.gather), and implementation 3 (indexing) is the fastest.

Let's start off by going over the shape range [10, 100, 1000] with a rank of 3 on the CPU. We see that of the 81 results, index_select was the fastest 22 times, gather was never the fastest, and index was the fastest 59 times, with an average margin of 36% (i.e. 1.36 seconds vs 1.0 seconds). Ok, that's somewhat reasonable, but perhaps index_select should be calling into index in some cases?

Now, let's check the same shapes/ranks, but on GPU.

shapes: [10, 100, 1000], rank: 3
index_select: 28
gather: 30
index: 23
average margin: 15%

Oh... now we see that gather has thrown its hat into the ring, and is sometimes the fastest. But... if we try limiting only to fairly big shapes, we see...

shapes: [100, 1000], rank: 3
index_select: 0
gather: 14
index: 2
average margin: 13%

Ok, so now for fairly big shapes on CUDA, gather seems to be faster than everything else.

So, we have a situation where for the same operation, there are multiple ways of writing it, and it's non-obvious which one will be the fastest. Judging from these results, it seems very possible that for large enough GPU indexing, we torch.index_select should just call into gather instead of running it itself.

This hurts both 1. people writing transformations, 2. users writing code, and 3. PyTorch developers

For people writing transformations (like vmap), I don't know whether I should call into gather, index_select, or index.Tensor and I'm worried that choosing the wrong one will hurt performance.

For users, they need to worry about the same thing. Moreover, most of them probably have no idea how these differ in performance, and thus might be leaving performance on the table by choosing the wrong option.

For PyTorch developers, what this situation means is that any optimizations/improvements they make to these ops have limited applicability, since users might be using a completely different code path.

I propose that we make all of these ops (gather, index_select, and index.Tensor) composite, and have them all call into the same unified implementation. It's not totally clear to me what the required semantics are, but XLA's gather may be a good place to start: https://www.tensorflow.org/xla/operation_semantics#gather

I've mainly talked about gather ops here, but I suspect our scatter ops are in a similar situation.

@ezyang
Copy link
Contributor

ezyang commented Sep 13, 2021

One thing I do worry about a little is if there are cases where you can be faster with a direct implementation

@jbschlosser jbschlosser added module: advanced indexing Related to x[i] = y, index functions module: scatter & gather ops triage review better-engineering Relatively self-contained tasks for better engineering contributors labels Sep 13, 2021
@Chillee
Copy link
Contributor Author

Chillee commented Sep 13, 2021

One thing I do worry about a little is if there are cases where you can be faster with a direct implementation

Couldn't we always just add a check in the general op that specializes it for a more direct implementation? I think it's better that we do these kinds of checks instead of the user.

@ngimel
Copy link
Collaborator

ngimel commented Sep 13, 2021

We started by making scatter call into general index_put, but that turned out to be 3-4x slower than more specialized implementation, at least for some benchmarks #31662 (comment). So for scatter/gather/index/index_put there's some code reuse going on, but for perf reasons they are not backed by a single implementation (and of course since a lot of decisions there are heuristics based, sometimes we also get heuristics wrong).

@Chillee
Copy link
Contributor Author

Chillee commented Sep 14, 2021

and of course since a lot of decisions there are heuristics based, sometimes we also get heuristics wrong

Can we not hide all of the heuristics inside of a single op?

@VitalyFedyunin VitalyFedyunin added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Oct 4, 2021
@vadimkantorov
Copy link
Contributor

vadimkantorov commented May 4, 2023

Also, flash attention somehow reports that indexing first dimension is a bit slow if done naively: https://github.com/HazyResearch/flash-attention/blob/ad113948a6c3864fbe48156a9857e97a38ce758c/flash_attn/bert_padding.py#L9

And torch.gather/torch.scatter_ are preferable than __getitem__. It would be interesting to collect their experience

Also, related: #64208

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better-engineering Relatively self-contained tasks for better engineering contributors module: advanced indexing Related to x[i] = y, index functions module: scatter & gather ops triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants