Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster index_select for sparse COO tensors on CPU. #72710

Closed
wants to merge 106 commits into from

Conversation

nikitaved
Copy link
Collaborator

@nikitaved nikitaved commented Feb 11, 2022

Fixes #72212.

This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible.

Benchmark results.

Testing script
import torch
import math
from IPython import get_ipython
from itertools import product
import pickle
from torch.utils.benchmark import Timer, Compare

torch.manual_seed(13)
#torch.set_num_threads(1)
ipython = get_ipython()

index_sizes = (100, 1000, 10000)
# specifies (n, nnz)
problem_dims = (
    # n > nnz
    (10000, 100),
    (100000, 1000),
    (1000000, 10000),
    # n < nnz
    (10, 100),
    (10, 1000),
    (10, 10000),
    (100, 1000),
    (100, 10000),
    (1000, 10000),
    (1000, 100000),
    (1000, 1000000),
    #(1000000, 1000000000),
)

def f(t, d, index):
    s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t)
    ss = s.index_select(d, index)
    return ss.coo()

name = "PR"
results = []

for (n, nnz), m in product(problem_dims, index_sizes):
    for d in (0, 1):
        if nnz < n:
            shape = (n, n)
        else:
            shape = (n, nnz // n) if d == 0 else (nnz // n, n)
        nrows, ncols = shape
        rowidx = torch.randint(low=0, high=nrows, size=(nnz,))
        colidx = torch.randint(low=0, high=ncols, size=(nnz,))
        itemidx = torch.vstack((rowidx, colidx))
        xvalues = torch.randn(nnz)
        index = torch.randint(low=0, high=n, size=(m,))

        SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce()
        smtp = "SparseX.index_select(d, index)"
        timer = Timer(smtp,
                      globals=globals(),
                      label="coo.index_select",
                      description=f"{name}: coo.index_select",
                      sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}",
                      num_threads=torch.get_num_threads())
        results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{name}_index_select.pickle", 'wb') as f:
    pickle.dump(results, f)
Gather results
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
        "PR",
        "torch_sparse",
        "master"
        ]

timers = []
for name in files:
    with open("{}_index_select.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
PR/torch_sparse/master runtime comparison
[----------------------------------- coo.index_select ----------------------------------]
                                                    |    PR   |  torch_sparse  |   master
32 threads: -----------------------------------------------------------------------------
      n=10000, nnz=100, index_len=100, dim=0        |     14  |        140     |       10
      n=10000, nnz=100, index_len=100, dim=1        |     14  |        200     |       10
      n=10000, nnz=100, index_len=1000, dim=0       |     30  |        180     |       38
      n=10000, nnz=100, index_len=1000, dim=1       |     34  |        240     |       38
      n=10000, nnz=100, index_len=10000, dim=0      |    278  |        460     |      330
      n=10000, nnz=100, index_len=10000, dim=1      |    275  |        516     |      330
      n=100000, nnz=1000, index_len=100, dim=0      |     16  |        290     |       31
      n=100000, nnz=1000, index_len=100, dim=1      |     26  |        390     |       31
      n=100000, nnz=1000, index_len=1000, dim=0     |     45  |        405     |      263
      n=100000, nnz=1000, index_len=1000, dim=1     |     73  |        500     |      261
      n=100000, nnz=1000, index_len=10000, dim=0    |    444  |        783     |     2570
      n=100000, nnz=1000, index_len=10000, dim=1    |    470  |        890     |     2590
      n=1000000, nnz=10000, index_len=100, dim=0    |     25  |       2400     |      270
      n=1000000, nnz=10000, index_len=100, dim=1    |    270  |       4000     |      269
      n=1000000, nnz=10000, index_len=1000, dim=0   |     74  |       2600     |     2620
      n=1000000, nnz=10000, index_len=1000, dim=1   |    464  |       3600     |     2640
      n=1000000, nnz=10000, index_len=10000, dim=0  |    635  |       3300     |    26400
      n=1000000, nnz=10000, index_len=10000, dim=1  |   1000  |       3960     |    26400
      n=10, nnz=100, index_len=100, dim=0           |     16  |        137     |       16
      n=10, nnz=100, index_len=100, dim=1           |     16  |        220     |       16
      n=10, nnz=100, index_len=1000, dim=0          |     63  |        238     |       81
      n=10, nnz=100, index_len=1000, dim=1          |     60  |        698     |       78
      n=10, nnz=100, index_len=10000, dim=0         |    480  |        940     |      862
      n=10, nnz=100, index_len=10000, dim=1         |    330  |       4930     |     1070
      n=10, nnz=1000, index_len=100, dim=0          |     60  |        200     |       73
      n=10, nnz=1000, index_len=100, dim=1          |     56  |        683     |       70
      n=10, nnz=1000, index_len=1000, dim=0         |    480  |        530     |     1050
      n=10, nnz=1000, index_len=1000, dim=1         |    330  |       4550     |     1368
      n=10, nnz=1000, index_len=10000, dim=0        |   3100  |       2900     |     9300
      n=10, nnz=1000, index_len=10000, dim=1        |   3400  |      46000     |     9100
      n=10, nnz=10000, index_len=100, dim=0         |    400  |        453     |      857
      n=10, nnz=10000, index_len=100, dim=1         |    400  |       4070     |     1730
      n=10, nnz=10000, index_len=1000, dim=0        |   2840  |       2600     |    13900
      n=10, nnz=10000, index_len=1000, dim=1        |   3700  |      40600     |    16000
      n=10, nnz=10000, index_len=10000, dim=0       |  83200  |      67400     |   160000
      n=10, nnz=10000, index_len=10000, dim=1       |  68000  |     528000     |   190000
      n=100, nnz=1000, index_len=100, dim=0         |     46  |        148     |       31
      n=100, nnz=1000, index_len=100, dim=1         |     45  |        242     |       37
      n=100, nnz=1000, index_len=1000, dim=0        |     68  |        248     |      240
      n=100, nnz=1000, index_len=1000, dim=1        |     66  |        755     |      290
      n=100, nnz=1000, index_len=10000, dim=0       |    370  |        802     |     2250
      n=100, nnz=1000, index_len=10000, dim=1       |    372  |       5430     |     2770
      n=100, nnz=10000, index_len=100, dim=0        |     82  |        210     |      224
      n=100, nnz=10000, index_len=100, dim=1        |     74  |        986     |      270
      n=100, nnz=10000, index_len=1000, dim=0       |    350  |        618     |     2600
      n=100, nnz=10000, index_len=1000, dim=1       |    370  |       4660     |     4560
      n=100, nnz=10000, index_len=10000, dim=0      |   3000  |       3400     |    41680
      n=100, nnz=10000, index_len=10000, dim=1      |   5000  |      47500     |    30400
      n=1000, nnz=10000, index_len=100, dim=0       |     71  |        160     |      185
      n=1000, nnz=10000, index_len=100, dim=1       |     64  |        516     |      190
      n=1000, nnz=10000, index_len=1000, dim=0      |    100  |        249     |     1740
      n=1000, nnz=10000, index_len=1000, dim=1      |     98  |       1030     |     1770
      n=1000, nnz=10000, index_len=10000, dim=0     |    600  |        808     |    18300
      n=1000, nnz=10000, index_len=10000, dim=1     |    663  |       5300     |    18500
      n=1000, nnz=100000, index_len=100, dim=0      |    160  |        258     |     1890
      n=1000, nnz=100000, index_len=100, dim=1      |    200  |       3620     |     2050
      n=1000, nnz=100000, index_len=1000, dim=0     |    500  |        580     |    18700
      n=1000, nnz=100000, index_len=1000, dim=1     |    640  |       7550     |    30000
      n=1000, nnz=100000, index_len=10000, dim=0    |   3400  |       3260     |   186000
      n=1000, nnz=100000, index_len=10000, dim=1    |   3600  |      49600     |   194000
      n=1000, nnz=1000000, index_len=100, dim=0     |    517  |        957     |    18700
      n=1000, nnz=1000000, index_len=100, dim=1     |    680  |      39600     |    37600
      n=1000, nnz=1000000, index_len=1000, dim=0    |   3600  |       4500     |   186000
      n=1000, nnz=1000000, index_len=1000, dim=1    |   5800  |      76400     |   190000
      n=1000, nnz=1000000, index_len=10000, dim=0   |  50000  |      67900     |  1800000
      n=1000, nnz=1000000, index_len=10000, dim=1   |  45000  |     570000     |  1900000

Times are in microseconds (us).


@nikitaved nikitaved added module: sparse Related to torch.sparse topic: performance topic category labels Feb 11, 2022
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 11, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/c0e285267a0df47dac3ac5cf0ad2b9ebebedb87e/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk, ciflow/xla ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Feb 11, 2022

🔗 Helpful links

❌ 1 New Failures, 1 Flaky Failures

As of commit e1a0978 (more details on the Dr. CI page):

Expand to see more

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-10T16:10:32.8198769Z RuntimeError: test_ops failed! Received signal: SIGIOT
2022-05-10T16:10:31.1736713Z   test_variant_consistency_eager_dsplit_cuda_complex64 (__main__.TestCommonCUDA) ... ok (0.013s)
2022-05-10T16:10:31.1822785Z   test_variant_consistency_eager_dsplit_cuda_float32 (__main__.TestCommonCUDA) ... ok (0.008s)
2022-05-10T16:10:31.1902452Z   test_variant_consistency_eager_dstack_cuda_complex64 (__main__.TestCommonCUDA) ... ok (0.008s)
2022-05-10T16:10:31.1966031Z   test_variant_consistency_eager_dstack_cuda_float32 (__main__.TestCommonCUDA) ... ok (0.006s)
2022-05-10T16:10:31.2142444Z   test_variant_consistency_eager_eig_cuda_complex64 (__main__.TestCommonCUDA) ... python: /opt/conda/conda-bld/magma-cuda113_1619629459349/work/interface_cuda/interface.cpp:806: void magma_queue_create_internal(magma_device_t, magma_queue**, const char*, const char*, int): Assertion `queue->dAarray__ != __null' failed.
2022-05-10T16:10:32.8192065Z Traceback (most recent call last):
2022-05-10T16:10:32.8192504Z   File "test/run_test.py", line 1072, in <module>
2022-05-10T16:10:32.8195622Z     main()
2022-05-10T16:10:32.8196154Z   File "test/run_test.py", line 1050, in main
2022-05-10T16:10:32.8198379Z     raise RuntimeError(err_message)
2022-05-10T16:10:32.8198769Z RuntimeError: test_ops failed! Received signal: SIGIOT
2022-05-10T16:10:34.1009633Z + cleanup
2022-05-10T16:10:34.1010222Z + retcode=1
2022-05-10T16:10:34.1010635Z + set +x
2022-05-10T16:10:34.1061888Z ##[error]Process completed with exit code 1.
2022-05-10T16:10:34.1118991Z ##[group]Run pytorch/pytorch/.github/actions/get-workflow-job-id@master
2022-05-10T16:10:34.1119356Z with:
2022-05-10T16:10:34.1120015Z   github-token: ***
2022-05-10T16:10:34.1120273Z env:
2022-05-10T16:10:34.1120483Z   IN_CI: 1
2022-05-10T16:10:34.1120720Z   IS_GHA: 1

❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See GitHub Actions build trunk / linux-bionic-rocm5.1-py3.7-distributed / test (distributed, 1, 2, linux.rocm.gpu) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun) ❄️

2022-05-10T17:28:48.8927975Z RuntimeError: Proc...ated or timed out after 100.03153944015503 seconds
2022-05-10T17:28:48.8914260Z ======================================================================
2022-05-10T17:28:48.8915363Z ERROR [100.495s]: test_forward_overlap (__main__.TestForwardOverlapWorldSizeOne)
2022-05-10T17:28:48.8917265Z ----------------------------------------------------------------------
2022-05-10T17:28:48.8918363Z Traceback (most recent call last):
2022-05-10T17:28:48.8920143Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py", line 494, in wrapper
2022-05-10T17:28:48.8921333Z     self._join_processes(fn)
2022-05-10T17:28:48.8922951Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py", line 717, in _join_processes
2022-05-10T17:28:48.8924208Z     self._check_return_codes(elapsed_time)
2022-05-10T17:28:48.8925819Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py", line 769, in _check_return_codes
2022-05-10T17:28:48.8926972Z     i, elapsed_time
2022-05-10T17:28:48.8927975Z RuntimeError: Process 0 terminated or timed out after 100.03153944015503 seconds
2022-05-10T17:28:48.8928642Z 
2022-05-10T17:28:48.8929463Z ----------------------------------------------------------------------
2022-05-10T17:28:48.8930458Z Ran 5 tests in 247.779s
2022-05-10T17:28:48.8930925Z 
2022-05-10T17:28:48.8931344Z FAILED (errors=1, unexpected successes=3)
2022-05-10T17:28:48.8931900Z 
2022-05-10T17:28:48.8932255Z Generating XML reports...
2022-05-10T17:28:48.9043272Z Generated XML report: test-reports/python-unittest/distributed.fsdp.test_fsdp_overlap/TEST-TestForwardOverlapWorldSizeOne-20220510172441.xml
2022-05-10T17:28:48.9045963Z Generated XML report: test-reports/python-unittest/distributed.fsdp.test_fsdp_overlap/TEST-TestForwardOverlapWorldSizeTwo-20220510172441.xml
2022-05-10T17:28:49.9994439Z Traceback (most recent call last):

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@nikitaved nikitaved changed the title [WIP] Faster index_select for sparse COO tensors. Faster index_select for sparse COO tensors. Mar 3, 2022
@nikitaved nikitaved requested a review from cpuhrsch March 3, 2022 11:45
@nikitaved
Copy link
Collaborator Author

nikitaved commented Mar 3, 2022

@cpuhrsch , could you please have a look? I will also put up some benchmarking results, but it is much faster than the previous implementation.

@dagitses dagitses added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 3, 2022
@cpuhrsch
Copy link
Contributor

cpuhrsch commented Mar 3, 2022

@nikitaved - Thanks for sending this! Looks great! Do you have some timings for some sample inputs to verify the perf gains not just analytically?

@github-actions
Copy link

github-actions bot commented May 9, 2022

Hey @nikitaved.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@malfet
Copy link
Contributor

malfet commented May 10, 2022

@pytorchbot revert this, as it breaks internal builds by introducing unused capture:

stderr: buck-out/gen/fe3a39b8/aten/src/ATen/native/TensorShape.cpp:1433:50: error: lambda capture 'self' is not used [-Werror,-Wunused-lambda-capture]
    const auto nneg_index = [&index, index_len, &self, size, dim](void) -> Tensor {
                                              ~~~^~~~
buck-out/gen/fe3a39b8/aten/src/ATen/native/TensorShape.cpp:1433:62: error: lambda capture 'dim' is not used [-Werror,-Wunused-lambda-capture]
    const auto nneg_index = [&index, index_len, &self, size, dim](void) -> Tensor {
                                                           ~~^~~

@malfet
Copy link
Contributor

malfet commented May 10, 2022

@pytorchbot revert this as it breaks internal builds

@malfet
Copy link
Contributor

malfet commented May 10, 2022

I can reproduce the failure by passing -DSTRIP_ERROR_MESSAGES, which is a reasonable thing to do for mobile builds, which does not seem to be covered by OSS CI matrix.

pytorchmergebot added a commit that referenced this pull request May 10, 2022
@malfet malfet reopened this May 10, 2022
Tentative fix for internal builds
@facebook-github-bot
Copy link
Contributor

@malfet has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@malfet
Copy link
Contributor

malfet commented May 10, 2022

Importing to check for internal build failures with suggested changes, will reland if passes

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge this

(Initiating merge automatically since Phabricator Diff has merged)

facebook-github-bot pushed a commit that referenced this pull request May 10, 2022
Summary:
Fixes #72212.

This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible.

Benchmark results.

<details>

<summary>Testing script</summary>

```python
import torch
import math
from IPython import get_ipython
from itertools import product
import pickle
from torch.utils.benchmark import Timer, Compare

torch.manual_seed(13)
#torch.set_num_threads(1)
ipython = get_ipython()

index_sizes = (100, 1000, 10000)
# specifies (n, nnz)
problem_dims = (
    # n > nnz
    (10000, 100),
    (100000, 1000),
    (1000000, 10000),
    # n < nnz
    (10, 100),
    (10, 1000),
    (10, 10000),
    (100, 1000),
    (100, 10000),
    (1000, 10000),
    (1000, 100000),
    (1000, 1000000),
    #(1000000, 1000000000),
)

def f(t, d, index):
    s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t)
    ss = s.index_select(d, index)
    return ss.coo()

name = "PR"
results = []

for (n, nnz), m in product(problem_dims, index_sizes):
    for d in (0, 1):
        if nnz < n:
            shape = (n, n)
        else:
            shape = (n, nnz // n) if d == 0 else (nnz // n, n)
        nrows, ncols = shape
        rowidx = torch.randint(low=0, high=nrows, size=(nnz,))
        colidx = torch.randint(low=0, high=ncols, size=(nnz,))
        itemidx = torch.vstack((rowidx, colidx))
        xvalues = torch.randn(nnz)
        index = torch.randint(low=0, high=n, size=(m,))

        SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce()
        smtp = "SparseX.index_select(d, index)"
        timer = Timer(smtp,
                      globals=globals(),
                      label="coo.index_select",
                      description=f"{name}: coo.index_select",
                      sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}",
                      num_threads=torch.get_num_threads())
        results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{name}_index_select.pickle", 'wb') as f:
    pickle.dump(results, f)

```

</details>

<details>

<summary>Gather results</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
        "PR",
        "torch_sparse",
        "master"
        ]

timers = []
for name in files:
    with open("{}_index_select.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()

```

</details>

<details>

<summary>PR/torch_sparse/master runtime comparison</summary>

```
[----------------------------------- coo.index_select ----------------------------------]
                                                    |    PR   |  torch_sparse  |   master
32 threads: -----------------------------------------------------------------------------
      n=10000, nnz=100, index_len=100, dim=0        |     14  |        140     |       10
      n=10000, nnz=100, index_len=100, dim=1        |     14  |        200     |       10
      n=10000, nnz=100, index_len=1000, dim=0       |     30  |        180     |       38
      n=10000, nnz=100, index_len=1000, dim=1       |     34  |        240     |       38
      n=10000, nnz=100, index_len=10000, dim=0      |    278  |        460     |      330
      n=10000, nnz=100, index_len=10000, dim=1      |    275  |        516     |      330
      n=100000, nnz=1000, index_len=100, dim=0      |     16  |        290     |       31
      n=100000, nnz=1000, index_len=100, dim=1      |     26  |        390     |       31
      n=100000, nnz=1000, index_len=1000, dim=0     |     45  |        405     |      263
      n=100000, nnz=1000, index_len=1000, dim=1     |     73  |        500     |      261
      n=100000, nnz=1000, index_len=10000, dim=0    |    444  |        783     |     2570
      n=100000, nnz=1000, index_len=10000, dim=1    |    470  |        890     |     2590
      n=1000000, nnz=10000, index_len=100, dim=0    |     25  |       2400     |      270
      n=1000000, nnz=10000, index_len=100, dim=1    |    270  |       4000     |      269
      n=1000000, nnz=10000, index_len=1000, dim=0   |     74  |       2600     |     2620
      n=1000000, nnz=10000, index_len=1000, dim=1   |    464  |       3600     |     2640
      n=1000000, nnz=10000, index_len=10000, dim=0  |    635  |       3300     |    26400
      n=1000000, nnz=10000, index_len=10000, dim=1  |   1000  |       3960     |    26400
      n=10, nnz=100, index_len=100, dim=0           |     16  |        137     |       16
      n=10, nnz=100, index_len=100, dim=1           |     16  |        220     |       16
      n=10, nnz=100, index_len=1000, dim=0          |     63  |        238     |       81
      n=10, nnz=100, index_len=1000, dim=1          |     60  |        698     |       78
      n=10, nnz=100, index_len=10000, dim=0         |    480  |        940     |      862
      n=10, nnz=100, index_len=10000, dim=1         |    330  |       4930     |     1070
      n=10, nnz=1000, index_len=100, dim=0          |     60  |        200     |       73
      n=10, nnz=1000, index_len=100, dim=1          |     56  |        683     |       70
      n=10, nnz=1000, index_len=1000, dim=0         |    480  |        530     |     1050
      n=10, nnz=1000, index_len=1000, dim=1         |    330  |       4550     |     1368
      n=10, nnz=1000, index_len=10000, dim=0        |   3100  |       2900     |     9300
      n=10, nnz=1000, index_len=10000, dim=1        |   3400  |      46000     |     9100
      n=10, nnz=10000, index_len=100, dim=0         |    400  |        453     |      857
      n=10, nnz=10000, index_len=100, dim=1         |    400  |       4070     |     1730
      n=10, nnz=10000, index_len=1000, dim=0        |   2840  |       2600     |    13900
      n=10, nnz=10000, index_len=1000, dim=1        |   3700  |      40600     |    16000
      n=10, nnz=10000, index_len=10000, dim=0       |  83200  |      67400     |   160000
      n=10, nnz=10000, index_len=10000, dim=1       |  68000  |     528000     |   190000
      n=100, nnz=1000, index_len=100, dim=0         |     46  |        148     |       31
      n=100, nnz=1000, index_len=100, dim=1         |     45  |        242     |       37
      n=100, nnz=1000, index_len=1000, dim=0        |     68  |        248     |      240
      n=100, nnz=1000, index_len=1000, dim=1        |     66  |        755     |      290
      n=100, nnz=1000, index_len=10000, dim=0       |    370  |        802     |     2250
      n=100, nnz=1000, index_len=10000, dim=1       |    372  |       5430     |     2770
      n=100, nnz=10000, index_len=100, dim=0        |     82  |        210     |      224
      n=100, nnz=10000, index_len=100, dim=1        |     74  |        986     |      270
      n=100, nnz=10000, index_len=1000, dim=0       |    350  |        618     |     2600
      n=100, nnz=10000, index_len=1000, dim=1       |    370  |       4660     |     4560
      n=100, nnz=10000, index_len=10000, dim=0      |   3000  |       3400     |    41680
      n=100, nnz=10000, index_len=10000, dim=1      |   5000  |      47500     |    30400
      n=1000, nnz=10000, index_len=100, dim=0       |     71  |        160     |      185
      n=1000, nnz=10000, index_len=100, dim=1       |     64  |        516     |      190
      n=1000, nnz=10000, index_len=1000, dim=0      |    100  |        249     |     1740
      n=1000, nnz=10000, index_len=1000, dim=1      |     98  |       1030     |     1770
      n=1000, nnz=10000, index_len=10000, dim=0     |    600  |        808     |    18300
      n=1000, nnz=10000, index_len=10000, dim=1     |    663  |       5300     |    18500
      n=1000, nnz=100000, index_len=100, dim=0      |    160  |        258     |     1890
      n=1000, nnz=100000, index_len=100, dim=1      |    200  |       3620     |     2050
      n=1000, nnz=100000, index_len=1000, dim=0     |    500  |        580     |    18700
      n=1000, nnz=100000, index_len=1000, dim=1     |    640  |       7550     |    30000
      n=1000, nnz=100000, index_len=10000, dim=0    |   3400  |       3260     |   186000
      n=1000, nnz=100000, index_len=10000, dim=1    |   3600  |      49600     |   194000
      n=1000, nnz=1000000, index_len=100, dim=0     |    517  |        957     |    18700
      n=1000, nnz=1000000, index_len=100, dim=1     |    680  |      39600     |    37600
      n=1000, nnz=1000000, index_len=1000, dim=0    |   3600  |       4500     |   186000
      n=1000, nnz=1000000, index_len=1000, dim=1    |   5800  |      76400     |   190000
      n=1000, nnz=1000000, index_len=10000, dim=0   |  50000  |      67900     |  1800000
      n=1000, nnz=1000000, index_len=10000, dim=1   |  45000  |     570000     |  1900000

Times are in microseconds (us).

```

</details>

Pull Request resolved: #72710

Reviewed By: samdow

Differential Revision: D36282349

Pulled By: malfet

fbshipit-source-id: 3679ea4ebeeda4d200a441aef6d45b98303bc0c0
facebook-github-bot pushed a commit that referenced this pull request May 13, 2022
Summary:
Fixes #72212.

This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible.

Benchmark results.

<details>

<summary>Testing script</summary>

```python
import torch
import math
from IPython import get_ipython
from itertools import product
import pickle
from torch.utils.benchmark import Timer, Compare

torch.manual_seed(13)
#torch.set_num_threads(1)
ipython = get_ipython()

index_sizes = (100, 1000, 10000)
# specifies (n, nnz)
problem_dims = (
    # n > nnz
    (10000, 100),
    (100000, 1000),
    (1000000, 10000),
    # n < nnz
    (10, 100),
    (10, 1000),
    (10, 10000),
    (100, 1000),
    (100, 10000),
    (1000, 10000),
    (1000, 100000),
    (1000, 1000000),
    #(1000000, 1000000000),
)

def f(t, d, index):
    s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t)
    ss = s.index_select(d, index)
    return ss.coo()

name = "PR"
results = []

for (n, nnz), m in product(problem_dims, index_sizes):
    for d in (0, 1):
        if nnz < n:
            shape = (n, n)
        else:
            shape = (n, nnz // n) if d == 0 else (nnz // n, n)
        nrows, ncols = shape
        rowidx = torch.randint(low=0, high=nrows, size=(nnz,))
        colidx = torch.randint(low=0, high=ncols, size=(nnz,))
        itemidx = torch.vstack((rowidx, colidx))
        xvalues = torch.randn(nnz)
        index = torch.randint(low=0, high=n, size=(m,))

        SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce()
        smtp = "SparseX.index_select(d, index)"
        timer = Timer(smtp,
                      globals=globals(),
                      label="coo.index_select",
                      description=f"{name}: coo.index_select",
                      sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}",
                      num_threads=torch.get_num_threads())
        results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{name}_index_select.pickle", 'wb') as f:
    pickle.dump(results, f)

```

</details>

<details>

<summary>Gather results</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
        "PR",
        "torch_sparse",
        "master"
        ]

timers = []
for name in files:
    with open("{}_index_select.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()

```

</details>

<details>

<summary>PR/torch_sparse/master runtime comparison</summary>

```
[----------------------------------- coo.index_select ----------------------------------]
                                                    |    PR   |  torch_sparse  |   master
32 threads: -----------------------------------------------------------------------------
      n=10000, nnz=100, index_len=100, dim=0        |     14  |        140     |       10
      n=10000, nnz=100, index_len=100, dim=1        |     14  |        200     |       10
      n=10000, nnz=100, index_len=1000, dim=0       |     30  |        180     |       38
      n=10000, nnz=100, index_len=1000, dim=1       |     34  |        240     |       38
      n=10000, nnz=100, index_len=10000, dim=0      |    278  |        460     |      330
      n=10000, nnz=100, index_len=10000, dim=1      |    275  |        516     |      330
      n=100000, nnz=1000, index_len=100, dim=0      |     16  |        290     |       31
      n=100000, nnz=1000, index_len=100, dim=1      |     26  |        390     |       31
      n=100000, nnz=1000, index_len=1000, dim=0     |     45  |        405     |      263
      n=100000, nnz=1000, index_len=1000, dim=1     |     73  |        500     |      261
      n=100000, nnz=1000, index_len=10000, dim=0    |    444  |        783     |     2570
      n=100000, nnz=1000, index_len=10000, dim=1    |    470  |        890     |     2590
      n=1000000, nnz=10000, index_len=100, dim=0    |     25  |       2400     |      270
      n=1000000, nnz=10000, index_len=100, dim=1    |    270  |       4000     |      269
      n=1000000, nnz=10000, index_len=1000, dim=0   |     74  |       2600     |     2620
      n=1000000, nnz=10000, index_len=1000, dim=1   |    464  |       3600     |     2640
      n=1000000, nnz=10000, index_len=10000, dim=0  |    635  |       3300     |    26400
      n=1000000, nnz=10000, index_len=10000, dim=1  |   1000  |       3960     |    26400
      n=10, nnz=100, index_len=100, dim=0           |     16  |        137     |       16
      n=10, nnz=100, index_len=100, dim=1           |     16  |        220     |       16
      n=10, nnz=100, index_len=1000, dim=0          |     63  |        238     |       81
      n=10, nnz=100, index_len=1000, dim=1          |     60  |        698     |       78
      n=10, nnz=100, index_len=10000, dim=0         |    480  |        940     |      862
      n=10, nnz=100, index_len=10000, dim=1         |    330  |       4930     |     1070
      n=10, nnz=1000, index_len=100, dim=0          |     60  |        200     |       73
      n=10, nnz=1000, index_len=100, dim=1          |     56  |        683     |       70
      n=10, nnz=1000, index_len=1000, dim=0         |    480  |        530     |     1050
      n=10, nnz=1000, index_len=1000, dim=1         |    330  |       4550     |     1368
      n=10, nnz=1000, index_len=10000, dim=0        |   3100  |       2900     |     9300
      n=10, nnz=1000, index_len=10000, dim=1        |   3400  |      46000     |     9100
      n=10, nnz=10000, index_len=100, dim=0         |    400  |        453     |      857
      n=10, nnz=10000, index_len=100, dim=1         |    400  |       4070     |     1730
      n=10, nnz=10000, index_len=1000, dim=0        |   2840  |       2600     |    13900
      n=10, nnz=10000, index_len=1000, dim=1        |   3700  |      40600     |    16000
      n=10, nnz=10000, index_len=10000, dim=0       |  83200  |      67400     |   160000
      n=10, nnz=10000, index_len=10000, dim=1       |  68000  |     528000     |   190000
      n=100, nnz=1000, index_len=100, dim=0         |     46  |        148     |       31
      n=100, nnz=1000, index_len=100, dim=1         |     45  |        242     |       37
      n=100, nnz=1000, index_len=1000, dim=0        |     68  |        248     |      240
      n=100, nnz=1000, index_len=1000, dim=1        |     66  |        755     |      290
      n=100, nnz=1000, index_len=10000, dim=0       |    370  |        802     |     2250
      n=100, nnz=1000, index_len=10000, dim=1       |    372  |       5430     |     2770
      n=100, nnz=10000, index_len=100, dim=0        |     82  |        210     |      224
      n=100, nnz=10000, index_len=100, dim=1        |     74  |        986     |      270
      n=100, nnz=10000, index_len=1000, dim=0       |    350  |        618     |     2600
      n=100, nnz=10000, index_len=1000, dim=1       |    370  |       4660     |     4560
      n=100, nnz=10000, index_len=10000, dim=0      |   3000  |       3400     |    41680
      n=100, nnz=10000, index_len=10000, dim=1      |   5000  |      47500     |    30400
      n=1000, nnz=10000, index_len=100, dim=0       |     71  |        160     |      185
      n=1000, nnz=10000, index_len=100, dim=1       |     64  |        516     |      190
      n=1000, nnz=10000, index_len=1000, dim=0      |    100  |        249     |     1740
      n=1000, nnz=10000, index_len=1000, dim=1      |     98  |       1030     |     1770
      n=1000, nnz=10000, index_len=10000, dim=0     |    600  |        808     |    18300
      n=1000, nnz=10000, index_len=10000, dim=1     |    663  |       5300     |    18500
      n=1000, nnz=100000, index_len=100, dim=0      |    160  |        258     |     1890
      n=1000, nnz=100000, index_len=100, dim=1      |    200  |       3620     |     2050
      n=1000, nnz=100000, index_len=1000, dim=0     |    500  |        580     |    18700
      n=1000, nnz=100000, index_len=1000, dim=1     |    640  |       7550     |    30000
      n=1000, nnz=100000, index_len=10000, dim=0    |   3400  |       3260     |   186000
      n=1000, nnz=100000, index_len=10000, dim=1    |   3600  |      49600     |   194000
      n=1000, nnz=1000000, index_len=100, dim=0     |    517  |        957     |    18700
      n=1000, nnz=1000000, index_len=100, dim=1     |    680  |      39600     |    37600
      n=1000, nnz=1000000, index_len=1000, dim=0    |   3600  |       4500     |   186000
      n=1000, nnz=1000000, index_len=1000, dim=1    |   5800  |      76400     |   190000
      n=1000, nnz=1000000, index_len=10000, dim=0   |  50000  |      67900     |  1800000
      n=1000, nnz=1000000, index_len=10000, dim=1   |  45000  |     570000     |  1900000

Times are in microseconds (us).

```

</details>

Pull Request resolved: #72710

Reviewed By: samdow

Differential Revision: D36282349

Pulled By: malfet

fbshipit-source-id: 3679ea4ebeeda4d200a441aef6d45b98303bc0c0
pytorchmergebot pushed a commit that referenced this pull request Jun 1, 2022
Brings a native CUDA implementation for `index_select`. Master silently converts CUDA tensors to CPU for CUDA support.

Case `nnz >> size` could be optimized similar to how #72710 is doing that.

Some benchmarks:
<details>

<summary>PR/torch_sparse/master</summary>

```
[------------------------------- cuda coo.index_select -------------------------------]
                                                    |   PR   |  torch_sparse  |  master
32 threads: ---------------------------------------------------------------------------
      n=10000, nnz=100, index_len=100, dim=0        |    96  |       327      |     70
      n=10000, nnz=100, index_len=100, dim=1        |   120  |       505      |     74
      n=10000, nnz=100, index_len=1000, dim=0       |    90  |       333      |     93
      n=10000, nnz=100, index_len=1000, dim=1       |   120  |       499      |     98
      n=10000, nnz=100, index_len=10000, dim=0      |    92  |       331      |    350
      n=10000, nnz=100, index_len=10000, dim=1      |   100  |       506      |    352
      n=100000, nnz=1000, index_len=100, dim=0      |    53  |       274      |     60
      n=100000, nnz=1000, index_len=100, dim=1      |    90  |       368      |     71
      n=100000, nnz=1000, index_len=1000, dim=0     |    93  |       332      |    100
      n=100000, nnz=1000, index_len=1000, dim=1     |   130  |       501      |    140
      n=100000, nnz=1000, index_len=10000, dim=0    |   100  |       341      |    522
      n=100000, nnz=1000, index_len=10000, dim=1    |   130  |       530      |    549
      n=1000000, nnz=10000, index_len=100, dim=0    |    90  |       429      |    110
      n=1000000, nnz=10000, index_len=100, dim=1    |   296  |       810      |    355
      n=1000000, nnz=10000, index_len=1000, dim=0   |   100  |       435      |    170
      n=1000000, nnz=10000, index_len=1000, dim=1   |   309  |       830      |    548
      n=1000000, nnz=10000, index_len=10000, dim=0  |   110  |       446      |    750
      n=1000000, nnz=10000, index_len=10000, dim=1  |   310  |       830      |   1000
      n=10, nnz=100, index_len=100, dim=0           |    90  |       333      |     74
      n=10, nnz=100, index_len=100, dim=1           |   100  |       497      |     78
      n=10, nnz=100, index_len=1000, dim=0          |    90  |       329      |    140
      n=10, nnz=100, index_len=1000, dim=1          |   100  |       800      |    100
      n=10, nnz=100, index_len=10000, dim=0         |    93  |       340      |    900
      n=10, nnz=100, index_len=10000, dim=1         |   120  |       800      |    489
      n=10, nnz=1000, index_len=100, dim=0          |    90  |       321      |    140
      n=10, nnz=1000, index_len=100, dim=1          |   100  |       680      |    140
      n=10, nnz=1000, index_len=1000, dim=0         |   110  |       349      |    670
      n=10, nnz=1000, index_len=1000, dim=1         |   130  |       740      |    800
      n=10, nnz=1000, index_len=10000, dim=0        |   302  |       503      |   4882
      n=10, nnz=1000, index_len=10000, dim=1        |   325  |      2257      |   5262
      n=10, nnz=10000, index_len=100, dim=0         |   229  |       349      |    810
      n=10, nnz=10000, index_len=100, dim=1         |   433  |       870      |    700
      n=10, nnz=10000, index_len=1000, dim=0        |   666  |       502      |   5581
      n=10, nnz=10000, index_len=1000, dim=1        |   826  |      2379      |   4820
      n=10, nnz=10000, index_len=10000, dim=0       |  2534  |      2700      |  80000
      n=10, nnz=10000, index_len=10000, dim=1       |  2723  |     18540      |  80000
      n=100, nnz=1000, index_len=100, dim=0         |    94  |       324      |    110
      n=100, nnz=1000, index_len=100, dim=1         |   100  |       499      |    110
      n=100, nnz=1000, index_len=1000, dim=0        |    96  |       337      |    150
      n=100, nnz=1000, index_len=1000, dim=1        |   130  |       800      |    140
      n=100, nnz=1000, index_len=10000, dim=0       |   100  |       346      |    900
      n=100, nnz=1000, index_len=10000, dim=1       |   130  |       760      |    900
      n=100, nnz=10000, index_len=100, dim=0        |    90  |       323      |    190
      n=100, nnz=10000, index_len=100, dim=1        |   279  |       800      |    180
      n=100, nnz=10000, index_len=1000, dim=0       |   110  |       339      |    781
      n=100, nnz=10000, index_len=1000, dim=1       |   294  |       870      |    800
      n=100, nnz=10000, index_len=10000, dim=0      |   315  |       505      |   6264
      n=100, nnz=10000, index_len=10000, dim=1      |   497  |      2398      |   5404
      n=1000, nnz=10000, index_len=100, dim=0       |    90  |       333      |    160
      n=1000, nnz=10000, index_len=100, dim=1       |   279  |       635      |    150
      n=1000, nnz=10000, index_len=1000, dim=0      |   100  |       328      |    215
      n=1000, nnz=10000, index_len=1000, dim=1      |   287  |       810      |    207
      n=1000, nnz=10000, index_len=10000, dim=0     |   100  |       339      |    900
      n=1000, nnz=10000, index_len=10000, dim=1     |   291  |       880      |   1000
      n=1000, nnz=100000, index_len=100, dim=0      |    92  |       358      |    435
      n=1000, nnz=100000, index_len=100, dim=1      |   302  |       900      |    530
      n=1000, nnz=100000, index_len=1000, dim=0     |   130  |       360      |   1000
      n=1000, nnz=100000, index_len=1000, dim=1     |   329  |       930      |   1200
      n=1000, nnz=100000, index_len=10000, dim=0    |   343  |       530      |   7000
      n=1000, nnz=100000, index_len=10000, dim=1    |   545  |      2446      |   6100
      n=1000, nnz=1000000, index_len=100, dim=0     |   355  |       394      |   2210
      n=1000, nnz=1000000, index_len=100, dim=1     |  1660  |      2276      |   2674
      n=1000, nnz=1000000, index_len=1000, dim=0    |   877  |       574      |   6700
      n=1000, nnz=1000000, index_len=1000, dim=1    |  2449  |      3782      |   9000
      n=1000, nnz=1000000, index_len=10000, dim=0   |  3112  |      2931      |  57000
      n=1000, nnz=1000000, index_len=10000, dim=1   |  7340  |     20220      |  65700

Times are in microseconds (us).

```

</details>

Pull Request resolved: #77551
Approved by: https://github.com/cpuhrsch
bigfootjon pushed a commit that referenced this pull request Jun 2, 2022
Summary:
Brings a native CUDA implementation for `index_select`. Master silently converts CUDA tensors to CPU for CUDA support.

Case `nnz >> size` could be optimized similar to how #72710 is doing that.

Some benchmarks:
<details>

<summary>PR/torch_sparse/master</summary>

```
[------------------------------- cuda coo.index_select -------------------------------]
                                                    |   PR   |  torch_sparse  |  master
32 threads: ---------------------------------------------------------------------------
      n=10000, nnz=100, index_len=100, dim=0        |    96  |       327      |     70
      n=10000, nnz=100, index_len=100, dim=1        |   120  |       505      |     74
      n=10000, nnz=100, index_len=1000, dim=0       |    90  |       333      |     93
      n=10000, nnz=100, index_len=1000, dim=1       |   120  |       499      |     98
      n=10000, nnz=100, index_len=10000, dim=0      |    92  |       331      |    350
      n=10000, nnz=100, index_len=10000, dim=1      |   100  |       506      |    352
      n=100000, nnz=1000, index_len=100, dim=0      |    53  |       274      |     60
      n=100000, nnz=1000, index_len=100, dim=1      |    90  |       368      |     71
      n=100000, nnz=1000, index_len=1000, dim=0     |    93  |       332      |    100
      n=100000, nnz=1000, index_len=1000, dim=1     |   130  |       501      |    140
      n=100000, nnz=1000, index_len=10000, dim=0    |   100  |       341      |    522
      n=100000, nnz=1000, index_len=10000, dim=1    |   130  |       530      |    549
      n=1000000, nnz=10000, index_len=100, dim=0    |    90  |       429      |    110
      n=1000000, nnz=10000, index_len=100, dim=1    |   296  |       810      |    355
      n=1000000, nnz=10000, index_len=1000, dim=0   |   100  |       435      |    170
      n=1000000, nnz=10000, index_len=1000, dim=1   |   309  |       830      |    548
      n=1000000, nnz=10000, index_len=10000, dim=0  |   110  |       446      |    750
      n=1000000, nnz=10000, index_len=10000, dim=1  |   310  |       830      |   1000
      n=10, nnz=100, index_len=100, dim=0           |    90  |       333      |     74
      n=10, nnz=100, index_len=100, dim=1           |   100  |       497      |     78
      n=10, nnz=100, index_len=1000, dim=0          |    90  |       329      |    140
      n=10, nnz=100, index_len=1000, dim=1          |   100  |       800      |    100
      n=10, nnz=100, index_len=10000, dim=0         |    93  |       340      |    900
      n=10, nnz=100, index_len=10000, dim=1         |   120  |       800      |    489
      n=10, nnz=1000, index_len=100, dim=0          |    90  |       321      |    140
      n=10, nnz=1000, index_len=100, dim=1          |   100  |       680      |    140
      n=10, nnz=1000, index_len=1000, dim=0         |   110  |       349      |    670
      n=10, nnz=1000, index_len=1000, dim=1         |   130  |       740      |    800
      n=10, nnz=1000, index_len=10000, dim=0        |   302  |       503      |   4882
      n=10, nnz=1000, index_len=10000, dim=1        |   325  |      2257      |   5262
      n=10, nnz=10000, index_len=100, dim=0         |   229  |       349      |    810
      n=10, nnz=10000, index_len=100, dim=1         |   433  |       870      |    700
      n=10, nnz=10000, index_len=1000, dim=0        |   666  |       502      |   5581
      n=10, nnz=10000, index_len=1000, dim=1        |   826  |      2379      |   4820
      n=10, nnz=10000, index_len=10000, dim=0       |  2534  |      2700      |  80000
      n=10, nnz=10000, index_len=10000, dim=1       |  2723  |     18540      |  80000
      n=100, nnz=1000, index_len=100, dim=0         |    94  |       324      |    110
      n=100, nnz=1000, index_len=100, dim=1         |   100  |       499      |    110
      n=100, nnz=1000, index_len=1000, dim=0        |    96  |       337      |    150
      n=100, nnz=1000, index_len=1000, dim=1        |   130  |       800      |    140
      n=100, nnz=1000, index_len=10000, dim=0       |   100  |       346      |    900
      n=100, nnz=1000, index_len=10000, dim=1       |   130  |       760      |    900
      n=100, nnz=10000, index_len=100, dim=0        |    90  |       323      |    190
      n=100, nnz=10000, index_len=100, dim=1        |   279  |       800      |    180
      n=100, nnz=10000, index_len=1000, dim=0       |   110  |       339      |    781
      n=100, nnz=10000, index_len=1000, dim=1       |   294  |       870      |    800
      n=100, nnz=10000, index_len=10000, dim=0      |   315  |       505      |   6264
      n=100, nnz=10000, index_len=10000, dim=1      |   497  |      2398      |   5404
      n=1000, nnz=10000, index_len=100, dim=0       |    90  |       333      |    160
      n=1000, nnz=10000, index_len=100, dim=1       |   279  |       635      |    150
      n=1000, nnz=10000, index_len=1000, dim=0      |   100  |       328      |    215
      n=1000, nnz=10000, index_len=1000, dim=1      |   287  |       810      |    207
      n=1000, nnz=10000, index_len=10000, dim=0     |   100  |       339      |    900
      n=1000, nnz=10000, index_len=10000, dim=1     |   291  |       880      |   1000
      n=1000, nnz=100000, index_len=100, dim=0      |    92  |       358      |    435
      n=1000, nnz=100000, index_len=100, dim=1      |   302  |       900      |    530
      n=1000, nnz=100000, index_len=1000, dim=0     |   130  |       360      |   1000
      n=1000, nnz=100000, index_len=1000, dim=1     |   329  |       930      |   1200
      n=1000, nnz=100000, index_len=10000, dim=0    |   343  |       530      |   7000
      n=1000, nnz=100000, index_len=10000, dim=1    |   545  |      2446      |   6100
      n=1000, nnz=1000000, index_len=100, dim=0     |   355  |       394      |   2210
      n=1000, nnz=1000000, index_len=100, dim=1     |  1660  |      2276      |   2674
      n=1000, nnz=1000000, index_len=1000, dim=0    |   877  |       574      |   6700
      n=1000, nnz=1000000, index_len=1000, dim=1    |  2449  |      3782      |   9000
      n=1000, nnz=1000000, index_len=10000, dim=0   |  3112  |      2931      |  57000
      n=1000, nnz=1000000, index_len=10000, dim=1   |  7340  |     20220      |  65700

Times are in microseconds (us).

```

</details>

Pull Request resolved: #77551
Approved by: https://github.com/cpuhrsch

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/03cf01bdc03a631a1ab521e27b6523bca1a57f0d

Reviewed By: b0noI

Differential Revision: D36854233

Pulled By: b0noI

fbshipit-source-id: 9c665baf72fbb5530b450af0d768d0761b1a5c73
@malfet malfet deleted the nikitaved/coo_index_select branch July 17, 2022 18:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: sparse Related to torch.sparse open source release notes: sparse release notes category Reverted topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module with-ssh
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Perf request] Make index_select on sparse COO tensors as fast as that from rusty1s/pytorch_sparse (1000x)
10 participants