Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse COO tensors index_select is way too slow to be any kind of useful #54561

Closed
wanjunhong0 opened this issue Mar 24, 2021 · 3 comments
Closed
Labels
module: performance Issues related to performance, either of kernel code or framework glue module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@wanjunhong0
Copy link

wanjunhong0 commented Mar 24, 2021

馃悰 Bug

torch.index_select( ) on sparse COO tensors is way too slow

To Reproduce

Steps to reproduce the behavior:

import time
import torch


a = torch.eye(10000).to_sparse().coalesce()
t = time.time()
c = a.index_select(0, torch.arange(1000))
print(time.time() - t)

t = time.time()
b = []
for i in range(1000):
    b.append(a[i])
b = torch.stack(b)
print(time.time() - t)
print((b.to_dense() == c.to_dense()).all())
8.997999906539917
0.08900022506713867
tensor(True)

Expected behavior

index_select are 100 times slower than than a for loop. It is way too slow to be any kind of useful.

Environment

Collecting environment information...
PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: TITAN Xp
GPU 1: TITAN Xp
GPU 2: TITAN Xp
GPU 3: TITAN Xp
GPU 4: TITAN Xp
GPU 5: TITAN Xp
GPU 6: TITAN Xp
GPU 7: TITAN Xp

Nvidia driver version: 450.80.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip] numpy==1.19.2
[pip] torch==1.7.0
[pip] torch-cluster==1.5.8
[pip] torch-geometric==1.6.3
[pip] torch-scatter==2.0.5
[pip] torch-sparse==0.6.8
[pip] torch-spline-conv==1.2.0
[pip] torchelastic==0.2.1
[pip] torchvision==0.8.0
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.0.221             h6bb024c_0  
[conda] mkl                       2020.2                      256  
[conda] mkl-service               2.3.0            py38he904b0f_0  
[conda] mkl_fft                   1.2.0            py38h23d657b_0  
[conda] mkl_random                1.1.1            py38h0573a6f_0  
[conda] numpy                     1.19.2           py38h54aff64_0  
[conda] numpy-base                1.19.2           py38hfa32c7d_0  
[conda] pytorch                   1.7.0           py3.8_cuda11.0.221_cudnn8.0.3_0    pytorch
[conda] torch-cluster             1.5.8                    pypi_0    pypi
[conda] torch-geometric           1.6.3                    pypi_0    pypi
[conda] torch-scatter             2.0.5                    pypi_0    pypi
[conda] torch-sparse              0.6.8                    pypi_0    pypi
[conda] torch-spline-conv         1.2.0                    pypi_0    pypi
[conda] torchelastic              0.2.1                    pypi_0    pypi
[conda] torchvision               0.8.0                py38_cu110    pytorch

cc @VitalyFedyunin @ngimel @aocsa @nikitaved @pearu @mruberry

@ailzhang ailzhang added module: performance Issues related to performance, either of kernel code or framework glue module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 26, 2021
@IAmKohlton
Copy link

Hey, I didn't see that there was already an issue open for this. I made issue #61788 that contains a proposed algorithm to speed up index_select. My version was better than the old one by a factor of between a few hundred and a few thousand depending on the sparse tensor

@pearu pearu added this to To do in Sparse tensors Aug 10, 2021
@cpuhrsch
Copy link
Contributor

A first step towards fixing this was recently landed. @wanjunhong0 #63008 - could you check to see whether this resolves your issue?

@wanjunhong0
Copy link
Author

A first step towards fixing this was recently landed. @wanjunhong0 #63008 - could you check to see whether this resolves your issue?

Sorry for the late reply. Run the test again, index selecting 1000 item from 10000*10000 matrix:
Before: 3.1917645931243896
After: 0.004959821701049805

Almost 1000 times faster. Thanks for the contribution. I am closing this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: performance Issues related to performance, either of kernel code or framework glue module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Development

No branches or pull requests

4 participants