Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Perf request] Make index_select on sparse COO tensors as fast as that from rusty1s/pytorch_sparse (1000x) #72212

Closed
tvercaut opened this issue Feb 2, 2022 · 15 comments
Assignees
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@tvercaut
Copy link

tvercaut commented Feb 2, 2022

馃殌 The feature, motivation and pitch

The native index_select on sparse COO tensors has been reported to be slow in the past (see #61788 and #54561). While improvements have been made with #63008, it remains slower than expected even with pytorch 1.10.0 and 1.10.1 which, as per the 1.10.0 release notes, I understand include the #63008 patch.

Alternatives

The index_select implementation from rusty1s/torch_sparse can be used in lieu of the native index_select and it is sometimes up to 1000x faster.

Additional context

Below is the simple test case I used as a quick and dirty "benchmark" (also on colab):

import torch
print(torch.__version__)

torchdevice = torch.device('cpu')
if torch.cuda.is_available():
  torchdevice = torch.device('cuda')
  print('Default GPU is ' + torch.cuda.get_device_name(torch.device('cuda')))
print('Running on ' + str(torchdevice))

!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
import torch_sparse

# Convenience wrapper around torch_sparse index_select
def ts_index_select(A,sdim,idx):
  Ats = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(A)
  Ats_select = torch_sparse.index_select(Ats,sdim,idx)
  row, col, value = Ats_select.coo()
  As_select = torch.sparse_coo_tensor(torch.stack([row, col], dim=0), value, (Ats_select.size(0), Ats_select.size(1)))
  return As_select

# Dimension of the square sparse matrix
n = 1000000
# Number of non-zero elements (up to duplicates)
nnz = 100000
# Number of selected indices (up to duplicates)
m = 10000

rowidx = torch.randint(low=0, high=n, size=(nnz,), device=torchdevice)
colidx = torch.randint(low=0, high=n, size=(nnz,), device=torchdevice)
itemidx = torch.vstack((rowidx,colidx))
xvalues = torch.randn(nnz, device=torchdevice)
SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=(n,n)).coalesce()
print('SparseX:',SparseX)

selectrowidx = torch.unique(torch.randint(low=0, high=n, size=(m,), device=torchdevice), sorted=True)

print('\nRunning index_select from PyTorch')
%timeit SparseXsub1 = SparseX.index_select(0,selectrowidx)

print('\nRunning index_select from torch_sparse')
%timeit SparseXsub2 = ts_index_select(SparseX,0,selectrowidx)

with output

Running index_select from PyTorch
1 loop, best of 5: 2.15 s per loop

Running index_select from torch_sparse
1000 loops, best of 5: 1.69 ms per loop

Despite the suggestion in #61788 (comment) to reopen that issue if performance bottlenecks remained , as I not the OP of #61788, I cannot reopen it myself, hence this new feature request.

cc @nikitaved @pearu @cpuhrsch

@ngimel ngimel added module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 2, 2022
@nikitaved
Copy link
Collaborator

nikitaved commented Feb 3, 2022

I did have a quick look at the implementation. It appears to be quite serial, and maybe we could make it parallelized and/or use more efficient data structures...

@nikitaved
Copy link
Collaborator

nikitaved commented Feb 3, 2022

@tvercaut , are these times for the CPU or GPU. If for the GPU, could you please append torch.cuda.synchronize() to timeit like this : t.index_select(...); torch.cuda.synchronize()? Could you please show the results for both the CPU and GPU for that matter?

@nikitaved
Copy link
Collaborator

nikitaved commented Feb 3, 2022

For the CPU the complexity of the implemented algorithm seems like nnz * len(index), and I think we can do that in nnz + len(index).

@tvercaut
Copy link
Author

tvercaut commented Feb 3, 2022

The reported timing were for the GPU but I get roughly the same numbers on CPU for this toy example. I have updated my quick and dirty colab notebook to run both GPU and CPU and also add torch.cuda.synchronize() on the %timeit line for the GPU:
https://colab.research.google.com/drive/1-tYbh_KP50NDhBuJ_m_DulRwCAi_ewZO?usp=sharing

Default GPU is Tesla K80
Available CPU is Intel(R) Xeon(R) CPU @ 2.30GHz

Running index_select from PyTorch on cuda
1 loop, best of 5: 2.27 s per loop

Running index_select from torch_sparse on cuda
1000 loops, best of 5: 1.79 ms per loop

Running index_select from PyTorch on cpu
1 loop, best of 5: 2.28 s per loop

Running index_select from torch_sparse on cpu
100 loops, best of 5: 5.52 ms per loop

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Feb 3, 2022

This is a reasonable request. We can find someone to work on it, but it'll take a while. If it's to be done sooner we're also happy to accept a pull request.

@nikitaved
Copy link
Collaborator

I can work on this for a change.

@nikitaved
Copy link
Collaborator

nikitaved commented Feb 11, 2022

@tvercaut , FYI, the CUDA backend converts tensors to the CPU, runs the algo, then converts back.
I am actively working on the issue. The CPU part is more or less clear to me, for CUDA I will stick to the CUDA-CPU-CUDA paradigm for now...

@nikitaved
Copy link
Collaborator

nikitaved commented Mar 3, 2022

@tvercaut , I have another question for you. In your use cases you only index for dim=0, right? And if so, are you aware that the combination coalesce + select at dim=0 is the fastest if the implementation utilizes the sorted structure of coalesced tensors?
When indexing with dim=1 and ndim=2, I am not sure we will be able to match the torch_sparse's performance as we do not have the CSC support yet AFAIK, which means that sorting along a given dimension is unavoidable and will be slow for pretty large nnzs.

@tvercaut
Copy link
Author

tvercaut commented Mar 3, 2022

@nikitaved, in the actual use case I was looking at, I am interested in indexing with dim=0 only but it's a bit specific and I assume indexing with dim=1 (and dim=2) as well would serve other use cases.

More specifically, my use case involves a sparse matrix L and I am intested in (conceptually) extracting a sub matrix from L@L.T using the same indices for dim=0 and dim=1. In my current approach, I get Lsub by indexing L for dim=0 and then create the sparse symmetrix sub matrix I want as Lsub@Lsub.T.

The toy example shared here was my attempt at providing a minimal example of the performance issue I was facing.

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Mar 3, 2022

@nikitaved - on CSC support @pearu was planning to work on it next.

@freddy5566
Copy link

freddy5566 commented Apr 8, 2022

Hi @tvercaut,

Have you been tested on an extremely huge table? When I used the index_select in your approach, I encounter the OOM issue. Plus, it cannot deal with one dimension sparse tensor. I had to make it two-dimensional beforehand. And index_select dim=1 and then calculate the mean on dim 0. Though it works, it is not very elegant.

Or, do you know a better way to get the values from a sparse tensor by a sequence of the index? Original index_select is too slow to use. Something like tensor[sequence] is what I want.
Thank you in advance.

@nikitaved
Copy link
Collaborator

nikitaved commented May 9, 2022

Master now contains an updated and faster version. Note, however, just like the previous version, CUDA inputs are converted to CPU, then the algorithm is run on the CPU and it moves the result back to CUDA once it is done.
Some benchmarks are available here: #72710 (comment)

pytorchmergebot pushed a commit that referenced this issue May 10, 2022
Fixes #72212.

This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible.

Benchmark results.

<details>

<summary>Testing script</summary>

```python
import torch
import math
from IPython import get_ipython
from itertools import product
import pickle
from torch.utils.benchmark import Timer, Compare

torch.manual_seed(13)
#torch.set_num_threads(1)
ipython = get_ipython()

index_sizes = (100, 1000, 10000)
# specifies (n, nnz)
problem_dims = (
    # n > nnz
    (10000, 100),
    (100000, 1000),
    (1000000, 10000),
    # n < nnz
    (10, 100),
    (10, 1000),
    (10, 10000),
    (100, 1000),
    (100, 10000),
    (1000, 10000),
    (1000, 100000),
    (1000, 1000000),
    #(1000000, 1000000000),
)

def f(t, d, index):
    s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t)
    ss = s.index_select(d, index)
    return ss.coo()

name = "PR"
results = []

for (n, nnz), m in product(problem_dims, index_sizes):
    for d in (0, 1):
        if nnz < n:
            shape = (n, n)
        else:
            shape = (n, nnz // n) if d == 0 else (nnz // n, n)
        nrows, ncols = shape
        rowidx = torch.randint(low=0, high=nrows, size=(nnz,))
        colidx = torch.randint(low=0, high=ncols, size=(nnz,))
        itemidx = torch.vstack((rowidx, colidx))
        xvalues = torch.randn(nnz)
        index = torch.randint(low=0, high=n, size=(m,))

        SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce()
        smtp = "SparseX.index_select(d, index)"
        timer = Timer(smtp,
                      globals=globals(),
                      label="coo.index_select",
                      description=f"{name}: coo.index_select",
                      sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}",
                      num_threads=torch.get_num_threads())
        results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{name}_index_select.pickle", 'wb') as f:
    pickle.dump(results, f)

```

</details>

<details>

<summary>Gather results</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
        "PR",
        "torch_sparse",
        "master"
        ]

timers = []
for name in files:
    with open("{}_index_select.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()

```

</details>

<details>

<summary>PR/torch_sparse/master runtime comparison</summary>

```
[----------------------------------- coo.index_select ----------------------------------]
                                                    |    PR   |  torch_sparse  |   master
32 threads: -----------------------------------------------------------------------------
      n=10000, nnz=100, index_len=100, dim=0        |     14  |        140     |       10
      n=10000, nnz=100, index_len=100, dim=1        |     14  |        200     |       10
      n=10000, nnz=100, index_len=1000, dim=0       |     30  |        180     |       38
      n=10000, nnz=100, index_len=1000, dim=1       |     34  |        240     |       38
      n=10000, nnz=100, index_len=10000, dim=0      |    278  |        460     |      330
      n=10000, nnz=100, index_len=10000, dim=1      |    275  |        516     |      330
      n=100000, nnz=1000, index_len=100, dim=0      |     16  |        290     |       31
      n=100000, nnz=1000, index_len=100, dim=1      |     26  |        390     |       31
      n=100000, nnz=1000, index_len=1000, dim=0     |     45  |        405     |      263
      n=100000, nnz=1000, index_len=1000, dim=1     |     73  |        500     |      261
      n=100000, nnz=1000, index_len=10000, dim=0    |    444  |        783     |     2570
      n=100000, nnz=1000, index_len=10000, dim=1    |    470  |        890     |     2590
      n=1000000, nnz=10000, index_len=100, dim=0    |     25  |       2400     |      270
      n=1000000, nnz=10000, index_len=100, dim=1    |    270  |       4000     |      269
      n=1000000, nnz=10000, index_len=1000, dim=0   |     74  |       2600     |     2620
      n=1000000, nnz=10000, index_len=1000, dim=1   |    464  |       3600     |     2640
      n=1000000, nnz=10000, index_len=10000, dim=0  |    635  |       3300     |    26400
      n=1000000, nnz=10000, index_len=10000, dim=1  |   1000  |       3960     |    26400
      n=10, nnz=100, index_len=100, dim=0           |     16  |        137     |       16
      n=10, nnz=100, index_len=100, dim=1           |     16  |        220     |       16
      n=10, nnz=100, index_len=1000, dim=0          |     63  |        238     |       81
      n=10, nnz=100, index_len=1000, dim=1          |     60  |        698     |       78
      n=10, nnz=100, index_len=10000, dim=0         |    480  |        940     |      862
      n=10, nnz=100, index_len=10000, dim=1         |    330  |       4930     |     1070
      n=10, nnz=1000, index_len=100, dim=0          |     60  |        200     |       73
      n=10, nnz=1000, index_len=100, dim=1          |     56  |        683     |       70
      n=10, nnz=1000, index_len=1000, dim=0         |    480  |        530     |     1050
      n=10, nnz=1000, index_len=1000, dim=1         |    330  |       4550     |     1368
      n=10, nnz=1000, index_len=10000, dim=0        |   3100  |       2900     |     9300
      n=10, nnz=1000, index_len=10000, dim=1        |   3400  |      46000     |     9100
      n=10, nnz=10000, index_len=100, dim=0         |    400  |        453     |      857
      n=10, nnz=10000, index_len=100, dim=1         |    400  |       4070     |     1730
      n=10, nnz=10000, index_len=1000, dim=0        |   2840  |       2600     |    13900
      n=10, nnz=10000, index_len=1000, dim=1        |   3700  |      40600     |    16000
      n=10, nnz=10000, index_len=10000, dim=0       |  83200  |      67400     |   160000
      n=10, nnz=10000, index_len=10000, dim=1       |  68000  |     528000     |   190000
      n=100, nnz=1000, index_len=100, dim=0         |     46  |        148     |       31
      n=100, nnz=1000, index_len=100, dim=1         |     45  |        242     |       37
      n=100, nnz=1000, index_len=1000, dim=0        |     68  |        248     |      240
      n=100, nnz=1000, index_len=1000, dim=1        |     66  |        755     |      290
      n=100, nnz=1000, index_len=10000, dim=0       |    370  |        802     |     2250
      n=100, nnz=1000, index_len=10000, dim=1       |    372  |       5430     |     2770
      n=100, nnz=10000, index_len=100, dim=0        |     82  |        210     |      224
      n=100, nnz=10000, index_len=100, dim=1        |     74  |        986     |      270
      n=100, nnz=10000, index_len=1000, dim=0       |    350  |        618     |     2600
      n=100, nnz=10000, index_len=1000, dim=1       |    370  |       4660     |     4560
      n=100, nnz=10000, index_len=10000, dim=0      |   3000  |       3400     |    41680
      n=100, nnz=10000, index_len=10000, dim=1      |   5000  |      47500     |    30400
      n=1000, nnz=10000, index_len=100, dim=0       |     71  |        160     |      185
      n=1000, nnz=10000, index_len=100, dim=1       |     64  |        516     |      190
      n=1000, nnz=10000, index_len=1000, dim=0      |    100  |        249     |     1740
      n=1000, nnz=10000, index_len=1000, dim=1      |     98  |       1030     |     1770
      n=1000, nnz=10000, index_len=10000, dim=0     |    600  |        808     |    18300
      n=1000, nnz=10000, index_len=10000, dim=1     |    663  |       5300     |    18500
      n=1000, nnz=100000, index_len=100, dim=0      |    160  |        258     |     1890
      n=1000, nnz=100000, index_len=100, dim=1      |    200  |       3620     |     2050
      n=1000, nnz=100000, index_len=1000, dim=0     |    500  |        580     |    18700
      n=1000, nnz=100000, index_len=1000, dim=1     |    640  |       7550     |    30000
      n=1000, nnz=100000, index_len=10000, dim=0    |   3400  |       3260     |   186000
      n=1000, nnz=100000, index_len=10000, dim=1    |   3600  |      49600     |   194000
      n=1000, nnz=1000000, index_len=100, dim=0     |    517  |        957     |    18700
      n=1000, nnz=1000000, index_len=100, dim=1     |    680  |      39600     |    37600
      n=1000, nnz=1000000, index_len=1000, dim=0    |   3600  |       4500     |   186000
      n=1000, nnz=1000000, index_len=1000, dim=1    |   5800  |      76400     |   190000
      n=1000, nnz=1000000, index_len=10000, dim=0   |  50000  |      67900     |  1800000
      n=1000, nnz=1000000, index_len=10000, dim=1   |  45000  |     570000     |  1900000

Times are in microseconds (us).

```

</details>
Pull Request resolved: #72710
Approved by: https://github.com/pearu, https://github.com/cpuhrsch
facebook-github-bot pushed a commit that referenced this issue May 10, 2022
Summary:
Fixes #72212.

This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible.

Benchmark results.

<details>

<summary>Testing script</summary>

```python
import torch
import math
from IPython import get_ipython
from itertools import product
import pickle
from torch.utils.benchmark import Timer, Compare

torch.manual_seed(13)
#torch.set_num_threads(1)
ipython = get_ipython()

index_sizes = (100, 1000, 10000)
# specifies (n, nnz)
problem_dims = (
    # n > nnz
    (10000, 100),
    (100000, 1000),
    (1000000, 10000),
    # n < nnz
    (10, 100),
    (10, 1000),
    (10, 10000),
    (100, 1000),
    (100, 10000),
    (1000, 10000),
    (1000, 100000),
    (1000, 1000000),
    #(1000000, 1000000000),
)

def f(t, d, index):
    s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t)
    ss = s.index_select(d, index)
    return ss.coo()

name = "PR"
results = []

for (n, nnz), m in product(problem_dims, index_sizes):
    for d in (0, 1):
        if nnz < n:
            shape = (n, n)
        else:
            shape = (n, nnz // n) if d == 0 else (nnz // n, n)
        nrows, ncols = shape
        rowidx = torch.randint(low=0, high=nrows, size=(nnz,))
        colidx = torch.randint(low=0, high=ncols, size=(nnz,))
        itemidx = torch.vstack((rowidx, colidx))
        xvalues = torch.randn(nnz)
        index = torch.randint(low=0, high=n, size=(m,))

        SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce()
        smtp = "SparseX.index_select(d, index)"
        timer = Timer(smtp,
                      globals=globals(),
                      label="coo.index_select",
                      description=f"{name}: coo.index_select",
                      sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}",
                      num_threads=torch.get_num_threads())
        results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{name}_index_select.pickle", 'wb') as f:
    pickle.dump(results, f)

```

</details>

<details>

<summary>Gather results</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
        "PR",
        "torch_sparse",
        "master"
        ]

timers = []
for name in files:
    with open("{}_index_select.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()

```

</details>

<details>

<summary>PR/torch_sparse/master runtime comparison</summary>

```
[----------------------------------- coo.index_select ----------------------------------]
                                                    |    PR   |  torch_sparse  |   master
32 threads: -----------------------------------------------------------------------------
      n=10000, nnz=100, index_len=100, dim=0        |     14  |        140     |       10
      n=10000, nnz=100, index_len=100, dim=1        |     14  |        200     |       10
      n=10000, nnz=100, index_len=1000, dim=0       |     30  |        180     |       38
      n=10000, nnz=100, index_len=1000, dim=1       |     34  |        240     |       38
      n=10000, nnz=100, index_len=10000, dim=0      |    278  |        460     |      330
      n=10000, nnz=100, index_len=10000, dim=1      |    275  |        516     |      330
      n=100000, nnz=1000, index_len=100, dim=0      |     16  |        290     |       31
      n=100000, nnz=1000, index_len=100, dim=1      |     26  |        390     |       31
      n=100000, nnz=1000, index_len=1000, dim=0     |     45  |        405     |      263
      n=100000, nnz=1000, index_len=1000, dim=1     |     73  |        500     |      261
      n=100000, nnz=1000, index_len=10000, dim=0    |    444  |        783     |     2570
      n=100000, nnz=1000, index_len=10000, dim=1    |    470  |        890     |     2590
      n=1000000, nnz=10000, index_len=100, dim=0    |     25  |       2400     |      270
      n=1000000, nnz=10000, index_len=100, dim=1    |    270  |       4000     |      269
      n=1000000, nnz=10000, index_len=1000, dim=0   |     74  |       2600     |     2620
      n=1000000, nnz=10000, index_len=1000, dim=1   |    464  |       3600     |     2640
      n=1000000, nnz=10000, index_len=10000, dim=0  |    635  |       3300     |    26400
      n=1000000, nnz=10000, index_len=10000, dim=1  |   1000  |       3960     |    26400
      n=10, nnz=100, index_len=100, dim=0           |     16  |        137     |       16
      n=10, nnz=100, index_len=100, dim=1           |     16  |        220     |       16
      n=10, nnz=100, index_len=1000, dim=0          |     63  |        238     |       81
      n=10, nnz=100, index_len=1000, dim=1          |     60  |        698     |       78
      n=10, nnz=100, index_len=10000, dim=0         |    480  |        940     |      862
      n=10, nnz=100, index_len=10000, dim=1         |    330  |       4930     |     1070
      n=10, nnz=1000, index_len=100, dim=0          |     60  |        200     |       73
      n=10, nnz=1000, index_len=100, dim=1          |     56  |        683     |       70
      n=10, nnz=1000, index_len=1000, dim=0         |    480  |        530     |     1050
      n=10, nnz=1000, index_len=1000, dim=1         |    330  |       4550     |     1368
      n=10, nnz=1000, index_len=10000, dim=0        |   3100  |       2900     |     9300
      n=10, nnz=1000, index_len=10000, dim=1        |   3400  |      46000     |     9100
      n=10, nnz=10000, index_len=100, dim=0         |    400  |        453     |      857
      n=10, nnz=10000, index_len=100, dim=1         |    400  |       4070     |     1730
      n=10, nnz=10000, index_len=1000, dim=0        |   2840  |       2600     |    13900
      n=10, nnz=10000, index_len=1000, dim=1        |   3700  |      40600     |    16000
      n=10, nnz=10000, index_len=10000, dim=0       |  83200  |      67400     |   160000
      n=10, nnz=10000, index_len=10000, dim=1       |  68000  |     528000     |   190000
      n=100, nnz=1000, index_len=100, dim=0         |     46  |        148     |       31
      n=100, nnz=1000, index_len=100, dim=1         |     45  |        242     |       37
      n=100, nnz=1000, index_len=1000, dim=0        |     68  |        248     |      240
      n=100, nnz=1000, index_len=1000, dim=1        |     66  |        755     |      290
      n=100, nnz=1000, index_len=10000, dim=0       |    370  |        802     |     2250
      n=100, nnz=1000, index_len=10000, dim=1       |    372  |       5430     |     2770
      n=100, nnz=10000, index_len=100, dim=0        |     82  |        210     |      224
      n=100, nnz=10000, index_len=100, dim=1        |     74  |        986     |      270
      n=100, nnz=10000, index_len=1000, dim=0       |    350  |        618     |     2600
      n=100, nnz=10000, index_len=1000, dim=1       |    370  |       4660     |     4560
      n=100, nnz=10000, index_len=10000, dim=0      |   3000  |       3400     |    41680
      n=100, nnz=10000, index_len=10000, dim=1      |   5000  |      47500     |    30400
      n=1000, nnz=10000, index_len=100, dim=0       |     71  |        160     |      185
      n=1000, nnz=10000, index_len=100, dim=1       |     64  |        516     |      190
      n=1000, nnz=10000, index_len=1000, dim=0      |    100  |        249     |     1740
      n=1000, nnz=10000, index_len=1000, dim=1      |     98  |       1030     |     1770
      n=1000, nnz=10000, index_len=10000, dim=0     |    600  |        808     |    18300
      n=1000, nnz=10000, index_len=10000, dim=1     |    663  |       5300     |    18500
      n=1000, nnz=100000, index_len=100, dim=0      |    160  |        258     |     1890
      n=1000, nnz=100000, index_len=100, dim=1      |    200  |       3620     |     2050
      n=1000, nnz=100000, index_len=1000, dim=0     |    500  |        580     |    18700
      n=1000, nnz=100000, index_len=1000, dim=1     |    640  |       7550     |    30000
      n=1000, nnz=100000, index_len=10000, dim=0    |   3400  |       3260     |   186000
      n=1000, nnz=100000, index_len=10000, dim=1    |   3600  |      49600     |   194000
      n=1000, nnz=1000000, index_len=100, dim=0     |    517  |        957     |    18700
      n=1000, nnz=1000000, index_len=100, dim=1     |    680  |      39600     |    37600
      n=1000, nnz=1000000, index_len=1000, dim=0    |   3600  |       4500     |   186000
      n=1000, nnz=1000000, index_len=1000, dim=1    |   5800  |      76400     |   190000
      n=1000, nnz=1000000, index_len=10000, dim=0   |  50000  |      67900     |  1800000
      n=1000, nnz=1000000, index_len=10000, dim=1   |  45000  |     570000     |  1900000

Times are in microseconds (us).

```

</details>

Pull Request resolved: #72710

Reviewed By: samdow

Differential Revision: D36282349

Pulled By: malfet

fbshipit-source-id: 3679ea4ebeeda4d200a441aef6d45b98303bc0c0
facebook-github-bot pushed a commit that referenced this issue May 13, 2022
Summary:
Fixes #72212.

This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible.

Benchmark results.

<details>

<summary>Testing script</summary>

```python
import torch
import math
from IPython import get_ipython
from itertools import product
import pickle
from torch.utils.benchmark import Timer, Compare

torch.manual_seed(13)
#torch.set_num_threads(1)
ipython = get_ipython()

index_sizes = (100, 1000, 10000)
# specifies (n, nnz)
problem_dims = (
    # n > nnz
    (10000, 100),
    (100000, 1000),
    (1000000, 10000),
    # n < nnz
    (10, 100),
    (10, 1000),
    (10, 10000),
    (100, 1000),
    (100, 10000),
    (1000, 10000),
    (1000, 100000),
    (1000, 1000000),
    #(1000000, 1000000000),
)

def f(t, d, index):
    s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t)
    ss = s.index_select(d, index)
    return ss.coo()

name = "PR"
results = []

for (n, nnz), m in product(problem_dims, index_sizes):
    for d in (0, 1):
        if nnz < n:
            shape = (n, n)
        else:
            shape = (n, nnz // n) if d == 0 else (nnz // n, n)
        nrows, ncols = shape
        rowidx = torch.randint(low=0, high=nrows, size=(nnz,))
        colidx = torch.randint(low=0, high=ncols, size=(nnz,))
        itemidx = torch.vstack((rowidx, colidx))
        xvalues = torch.randn(nnz)
        index = torch.randint(low=0, high=n, size=(m,))

        SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce()
        smtp = "SparseX.index_select(d, index)"
        timer = Timer(smtp,
                      globals=globals(),
                      label="coo.index_select",
                      description=f"{name}: coo.index_select",
                      sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}",
                      num_threads=torch.get_num_threads())
        results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{name}_index_select.pickle", 'wb') as f:
    pickle.dump(results, f)

```

</details>

<details>

<summary>Gather results</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
        "PR",
        "torch_sparse",
        "master"
        ]

timers = []
for name in files:
    with open("{}_index_select.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()

```

</details>

<details>

<summary>PR/torch_sparse/master runtime comparison</summary>

```
[----------------------------------- coo.index_select ----------------------------------]
                                                    |    PR   |  torch_sparse  |   master
32 threads: -----------------------------------------------------------------------------
      n=10000, nnz=100, index_len=100, dim=0        |     14  |        140     |       10
      n=10000, nnz=100, index_len=100, dim=1        |     14  |        200     |       10
      n=10000, nnz=100, index_len=1000, dim=0       |     30  |        180     |       38
      n=10000, nnz=100, index_len=1000, dim=1       |     34  |        240     |       38
      n=10000, nnz=100, index_len=10000, dim=0      |    278  |        460     |      330
      n=10000, nnz=100, index_len=10000, dim=1      |    275  |        516     |      330
      n=100000, nnz=1000, index_len=100, dim=0      |     16  |        290     |       31
      n=100000, nnz=1000, index_len=100, dim=1      |     26  |        390     |       31
      n=100000, nnz=1000, index_len=1000, dim=0     |     45  |        405     |      263
      n=100000, nnz=1000, index_len=1000, dim=1     |     73  |        500     |      261
      n=100000, nnz=1000, index_len=10000, dim=0    |    444  |        783     |     2570
      n=100000, nnz=1000, index_len=10000, dim=1    |    470  |        890     |     2590
      n=1000000, nnz=10000, index_len=100, dim=0    |     25  |       2400     |      270
      n=1000000, nnz=10000, index_len=100, dim=1    |    270  |       4000     |      269
      n=1000000, nnz=10000, index_len=1000, dim=0   |     74  |       2600     |     2620
      n=1000000, nnz=10000, index_len=1000, dim=1   |    464  |       3600     |     2640
      n=1000000, nnz=10000, index_len=10000, dim=0  |    635  |       3300     |    26400
      n=1000000, nnz=10000, index_len=10000, dim=1  |   1000  |       3960     |    26400
      n=10, nnz=100, index_len=100, dim=0           |     16  |        137     |       16
      n=10, nnz=100, index_len=100, dim=1           |     16  |        220     |       16
      n=10, nnz=100, index_len=1000, dim=0          |     63  |        238     |       81
      n=10, nnz=100, index_len=1000, dim=1          |     60  |        698     |       78
      n=10, nnz=100, index_len=10000, dim=0         |    480  |        940     |      862
      n=10, nnz=100, index_len=10000, dim=1         |    330  |       4930     |     1070
      n=10, nnz=1000, index_len=100, dim=0          |     60  |        200     |       73
      n=10, nnz=1000, index_len=100, dim=1          |     56  |        683     |       70
      n=10, nnz=1000, index_len=1000, dim=0         |    480  |        530     |     1050
      n=10, nnz=1000, index_len=1000, dim=1         |    330  |       4550     |     1368
      n=10, nnz=1000, index_len=10000, dim=0        |   3100  |       2900     |     9300
      n=10, nnz=1000, index_len=10000, dim=1        |   3400  |      46000     |     9100
      n=10, nnz=10000, index_len=100, dim=0         |    400  |        453     |      857
      n=10, nnz=10000, index_len=100, dim=1         |    400  |       4070     |     1730
      n=10, nnz=10000, index_len=1000, dim=0        |   2840  |       2600     |    13900
      n=10, nnz=10000, index_len=1000, dim=1        |   3700  |      40600     |    16000
      n=10, nnz=10000, index_len=10000, dim=0       |  83200  |      67400     |   160000
      n=10, nnz=10000, index_len=10000, dim=1       |  68000  |     528000     |   190000
      n=100, nnz=1000, index_len=100, dim=0         |     46  |        148     |       31
      n=100, nnz=1000, index_len=100, dim=1         |     45  |        242     |       37
      n=100, nnz=1000, index_len=1000, dim=0        |     68  |        248     |      240
      n=100, nnz=1000, index_len=1000, dim=1        |     66  |        755     |      290
      n=100, nnz=1000, index_len=10000, dim=0       |    370  |        802     |     2250
      n=100, nnz=1000, index_len=10000, dim=1       |    372  |       5430     |     2770
      n=100, nnz=10000, index_len=100, dim=0        |     82  |        210     |      224
      n=100, nnz=10000, index_len=100, dim=1        |     74  |        986     |      270
      n=100, nnz=10000, index_len=1000, dim=0       |    350  |        618     |     2600
      n=100, nnz=10000, index_len=1000, dim=1       |    370  |       4660     |     4560
      n=100, nnz=10000, index_len=10000, dim=0      |   3000  |       3400     |    41680
      n=100, nnz=10000, index_len=10000, dim=1      |   5000  |      47500     |    30400
      n=1000, nnz=10000, index_len=100, dim=0       |     71  |        160     |      185
      n=1000, nnz=10000, index_len=100, dim=1       |     64  |        516     |      190
      n=1000, nnz=10000, index_len=1000, dim=0      |    100  |        249     |     1740
      n=1000, nnz=10000, index_len=1000, dim=1      |     98  |       1030     |     1770
      n=1000, nnz=10000, index_len=10000, dim=0     |    600  |        808     |    18300
      n=1000, nnz=10000, index_len=10000, dim=1     |    663  |       5300     |    18500
      n=1000, nnz=100000, index_len=100, dim=0      |    160  |        258     |     1890
      n=1000, nnz=100000, index_len=100, dim=1      |    200  |       3620     |     2050
      n=1000, nnz=100000, index_len=1000, dim=0     |    500  |        580     |    18700
      n=1000, nnz=100000, index_len=1000, dim=1     |    640  |       7550     |    30000
      n=1000, nnz=100000, index_len=10000, dim=0    |   3400  |       3260     |   186000
      n=1000, nnz=100000, index_len=10000, dim=1    |   3600  |      49600     |   194000
      n=1000, nnz=1000000, index_len=100, dim=0     |    517  |        957     |    18700
      n=1000, nnz=1000000, index_len=100, dim=1     |    680  |      39600     |    37600
      n=1000, nnz=1000000, index_len=1000, dim=0    |   3600  |       4500     |   186000
      n=1000, nnz=1000000, index_len=1000, dim=1    |   5800  |      76400     |   190000
      n=1000, nnz=1000000, index_len=10000, dim=0   |  50000  |      67900     |  1800000
      n=1000, nnz=1000000, index_len=10000, dim=1   |  45000  |     570000     |  1900000

Times are in microseconds (us).

```

</details>

Pull Request resolved: #72710

Reviewed By: samdow

Differential Revision: D36282349

Pulled By: malfet

fbshipit-source-id: 3679ea4ebeeda4d200a441aef6d45b98303bc0c0
@cpuhrsch
Copy link
Contributor

cpuhrsch commented Jun 3, 2022

@nikitaved also just merged a native CUDA implementation for this with even more improvements.

@tvercaut - the latest nightly should now contain a much more efficient implement for index_select. Would you mind trying it for your use case again?

@tvercaut
Copy link
Author

tvercaut commented Jun 6, 2022

Looks great! I confirm the speed up is massive. Many thanks. I'm looking forward to seeing this in 1.12.

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Jun 8, 2022

@tvercaut - I'm very happy to hear that! Unfortunately the CUDA performance improvements didn't make the branch cut in time so are to be expected in 1.13, but the CPU performance improvements will be available in 1.12.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants