Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] index_select is very slow on sparse tensors (and my proposed algorithm to fix it) #61788

Closed
IAmKohlton opened this issue Jul 16, 2021 · 16 comments
Labels
module: performance Issues related to performance, either of kernel code or framework glue module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@IAmKohlton
Copy link

IAmKohlton commented Jul 16, 2021

torch.index_select is supposed to work on both dense, and sparse tensors. For dense tensors it's pretty amazing, but for sparse tensors it's painfully slow. Here's an example I ran in a jupyter notebook that shows this:

import torch
from earth_mesh.entity.sparse_tensor import SparseTensor 
# SparseTensor is a class that has some custom functions I wrote that operate on sparse tensors

dense_range = torch.arange(100 * 100 * 100).view(100, 100, 100)
sparse_range = dense_range.to_sparse()

indices = torch.randint(0, 100, (100,))

%time torch.index_select(dense_range, 0, indices)
%time torch.index_select(sparse_range, 0, indices)
%time SparseTensor._index_select(SparseTensor(sparse_range), 0, indices)

This had the result:

CPU times: user 12.5 ms, sys: 0 ns, total: 12.5 ms
Wall time: 513 µs
CPU times: user 26.9 s, sys: 5.26 ms, total: 26.9 s
Wall time: 26.8 s
CPU times: user 254 ms, sys: 47.8 ms, total: 301 ms
Wall time: 83.8 ms

As you can see index_select for a dense tensor runs incredibly quickly, but the pytorch index_select is atrociously slow for sparse tensors. However, the algorithm I wrote was several hundred times faster. For the use case I originally built this for it was several thousand times faster! I was wondering if my algorithm could replace the current one that computes index_select for sparse tensors.

The only catch is that the space complexity isn't amazing. If n=len(indices) and m=number of occurrences of most common index in tensor._indices()[dim] then the space complexity is O(nm). This might seem bad, but this is also the space complexity of the size of the output! (This is true, but in my algorithm the best case is very close to the worst case, which I'm sure isn't true of the current index_select)

My function is as follows

@classmethod
def _index_select(cls, tensor, dim, indices):
    is_sorted = torch.all(torch.diff(tensor._indices()[dim]) >= 0)
    if not is_sorted:
        sort_indices = torch.argsort(tensor._indices()[dim])
        tensor = torch.sparse_coo_tensor(tensor._indices()[:, sort_indices], tensor._values()[sort_indices], tensor.shape)

    search_result_left = torch.searchsorted(tensor._indices()[dim], indices, right=False)
    search_result_right = torch.searchsorted(tensor._indices()[dim], indices, right=True)
    num_results = search_result_right - search_result_left

    index_mesh_grid = torch.stack(torch.meshgrid(
        torch.arange(torch.max(num_results)),
        torch.arange(indices.shape[0])
    ))

    desired_index_info = torch.where(
        index_mesh_grid[0] < num_results,
        index_mesh_grid,
        torch.full_like(index_mesh_grid, -1)
    ).transpose(1, 2)

    filtered_index_info = desired_index_info[desired_index_info != -1].view(2, -1)

    tensor_indices = search_result_left[filtered_index_info[1]] + filtered_index_info[0]
    return_index = tensor._indices()[:, tensor_indices]
    return_index[dim] = filtered_index_info[1]
    new_shape = *tensor.shape[:dim], indices.shape[0], *tensor.shape[dim + 1:]
    return torch.sparse_coo_tensor(return_index, tensor._values()[tensor_indices], new_shape)

If you're interested in the algorithm but it's a bit too dense to understand then I'm more than happy to explain it!

cc @VitalyFedyunin @ngimel @heitorschueroff @nikitaved @pearu @cpuhrsch @IvanYashchuk

@IAmKohlton IAmKohlton changed the title [feature request] index_select is very slow on sparse tensors [feature request] index_select is very slow on sparse tensors (and my proposed algorithm to fix it) Jul 16, 2021
@anjali411 anjali411 added module: sparse Related to torch.sparse module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 19, 2021
@cpuhrsch
Copy link
Contributor

Hello @IAmKohlton,

Thank you for posting this issue! Just to get some more context, what is your intended use case for this operator?

Thanks,
Christian

@IAmKohlton
Copy link
Author

Hey Christian,

In terms of what the outputs are it should have the exact same use case as the current index_select for sparse tensors. I can't give too many specifics, but in abstract terms I needed a robust way to index into a sparse tensor. For example you can just index into a 4d dense tensor with tensor[[4, 6, 1, 5, 1], :, 3, [3, 5, 1, 3, 6]] but there's no equivalent way to do that for a sparse tensor. When writing that robust indexing algorithm index_select for a sparse tensor was the most important, and time consuming operation.

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Aug 4, 2021

Thanks for the info @IAmKohlton,

Are you allowed to share the general category of models (e.g. GNNs) or rough area of application (e.g. image classification) for this? We're currently revisiting torch.sparse as a whole and collecting explicit use cases is very important for that.

Thanks,
Christian

@IAmKohlton
Copy link
Author

Oh I see what you mean now. The use case is actually only tangentially related to machine learning. We use pytorch because of it's ability to work with hybrid tensors. If you need more details I can ask my manager if I can share more specific details (I'm doing this for work, but I got permission to share everything I've said so far)

@IAmKohlton
Copy link
Author

I guess I can say that we're simulating some real world phenomena with hybrid tensors, and the results of the simulations will potentially be used for some form of ML

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Aug 4, 2021

Thanks for the details! What do you mean by hybrid tensors? More details such as the type of model or the general area would be useful. We could also throw it under the broad stroke of "scientific computing" if that's useful. If you like you could also add a broader paragraph about what you want to use sparsity for or other ops and your general experience with torch.sparse.

@IAmKohlton
Copy link
Author

"scientific computing" is a decent description. By hybrid tensor I mean a tensor where the first few dimensions are sparse, and the last few dimensions are dense. For what I'm working it's very useful to think about it like we have many points in some N dimensional space, and at each point we have a vector/matrix of useful information. Because of this we basically always make the assumption that when a value isn't specified it isn't zero, it just doesn't exist!

I'm really glad you asked! There's kind of two categories of things I would really like added/changed. The first category has to do with making sparse tensors work more similarly to dense tensors. Most of the things I'd put in this category are things that are possible to do, but hard/tedious/repetitive. One example of this I kind of already talked about, but you can't index into a sparse tensor like you can a dense tensor. For a dense tensor you can just do tensor[3:5, 3, [4, 6, 3, 2], [7, 2, 6, 7]], but when I needed to do the same thing for a sparse tensor it took quite a bit of time, and creativity. Another example that comes up in my use case a lot is filtration operations. For example if I need to get all values in my tensor that are larger than 20 then I need to do something like

mask = old_sparse._values() > 20
new_sparse = torch.sparse_coo_tensor(old_sparse._indices[:, mask], old_sparse._values()[mask], old_sparse.shape)

which is really long and tedious in comparison to what it could be: new_sparse = old_sparse[old_sparse > 20]
I can think of a lot more examples that fall into the category but those two is the most glaring.

The second category is operations where the non zero values in the sparse tensors affect each other. The main example of this are the reduction operations. If I want to calculate the product/max/min/average (any reduction op but addition, ( should note, you can actually get average working with a bit of creativity)) over a sparse dimension I need do something weird, custom, and slow. I actually had a PR open here that would introduce torch.sparse.max, but I've been kind of procrastinating on making changes to it

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Aug 4, 2021

@IAmKohlton

Thank you for the additional context and comments!

Another reason I was asking for more specifics on your application is that when they overlap with that of other customers it's easier to justify broad investments into them.

However, the shortcomings you mention are known and valid. Indeed it would be useful to extend the semantics and operator coverage of sparse Tensors to be closer to dense Tensors and the two operations you mention are quite common.

On the topic of reductions, I know you mention that values "don't exist" at unspecified points. Would you say this agrees with the concept of NumPy's masked arrays? In particular, for reductions sparse Tensors, in the scientific sense, should behave equivalent to their dense counterpart. That means the max may be 0 if all non-zero elements are less than 0, even if the user might want these 0s to be ignored. But a masked Tensor actually allows you to define reductions only for the specified values. If you were to convert say all 0s to masked out values and then apply the reduction you'd get the aforementioned behavior and consistent semantics.

@IAmKohlton
Copy link
Author

I think I can give a way that might tell you more about the application from an abstract perspective. Like I said we're doing different kinds of simulations. In these simulations we have objects in 3D space that belong to some set X. Every element of set X has a set of subobjects Y (not every X has the same number of subobjects). Every y in Y cares about some set of locations in 3D space, but not others. Denote the set of elements in 3D space to be Z (to be clear the locations in Z isn't the same as the set of locations that make up X). The 3D locations specified in Z are not regularly spaced, and we can assume they're basically randomly placed. For every x, y, z index we have a vector of values that can range from a single value, to thousands of values. This can be modeled with a sparse tensor of shape (len(X), max(len(Y)), len(Z), num_values_we_care_about). So a single entry in our sparse tensor represents: the index of an object x, the index of a subobject y that belongs to x, and the index of a location in 3D space z. Note that if you look in the sparse tensor and see many of the same z you know they all represent the same location since they're conceptually disconnected from the objects or subobjects. However if you see many of the same y they may or may not be the same subobject since x=1, y=2 is a different subobject than x=5, y=2.

Basically our tensor tracks how all the different x, y pairs interact with the different z's. Each simulation sees what happens when the different x's move around in 3D space. We observe what x, y pairs interact with what z's and what the values are at the z's when there is an interaction. I hope that tells you a bit more about how we're using torch sparse tensors.

I would say that my use case is basically a masked sparse tensor since an interaction is either happening or it's not. I've thought about the max(-1, -2)=0 thing before, and based on what you're saying the masked sparse tensor seems like it would solve that issue while still maintaining mathematical correctness. It does seem like you would need two entirely different types of tensors to do operations that are really very similar to each other though. To me it seems simpler to have an optional boolean on the reduction operations rather than two entirely different types of tensors that do the same thing a very large portion of the time, but that's just my opinion after a few seconds of thought

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Aug 5, 2021

Thank you for the detailed explanation, @IAmKohlton!

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Aug 30, 2021

@IAmKohlton - a first, obvious performance improvement was just merge. Please feel encouraged to reopen this issue if it's not fast enough.

@tvercaut
Copy link

tvercaut commented Jan 7, 2022

@cpuhrsch I am using PyTorch 1.10 which from the release notes includes the improvement you refer to in #61788 (comment) but am still facing performance issues with index_select on sparse tensors.

For sake of comparison, I tried the implementation from torch_sparse and it is sometimes up to 1000x faster.

Should this issue be reopened or should I create a new one?

Below is the simple test case I used (also on colab):

import torch
print(torch.__version__)

torchdevice = torch.device('cpu')
if torch.cuda.is_available():
  torchdevice = torch.device('cuda')
  print('Default GPU is ' + torch.cuda.get_device_name(torch.device('cuda')))
print('Running on ' + str(torchdevice))

!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
import torch_sparse

# Covenience wrapper around torch_sparse index_select
def ts_index_select(A,sdim,idx):
  Ats = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(A)
  Ats_select = torch_sparse.index_select(Ats,sdim,idx)
  row, col, value = Ats_select.coo()
  As_select = torch.sparse_coo_tensor(torch.stack([row, col], dim=0), value, (Ats_select.size(0), Ats_select.size(1)))
  return As_select

# Dimension of the square sparse matrix
n = 1000000
# Number of non-zero elements (up to duplicates)
nnz = 100000
# Number of selected indices (up to duplicates)
m = 10000

rowidx = torch.randint(low=0, high=n, size=(nnz,), device=torchdevice)
colidx = torch.randint(low=0, high=n, size=(nnz,), device=torchdevice)
itemidx = torch.vstack((rowidx,colidx))
xvalues = torch.randn(nnz, device=torchdevice)
SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=(n,n)).coalesce()
print('SparseX:',SparseX)

selectrowidx = torch.unique(torch.randint(low=0, high=n, size=(m,), device=torchdevice), sorted=True)

print('\nRunning index_select from PyTorch')
%timeit SparseXsub1 = SparseX.index_select(0,selectrowidx)

print('\nRunning index_select from torch_sparse')
%timeit SparseXsub2 = ts_index_select(SparseX,0,selectrowidx)

@nikitaved
Copy link
Collaborator

nikitaved commented Feb 3, 2022

@IAmKohlton , your algorithm can further be improved if the binary search in sorted arrays is replaced with a hash-map inclusion test, but his one will require modifying the kernel on the CPP side. The CPU code has a nnz * len(index) complexity and can be improved to nnz + len(index) with a caveat of using a thread-safe hash-map with multithreading.

@april211
Copy link

@cpuhrsch I am using PyTorch 1.10 which from the release notes includes the improvement you refer to in #61788 (comment) but am still facing performance issues with index_select on sparse tensors.

For sake of comparison, I tried the implementation from torch_sparse and it is sometimes up to 1000x faster.

Should this issue be reopened or should I create a new one?

Below is the simple test case I used (also on colab):

import torch
print(torch.__version__)

torchdevice = torch.device('cpu')
if torch.cuda.is_available():
  torchdevice = torch.device('cuda')
  print('Default GPU is ' + torch.cuda.get_device_name(torch.device('cuda')))
print('Running on ' + str(torchdevice))

!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
import torch_sparse

# Covenience wrapper around torch_sparse index_select
def ts_index_select(A,sdim,idx):
  Ats = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(A)
  Ats_select = torch_sparse.index_select(Ats,sdim,idx)
  row, col, value = Ats_select.coo()
  As_select = torch.sparse_coo_tensor(torch.stack([row, col], dim=0), value, (Ats_select.size(0), Ats_select.size(1)))
  return As_select

# Dimension of the square sparse matrix
n = 1000000
# Number of non-zero elements (up to duplicates)
nnz = 100000
# Number of selected indices (up to duplicates)
m = 10000

rowidx = torch.randint(low=0, high=n, size=(nnz,), device=torchdevice)
colidx = torch.randint(low=0, high=n, size=(nnz,), device=torchdevice)
itemidx = torch.vstack((rowidx,colidx))
xvalues = torch.randn(nnz, device=torchdevice)
SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=(n,n)).coalesce()
print('SparseX:',SparseX)

selectrowidx = torch.unique(torch.randint(low=0, high=n, size=(m,), device=torchdevice), sorted=True)

print('\nRunning index_select from PyTorch')
%timeit SparseXsub1 = SparseX.index_select(0,selectrowidx)

print('\nRunning index_select from torch_sparse')
%timeit SparseXsub2 = ts_index_select(SparseX,0,selectrowidx)

Thank you for this info! Now I use torch_sparse instead to do the indexing of a sparse matrix, and it's x200 faster than the original for-loop-based method in my use case!

@nikitaved
Copy link
Collaborator

@april211 , which PyTorch version are you using? We enabled much faster kernels in 1.12.

@april211
Copy link

@april211 , which PyTorch version are you using? We enabled much faster kernels in 1.12.

I currently use torch 1.9.1+cu111 to reproduce a deep learning paper, sorry I forgot to mention it...

IMHO, I think the latest PyTorch docs of this function should add this info as a reminder for user's convenience :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: performance Issues related to performance, either of kernel code or framework glue module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants