-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Perf request] Make index_select on sparse COO tensors as fast as that from rusty1s/pytorch_sparse (1000x) #72212
Comments
I did have a quick look at the implementation. It appears to be quite serial, and maybe we could make it parallelized and/or use more efficient data structures... |
@tvercaut , are these times for the CPU or GPU. If for the GPU, could you please append |
For the CPU the complexity of the implemented algorithm seems like |
The reported timing were for the GPU but I get roughly the same numbers on CPU for this toy example. I have updated my quick and dirty colab notebook to run both GPU and CPU and also add
|
This is a reasonable request. We can find someone to work on it, but it'll take a while. If it's to be done sooner we're also happy to accept a pull request. |
I can work on this for a change. |
@tvercaut , FYI, the CUDA backend converts tensors to the CPU, runs the algo, then converts back. |
@tvercaut , I have another question for you. In your use cases you only index for |
@nikitaved, in the actual use case I was looking at, I am interested in indexing with More specifically, my use case involves a sparse matrix L and I am intested in (conceptually) extracting a sub matrix from L@L.T using the same indices for The toy example shared here was my attempt at providing a minimal example of the performance issue I was facing. |
@nikitaved - on CSC support @pearu was planning to work on it next. |
Hi @tvercaut, Have you been tested on an extremely huge table? When I used the index_select in your approach, I encounter the OOM issue. Plus, it cannot deal with one dimension sparse tensor. I had to make it two-dimensional beforehand. And index_select dim=1 and then calculate the mean on dim 0. Though it works, it is not very elegant. Or, do you know a better way to get the values from a sparse tensor by a sequence of the index? Original index_select is too slow to use. Something like |
Master now contains an updated and faster version. Note, however, just like the previous version, CUDA inputs are converted to CPU, then the algorithm is run on the CPU and it moves the result back to CUDA once it is done. |
Fixes #72212. This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible. Benchmark results. <details> <summary>Testing script</summary> ```python import torch import math from IPython import get_ipython from itertools import product import pickle from torch.utils.benchmark import Timer, Compare torch.manual_seed(13) #torch.set_num_threads(1) ipython = get_ipython() index_sizes = (100, 1000, 10000) # specifies (n, nnz) problem_dims = ( # n > nnz (10000, 100), (100000, 1000), (1000000, 10000), # n < nnz (10, 100), (10, 1000), (10, 10000), (100, 1000), (100, 10000), (1000, 10000), (1000, 100000), (1000, 1000000), #(1000000, 1000000000), ) def f(t, d, index): s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t) ss = s.index_select(d, index) return ss.coo() name = "PR" results = [] for (n, nnz), m in product(problem_dims, index_sizes): for d in (0, 1): if nnz < n: shape = (n, n) else: shape = (n, nnz // n) if d == 0 else (nnz // n, n) nrows, ncols = shape rowidx = torch.randint(low=0, high=nrows, size=(nnz,)) colidx = torch.randint(low=0, high=ncols, size=(nnz,)) itemidx = torch.vstack((rowidx, colidx)) xvalues = torch.randn(nnz) index = torch.randint(low=0, high=n, size=(m,)) SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce() smtp = "SparseX.index_select(d, index)" timer = Timer(smtp, globals=globals(), label="coo.index_select", description=f"{name}: coo.index_select", sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}", num_threads=torch.get_num_threads()) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open(f"{name}_index_select.pickle", 'wb') as f: pickle.dump(results, f) ``` </details> <details> <summary>Gather results</summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "PR", "torch_sparse", "master" ] timers = [] for name in files: with open("{}_index_select.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> <details> <summary>PR/torch_sparse/master runtime comparison</summary> ``` [----------------------------------- coo.index_select ----------------------------------] | PR | torch_sparse | master 32 threads: ----------------------------------------------------------------------------- n=10000, nnz=100, index_len=100, dim=0 | 14 | 140 | 10 n=10000, nnz=100, index_len=100, dim=1 | 14 | 200 | 10 n=10000, nnz=100, index_len=1000, dim=0 | 30 | 180 | 38 n=10000, nnz=100, index_len=1000, dim=1 | 34 | 240 | 38 n=10000, nnz=100, index_len=10000, dim=0 | 278 | 460 | 330 n=10000, nnz=100, index_len=10000, dim=1 | 275 | 516 | 330 n=100000, nnz=1000, index_len=100, dim=0 | 16 | 290 | 31 n=100000, nnz=1000, index_len=100, dim=1 | 26 | 390 | 31 n=100000, nnz=1000, index_len=1000, dim=0 | 45 | 405 | 263 n=100000, nnz=1000, index_len=1000, dim=1 | 73 | 500 | 261 n=100000, nnz=1000, index_len=10000, dim=0 | 444 | 783 | 2570 n=100000, nnz=1000, index_len=10000, dim=1 | 470 | 890 | 2590 n=1000000, nnz=10000, index_len=100, dim=0 | 25 | 2400 | 270 n=1000000, nnz=10000, index_len=100, dim=1 | 270 | 4000 | 269 n=1000000, nnz=10000, index_len=1000, dim=0 | 74 | 2600 | 2620 n=1000000, nnz=10000, index_len=1000, dim=1 | 464 | 3600 | 2640 n=1000000, nnz=10000, index_len=10000, dim=0 | 635 | 3300 | 26400 n=1000000, nnz=10000, index_len=10000, dim=1 | 1000 | 3960 | 26400 n=10, nnz=100, index_len=100, dim=0 | 16 | 137 | 16 n=10, nnz=100, index_len=100, dim=1 | 16 | 220 | 16 n=10, nnz=100, index_len=1000, dim=0 | 63 | 238 | 81 n=10, nnz=100, index_len=1000, dim=1 | 60 | 698 | 78 n=10, nnz=100, index_len=10000, dim=0 | 480 | 940 | 862 n=10, nnz=100, index_len=10000, dim=1 | 330 | 4930 | 1070 n=10, nnz=1000, index_len=100, dim=0 | 60 | 200 | 73 n=10, nnz=1000, index_len=100, dim=1 | 56 | 683 | 70 n=10, nnz=1000, index_len=1000, dim=0 | 480 | 530 | 1050 n=10, nnz=1000, index_len=1000, dim=1 | 330 | 4550 | 1368 n=10, nnz=1000, index_len=10000, dim=0 | 3100 | 2900 | 9300 n=10, nnz=1000, index_len=10000, dim=1 | 3400 | 46000 | 9100 n=10, nnz=10000, index_len=100, dim=0 | 400 | 453 | 857 n=10, nnz=10000, index_len=100, dim=1 | 400 | 4070 | 1730 n=10, nnz=10000, index_len=1000, dim=0 | 2840 | 2600 | 13900 n=10, nnz=10000, index_len=1000, dim=1 | 3700 | 40600 | 16000 n=10, nnz=10000, index_len=10000, dim=0 | 83200 | 67400 | 160000 n=10, nnz=10000, index_len=10000, dim=1 | 68000 | 528000 | 190000 n=100, nnz=1000, index_len=100, dim=0 | 46 | 148 | 31 n=100, nnz=1000, index_len=100, dim=1 | 45 | 242 | 37 n=100, nnz=1000, index_len=1000, dim=0 | 68 | 248 | 240 n=100, nnz=1000, index_len=1000, dim=1 | 66 | 755 | 290 n=100, nnz=1000, index_len=10000, dim=0 | 370 | 802 | 2250 n=100, nnz=1000, index_len=10000, dim=1 | 372 | 5430 | 2770 n=100, nnz=10000, index_len=100, dim=0 | 82 | 210 | 224 n=100, nnz=10000, index_len=100, dim=1 | 74 | 986 | 270 n=100, nnz=10000, index_len=1000, dim=0 | 350 | 618 | 2600 n=100, nnz=10000, index_len=1000, dim=1 | 370 | 4660 | 4560 n=100, nnz=10000, index_len=10000, dim=0 | 3000 | 3400 | 41680 n=100, nnz=10000, index_len=10000, dim=1 | 5000 | 47500 | 30400 n=1000, nnz=10000, index_len=100, dim=0 | 71 | 160 | 185 n=1000, nnz=10000, index_len=100, dim=1 | 64 | 516 | 190 n=1000, nnz=10000, index_len=1000, dim=0 | 100 | 249 | 1740 n=1000, nnz=10000, index_len=1000, dim=1 | 98 | 1030 | 1770 n=1000, nnz=10000, index_len=10000, dim=0 | 600 | 808 | 18300 n=1000, nnz=10000, index_len=10000, dim=1 | 663 | 5300 | 18500 n=1000, nnz=100000, index_len=100, dim=0 | 160 | 258 | 1890 n=1000, nnz=100000, index_len=100, dim=1 | 200 | 3620 | 2050 n=1000, nnz=100000, index_len=1000, dim=0 | 500 | 580 | 18700 n=1000, nnz=100000, index_len=1000, dim=1 | 640 | 7550 | 30000 n=1000, nnz=100000, index_len=10000, dim=0 | 3400 | 3260 | 186000 n=1000, nnz=100000, index_len=10000, dim=1 | 3600 | 49600 | 194000 n=1000, nnz=1000000, index_len=100, dim=0 | 517 | 957 | 18700 n=1000, nnz=1000000, index_len=100, dim=1 | 680 | 39600 | 37600 n=1000, nnz=1000000, index_len=1000, dim=0 | 3600 | 4500 | 186000 n=1000, nnz=1000000, index_len=1000, dim=1 | 5800 | 76400 | 190000 n=1000, nnz=1000000, index_len=10000, dim=0 | 50000 | 67900 | 1800000 n=1000, nnz=1000000, index_len=10000, dim=1 | 45000 | 570000 | 1900000 Times are in microseconds (us). ``` </details> Pull Request resolved: #72710 Approved by: https://github.com/pearu, https://github.com/cpuhrsch
Summary: Fixes #72212. This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible. Benchmark results. <details> <summary>Testing script</summary> ```python import torch import math from IPython import get_ipython from itertools import product import pickle from torch.utils.benchmark import Timer, Compare torch.manual_seed(13) #torch.set_num_threads(1) ipython = get_ipython() index_sizes = (100, 1000, 10000) # specifies (n, nnz) problem_dims = ( # n > nnz (10000, 100), (100000, 1000), (1000000, 10000), # n < nnz (10, 100), (10, 1000), (10, 10000), (100, 1000), (100, 10000), (1000, 10000), (1000, 100000), (1000, 1000000), #(1000000, 1000000000), ) def f(t, d, index): s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t) ss = s.index_select(d, index) return ss.coo() name = "PR" results = [] for (n, nnz), m in product(problem_dims, index_sizes): for d in (0, 1): if nnz < n: shape = (n, n) else: shape = (n, nnz // n) if d == 0 else (nnz // n, n) nrows, ncols = shape rowidx = torch.randint(low=0, high=nrows, size=(nnz,)) colidx = torch.randint(low=0, high=ncols, size=(nnz,)) itemidx = torch.vstack((rowidx, colidx)) xvalues = torch.randn(nnz) index = torch.randint(low=0, high=n, size=(m,)) SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce() smtp = "SparseX.index_select(d, index)" timer = Timer(smtp, globals=globals(), label="coo.index_select", description=f"{name}: coo.index_select", sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}", num_threads=torch.get_num_threads()) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open(f"{name}_index_select.pickle", 'wb') as f: pickle.dump(results, f) ``` </details> <details> <summary>Gather results</summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "PR", "torch_sparse", "master" ] timers = [] for name in files: with open("{}_index_select.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> <details> <summary>PR/torch_sparse/master runtime comparison</summary> ``` [----------------------------------- coo.index_select ----------------------------------] | PR | torch_sparse | master 32 threads: ----------------------------------------------------------------------------- n=10000, nnz=100, index_len=100, dim=0 | 14 | 140 | 10 n=10000, nnz=100, index_len=100, dim=1 | 14 | 200 | 10 n=10000, nnz=100, index_len=1000, dim=0 | 30 | 180 | 38 n=10000, nnz=100, index_len=1000, dim=1 | 34 | 240 | 38 n=10000, nnz=100, index_len=10000, dim=0 | 278 | 460 | 330 n=10000, nnz=100, index_len=10000, dim=1 | 275 | 516 | 330 n=100000, nnz=1000, index_len=100, dim=0 | 16 | 290 | 31 n=100000, nnz=1000, index_len=100, dim=1 | 26 | 390 | 31 n=100000, nnz=1000, index_len=1000, dim=0 | 45 | 405 | 263 n=100000, nnz=1000, index_len=1000, dim=1 | 73 | 500 | 261 n=100000, nnz=1000, index_len=10000, dim=0 | 444 | 783 | 2570 n=100000, nnz=1000, index_len=10000, dim=1 | 470 | 890 | 2590 n=1000000, nnz=10000, index_len=100, dim=0 | 25 | 2400 | 270 n=1000000, nnz=10000, index_len=100, dim=1 | 270 | 4000 | 269 n=1000000, nnz=10000, index_len=1000, dim=0 | 74 | 2600 | 2620 n=1000000, nnz=10000, index_len=1000, dim=1 | 464 | 3600 | 2640 n=1000000, nnz=10000, index_len=10000, dim=0 | 635 | 3300 | 26400 n=1000000, nnz=10000, index_len=10000, dim=1 | 1000 | 3960 | 26400 n=10, nnz=100, index_len=100, dim=0 | 16 | 137 | 16 n=10, nnz=100, index_len=100, dim=1 | 16 | 220 | 16 n=10, nnz=100, index_len=1000, dim=0 | 63 | 238 | 81 n=10, nnz=100, index_len=1000, dim=1 | 60 | 698 | 78 n=10, nnz=100, index_len=10000, dim=0 | 480 | 940 | 862 n=10, nnz=100, index_len=10000, dim=1 | 330 | 4930 | 1070 n=10, nnz=1000, index_len=100, dim=0 | 60 | 200 | 73 n=10, nnz=1000, index_len=100, dim=1 | 56 | 683 | 70 n=10, nnz=1000, index_len=1000, dim=0 | 480 | 530 | 1050 n=10, nnz=1000, index_len=1000, dim=1 | 330 | 4550 | 1368 n=10, nnz=1000, index_len=10000, dim=0 | 3100 | 2900 | 9300 n=10, nnz=1000, index_len=10000, dim=1 | 3400 | 46000 | 9100 n=10, nnz=10000, index_len=100, dim=0 | 400 | 453 | 857 n=10, nnz=10000, index_len=100, dim=1 | 400 | 4070 | 1730 n=10, nnz=10000, index_len=1000, dim=0 | 2840 | 2600 | 13900 n=10, nnz=10000, index_len=1000, dim=1 | 3700 | 40600 | 16000 n=10, nnz=10000, index_len=10000, dim=0 | 83200 | 67400 | 160000 n=10, nnz=10000, index_len=10000, dim=1 | 68000 | 528000 | 190000 n=100, nnz=1000, index_len=100, dim=0 | 46 | 148 | 31 n=100, nnz=1000, index_len=100, dim=1 | 45 | 242 | 37 n=100, nnz=1000, index_len=1000, dim=0 | 68 | 248 | 240 n=100, nnz=1000, index_len=1000, dim=1 | 66 | 755 | 290 n=100, nnz=1000, index_len=10000, dim=0 | 370 | 802 | 2250 n=100, nnz=1000, index_len=10000, dim=1 | 372 | 5430 | 2770 n=100, nnz=10000, index_len=100, dim=0 | 82 | 210 | 224 n=100, nnz=10000, index_len=100, dim=1 | 74 | 986 | 270 n=100, nnz=10000, index_len=1000, dim=0 | 350 | 618 | 2600 n=100, nnz=10000, index_len=1000, dim=1 | 370 | 4660 | 4560 n=100, nnz=10000, index_len=10000, dim=0 | 3000 | 3400 | 41680 n=100, nnz=10000, index_len=10000, dim=1 | 5000 | 47500 | 30400 n=1000, nnz=10000, index_len=100, dim=0 | 71 | 160 | 185 n=1000, nnz=10000, index_len=100, dim=1 | 64 | 516 | 190 n=1000, nnz=10000, index_len=1000, dim=0 | 100 | 249 | 1740 n=1000, nnz=10000, index_len=1000, dim=1 | 98 | 1030 | 1770 n=1000, nnz=10000, index_len=10000, dim=0 | 600 | 808 | 18300 n=1000, nnz=10000, index_len=10000, dim=1 | 663 | 5300 | 18500 n=1000, nnz=100000, index_len=100, dim=0 | 160 | 258 | 1890 n=1000, nnz=100000, index_len=100, dim=1 | 200 | 3620 | 2050 n=1000, nnz=100000, index_len=1000, dim=0 | 500 | 580 | 18700 n=1000, nnz=100000, index_len=1000, dim=1 | 640 | 7550 | 30000 n=1000, nnz=100000, index_len=10000, dim=0 | 3400 | 3260 | 186000 n=1000, nnz=100000, index_len=10000, dim=1 | 3600 | 49600 | 194000 n=1000, nnz=1000000, index_len=100, dim=0 | 517 | 957 | 18700 n=1000, nnz=1000000, index_len=100, dim=1 | 680 | 39600 | 37600 n=1000, nnz=1000000, index_len=1000, dim=0 | 3600 | 4500 | 186000 n=1000, nnz=1000000, index_len=1000, dim=1 | 5800 | 76400 | 190000 n=1000, nnz=1000000, index_len=10000, dim=0 | 50000 | 67900 | 1800000 n=1000, nnz=1000000, index_len=10000, dim=1 | 45000 | 570000 | 1900000 Times are in microseconds (us). ``` </details> Pull Request resolved: #72710 Reviewed By: samdow Differential Revision: D36282349 Pulled By: malfet fbshipit-source-id: 3679ea4ebeeda4d200a441aef6d45b98303bc0c0
Summary: Fixes #72212. This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible. Benchmark results. <details> <summary>Testing script</summary> ```python import torch import math from IPython import get_ipython from itertools import product import pickle from torch.utils.benchmark import Timer, Compare torch.manual_seed(13) #torch.set_num_threads(1) ipython = get_ipython() index_sizes = (100, 1000, 10000) # specifies (n, nnz) problem_dims = ( # n > nnz (10000, 100), (100000, 1000), (1000000, 10000), # n < nnz (10, 100), (10, 1000), (10, 10000), (100, 1000), (100, 10000), (1000, 10000), (1000, 100000), (1000, 1000000), #(1000000, 1000000000), ) def f(t, d, index): s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t) ss = s.index_select(d, index) return ss.coo() name = "PR" results = [] for (n, nnz), m in product(problem_dims, index_sizes): for d in (0, 1): if nnz < n: shape = (n, n) else: shape = (n, nnz // n) if d == 0 else (nnz // n, n) nrows, ncols = shape rowidx = torch.randint(low=0, high=nrows, size=(nnz,)) colidx = torch.randint(low=0, high=ncols, size=(nnz,)) itemidx = torch.vstack((rowidx, colidx)) xvalues = torch.randn(nnz) index = torch.randint(low=0, high=n, size=(m,)) SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce() smtp = "SparseX.index_select(d, index)" timer = Timer(smtp, globals=globals(), label="coo.index_select", description=f"{name}: coo.index_select", sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}", num_threads=torch.get_num_threads()) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open(f"{name}_index_select.pickle", 'wb') as f: pickle.dump(results, f) ``` </details> <details> <summary>Gather results</summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "PR", "torch_sparse", "master" ] timers = [] for name in files: with open("{}_index_select.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> <details> <summary>PR/torch_sparse/master runtime comparison</summary> ``` [----------------------------------- coo.index_select ----------------------------------] | PR | torch_sparse | master 32 threads: ----------------------------------------------------------------------------- n=10000, nnz=100, index_len=100, dim=0 | 14 | 140 | 10 n=10000, nnz=100, index_len=100, dim=1 | 14 | 200 | 10 n=10000, nnz=100, index_len=1000, dim=0 | 30 | 180 | 38 n=10000, nnz=100, index_len=1000, dim=1 | 34 | 240 | 38 n=10000, nnz=100, index_len=10000, dim=0 | 278 | 460 | 330 n=10000, nnz=100, index_len=10000, dim=1 | 275 | 516 | 330 n=100000, nnz=1000, index_len=100, dim=0 | 16 | 290 | 31 n=100000, nnz=1000, index_len=100, dim=1 | 26 | 390 | 31 n=100000, nnz=1000, index_len=1000, dim=0 | 45 | 405 | 263 n=100000, nnz=1000, index_len=1000, dim=1 | 73 | 500 | 261 n=100000, nnz=1000, index_len=10000, dim=0 | 444 | 783 | 2570 n=100000, nnz=1000, index_len=10000, dim=1 | 470 | 890 | 2590 n=1000000, nnz=10000, index_len=100, dim=0 | 25 | 2400 | 270 n=1000000, nnz=10000, index_len=100, dim=1 | 270 | 4000 | 269 n=1000000, nnz=10000, index_len=1000, dim=0 | 74 | 2600 | 2620 n=1000000, nnz=10000, index_len=1000, dim=1 | 464 | 3600 | 2640 n=1000000, nnz=10000, index_len=10000, dim=0 | 635 | 3300 | 26400 n=1000000, nnz=10000, index_len=10000, dim=1 | 1000 | 3960 | 26400 n=10, nnz=100, index_len=100, dim=0 | 16 | 137 | 16 n=10, nnz=100, index_len=100, dim=1 | 16 | 220 | 16 n=10, nnz=100, index_len=1000, dim=0 | 63 | 238 | 81 n=10, nnz=100, index_len=1000, dim=1 | 60 | 698 | 78 n=10, nnz=100, index_len=10000, dim=0 | 480 | 940 | 862 n=10, nnz=100, index_len=10000, dim=1 | 330 | 4930 | 1070 n=10, nnz=1000, index_len=100, dim=0 | 60 | 200 | 73 n=10, nnz=1000, index_len=100, dim=1 | 56 | 683 | 70 n=10, nnz=1000, index_len=1000, dim=0 | 480 | 530 | 1050 n=10, nnz=1000, index_len=1000, dim=1 | 330 | 4550 | 1368 n=10, nnz=1000, index_len=10000, dim=0 | 3100 | 2900 | 9300 n=10, nnz=1000, index_len=10000, dim=1 | 3400 | 46000 | 9100 n=10, nnz=10000, index_len=100, dim=0 | 400 | 453 | 857 n=10, nnz=10000, index_len=100, dim=1 | 400 | 4070 | 1730 n=10, nnz=10000, index_len=1000, dim=0 | 2840 | 2600 | 13900 n=10, nnz=10000, index_len=1000, dim=1 | 3700 | 40600 | 16000 n=10, nnz=10000, index_len=10000, dim=0 | 83200 | 67400 | 160000 n=10, nnz=10000, index_len=10000, dim=1 | 68000 | 528000 | 190000 n=100, nnz=1000, index_len=100, dim=0 | 46 | 148 | 31 n=100, nnz=1000, index_len=100, dim=1 | 45 | 242 | 37 n=100, nnz=1000, index_len=1000, dim=0 | 68 | 248 | 240 n=100, nnz=1000, index_len=1000, dim=1 | 66 | 755 | 290 n=100, nnz=1000, index_len=10000, dim=0 | 370 | 802 | 2250 n=100, nnz=1000, index_len=10000, dim=1 | 372 | 5430 | 2770 n=100, nnz=10000, index_len=100, dim=0 | 82 | 210 | 224 n=100, nnz=10000, index_len=100, dim=1 | 74 | 986 | 270 n=100, nnz=10000, index_len=1000, dim=0 | 350 | 618 | 2600 n=100, nnz=10000, index_len=1000, dim=1 | 370 | 4660 | 4560 n=100, nnz=10000, index_len=10000, dim=0 | 3000 | 3400 | 41680 n=100, nnz=10000, index_len=10000, dim=1 | 5000 | 47500 | 30400 n=1000, nnz=10000, index_len=100, dim=0 | 71 | 160 | 185 n=1000, nnz=10000, index_len=100, dim=1 | 64 | 516 | 190 n=1000, nnz=10000, index_len=1000, dim=0 | 100 | 249 | 1740 n=1000, nnz=10000, index_len=1000, dim=1 | 98 | 1030 | 1770 n=1000, nnz=10000, index_len=10000, dim=0 | 600 | 808 | 18300 n=1000, nnz=10000, index_len=10000, dim=1 | 663 | 5300 | 18500 n=1000, nnz=100000, index_len=100, dim=0 | 160 | 258 | 1890 n=1000, nnz=100000, index_len=100, dim=1 | 200 | 3620 | 2050 n=1000, nnz=100000, index_len=1000, dim=0 | 500 | 580 | 18700 n=1000, nnz=100000, index_len=1000, dim=1 | 640 | 7550 | 30000 n=1000, nnz=100000, index_len=10000, dim=0 | 3400 | 3260 | 186000 n=1000, nnz=100000, index_len=10000, dim=1 | 3600 | 49600 | 194000 n=1000, nnz=1000000, index_len=100, dim=0 | 517 | 957 | 18700 n=1000, nnz=1000000, index_len=100, dim=1 | 680 | 39600 | 37600 n=1000, nnz=1000000, index_len=1000, dim=0 | 3600 | 4500 | 186000 n=1000, nnz=1000000, index_len=1000, dim=1 | 5800 | 76400 | 190000 n=1000, nnz=1000000, index_len=10000, dim=0 | 50000 | 67900 | 1800000 n=1000, nnz=1000000, index_len=10000, dim=1 | 45000 | 570000 | 1900000 Times are in microseconds (us). ``` </details> Pull Request resolved: #72710 Reviewed By: samdow Differential Revision: D36282349 Pulled By: malfet fbshipit-source-id: 3679ea4ebeeda4d200a441aef6d45b98303bc0c0
@nikitaved also just merged a native CUDA implementation for this with even more improvements. @tvercaut - the latest nightly should now contain a much more efficient implement for index_select. Would you mind trying it for your use case again? |
Looks great! I confirm the speed up is massive. Many thanks. I'm looking forward to seeing this in 1.12. |
@tvercaut - I'm very happy to hear that! Unfortunately the CUDA performance improvements didn't make the branch cut in time so are to be expected in 1.13, but the CPU performance improvements will be available in 1.12. |
馃殌 The feature, motivation and pitch
The native
index_select
on sparse COO tensors has been reported to be slow in the past (see #61788 and #54561). While improvements have been made with #63008, it remains slower than expected even with pytorch 1.10.0 and 1.10.1 which, as per the 1.10.0 release notes, I understand include the #63008 patch.Alternatives
The
index_select
implementation from rusty1s/torch_sparse can be used in lieu of the nativeindex_select
and it is sometimes up to 1000x faster.Additional context
Below is the simple test case I used as a quick and dirty "benchmark" (also on colab):
with output
Despite the suggestion in #61788 (comment) to reopen that issue if performance bottlenecks remained , as I not the OP of #61788, I cannot reopen it myself, hence this new feature request.
cc @nikitaved @pearu @cpuhrsch
The text was updated successfully, but these errors were encountered: