Skip to content

Commit

Permalink
Faster index_select for sparse COO tensors on CPU. (#72710)
Browse files Browse the repository at this point in the history
Summary:
Fixes #72212.

This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible.

Benchmark results.

<details>

<summary>Testing script</summary>

```python
import torch
import math
from IPython import get_ipython
from itertools import product
import pickle
from torch.utils.benchmark import Timer, Compare

torch.manual_seed(13)
#torch.set_num_threads(1)
ipython = get_ipython()

index_sizes = (100, 1000, 10000)
# specifies (n, nnz)
problem_dims = (
    # n > nnz
    (10000, 100),
    (100000, 1000),
    (1000000, 10000),
    # n < nnz
    (10, 100),
    (10, 1000),
    (10, 10000),
    (100, 1000),
    (100, 10000),
    (1000, 10000),
    (1000, 100000),
    (1000, 1000000),
    #(1000000, 1000000000),
)

def f(t, d, index):
    s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t)
    ss = s.index_select(d, index)
    return ss.coo()

name = "PR"
results = []

for (n, nnz), m in product(problem_dims, index_sizes):
    for d in (0, 1):
        if nnz < n:
            shape = (n, n)
        else:
            shape = (n, nnz // n) if d == 0 else (nnz // n, n)
        nrows, ncols = shape
        rowidx = torch.randint(low=0, high=nrows, size=(nnz,))
        colidx = torch.randint(low=0, high=ncols, size=(nnz,))
        itemidx = torch.vstack((rowidx, colidx))
        xvalues = torch.randn(nnz)
        index = torch.randint(low=0, high=n, size=(m,))

        SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce()
        smtp = "SparseX.index_select(d, index)"
        timer = Timer(smtp,
                      globals=globals(),
                      label="coo.index_select",
                      description=f"{name}: coo.index_select",
                      sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}",
                      num_threads=torch.get_num_threads())
        results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{name}_index_select.pickle", 'wb') as f:
    pickle.dump(results, f)

```

</details>

<details>

<summary>Gather results</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
        "PR",
        "torch_sparse",
        "master"
        ]

timers = []
for name in files:
    with open("{}_index_select.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()

```

</details>

<details>

<summary>PR/torch_sparse/master runtime comparison</summary>

```
[----------------------------------- coo.index_select ----------------------------------]
                                                    |    PR   |  torch_sparse  |   master
32 threads: -----------------------------------------------------------------------------
      n=10000, nnz=100, index_len=100, dim=0        |     14  |        140     |       10
      n=10000, nnz=100, index_len=100, dim=1        |     14  |        200     |       10
      n=10000, nnz=100, index_len=1000, dim=0       |     30  |        180     |       38
      n=10000, nnz=100, index_len=1000, dim=1       |     34  |        240     |       38
      n=10000, nnz=100, index_len=10000, dim=0      |    278  |        460     |      330
      n=10000, nnz=100, index_len=10000, dim=1      |    275  |        516     |      330
      n=100000, nnz=1000, index_len=100, dim=0      |     16  |        290     |       31
      n=100000, nnz=1000, index_len=100, dim=1      |     26  |        390     |       31
      n=100000, nnz=1000, index_len=1000, dim=0     |     45  |        405     |      263
      n=100000, nnz=1000, index_len=1000, dim=1     |     73  |        500     |      261
      n=100000, nnz=1000, index_len=10000, dim=0    |    444  |        783     |     2570
      n=100000, nnz=1000, index_len=10000, dim=1    |    470  |        890     |     2590
      n=1000000, nnz=10000, index_len=100, dim=0    |     25  |       2400     |      270
      n=1000000, nnz=10000, index_len=100, dim=1    |    270  |       4000     |      269
      n=1000000, nnz=10000, index_len=1000, dim=0   |     74  |       2600     |     2620
      n=1000000, nnz=10000, index_len=1000, dim=1   |    464  |       3600     |     2640
      n=1000000, nnz=10000, index_len=10000, dim=0  |    635  |       3300     |    26400
      n=1000000, nnz=10000, index_len=10000, dim=1  |   1000  |       3960     |    26400
      n=10, nnz=100, index_len=100, dim=0           |     16  |        137     |       16
      n=10, nnz=100, index_len=100, dim=1           |     16  |        220     |       16
      n=10, nnz=100, index_len=1000, dim=0          |     63  |        238     |       81
      n=10, nnz=100, index_len=1000, dim=1          |     60  |        698     |       78
      n=10, nnz=100, index_len=10000, dim=0         |    480  |        940     |      862
      n=10, nnz=100, index_len=10000, dim=1         |    330  |       4930     |     1070
      n=10, nnz=1000, index_len=100, dim=0          |     60  |        200     |       73
      n=10, nnz=1000, index_len=100, dim=1          |     56  |        683     |       70
      n=10, nnz=1000, index_len=1000, dim=0         |    480  |        530     |     1050
      n=10, nnz=1000, index_len=1000, dim=1         |    330  |       4550     |     1368
      n=10, nnz=1000, index_len=10000, dim=0        |   3100  |       2900     |     9300
      n=10, nnz=1000, index_len=10000, dim=1        |   3400  |      46000     |     9100
      n=10, nnz=10000, index_len=100, dim=0         |    400  |        453     |      857
      n=10, nnz=10000, index_len=100, dim=1         |    400  |       4070     |     1730
      n=10, nnz=10000, index_len=1000, dim=0        |   2840  |       2600     |    13900
      n=10, nnz=10000, index_len=1000, dim=1        |   3700  |      40600     |    16000
      n=10, nnz=10000, index_len=10000, dim=0       |  83200  |      67400     |   160000
      n=10, nnz=10000, index_len=10000, dim=1       |  68000  |     528000     |   190000
      n=100, nnz=1000, index_len=100, dim=0         |     46  |        148     |       31
      n=100, nnz=1000, index_len=100, dim=1         |     45  |        242     |       37
      n=100, nnz=1000, index_len=1000, dim=0        |     68  |        248     |      240
      n=100, nnz=1000, index_len=1000, dim=1        |     66  |        755     |      290
      n=100, nnz=1000, index_len=10000, dim=0       |    370  |        802     |     2250
      n=100, nnz=1000, index_len=10000, dim=1       |    372  |       5430     |     2770
      n=100, nnz=10000, index_len=100, dim=0        |     82  |        210     |      224
      n=100, nnz=10000, index_len=100, dim=1        |     74  |        986     |      270
      n=100, nnz=10000, index_len=1000, dim=0       |    350  |        618     |     2600
      n=100, nnz=10000, index_len=1000, dim=1       |    370  |       4660     |     4560
      n=100, nnz=10000, index_len=10000, dim=0      |   3000  |       3400     |    41680
      n=100, nnz=10000, index_len=10000, dim=1      |   5000  |      47500     |    30400
      n=1000, nnz=10000, index_len=100, dim=0       |     71  |        160     |      185
      n=1000, nnz=10000, index_len=100, dim=1       |     64  |        516     |      190
      n=1000, nnz=10000, index_len=1000, dim=0      |    100  |        249     |     1740
      n=1000, nnz=10000, index_len=1000, dim=1      |     98  |       1030     |     1770
      n=1000, nnz=10000, index_len=10000, dim=0     |    600  |        808     |    18300
      n=1000, nnz=10000, index_len=10000, dim=1     |    663  |       5300     |    18500
      n=1000, nnz=100000, index_len=100, dim=0      |    160  |        258     |     1890
      n=1000, nnz=100000, index_len=100, dim=1      |    200  |       3620     |     2050
      n=1000, nnz=100000, index_len=1000, dim=0     |    500  |        580     |    18700
      n=1000, nnz=100000, index_len=1000, dim=1     |    640  |       7550     |    30000
      n=1000, nnz=100000, index_len=10000, dim=0    |   3400  |       3260     |   186000
      n=1000, nnz=100000, index_len=10000, dim=1    |   3600  |      49600     |   194000
      n=1000, nnz=1000000, index_len=100, dim=0     |    517  |        957     |    18700
      n=1000, nnz=1000000, index_len=100, dim=1     |    680  |      39600     |    37600
      n=1000, nnz=1000000, index_len=1000, dim=0    |   3600  |       4500     |   186000
      n=1000, nnz=1000000, index_len=1000, dim=1    |   5800  |      76400     |   190000
      n=1000, nnz=1000000, index_len=10000, dim=0   |  50000  |      67900     |  1800000
      n=1000, nnz=1000000, index_len=10000, dim=1   |  45000  |     570000     |  1900000

Times are in microseconds (us).

```

</details>

Pull Request resolved: #72710

Reviewed By: samdow

Differential Revision: D36282349

Pulled By: malfet

fbshipit-source-id: 3679ea4ebeeda4d200a441aef6d45b98303bc0c0
  • Loading branch information
nikitaved authored and facebook-github-bot committed May 10, 2022
1 parent 35dc643 commit dfb217c
Show file tree
Hide file tree
Showing 4 changed files with 709 additions and 63 deletions.
Loading

0 comments on commit dfb217c

Please sign in to comment.