Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Faster
index_select
for sparse COO tensors on CPU. (#72710)
Fixes #72212. This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible. Benchmark results. <details> <summary>Testing script</summary> ```python import torch import math from IPython import get_ipython from itertools import product import pickle from torch.utils.benchmark import Timer, Compare torch.manual_seed(13) #torch.set_num_threads(1) ipython = get_ipython() index_sizes = (100, 1000, 10000) # specifies (n, nnz) problem_dims = ( # n > nnz (10000, 100), (100000, 1000), (1000000, 10000), # n < nnz (10, 100), (10, 1000), (10, 10000), (100, 1000), (100, 10000), (1000, 10000), (1000, 100000), (1000, 1000000), #(1000000, 1000000000), ) def f(t, d, index): s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t) ss = s.index_select(d, index) return ss.coo() name = "PR" results = [] for (n, nnz), m in product(problem_dims, index_sizes): for d in (0, 1): if nnz < n: shape = (n, n) else: shape = (n, nnz // n) if d == 0 else (nnz // n, n) nrows, ncols = shape rowidx = torch.randint(low=0, high=nrows, size=(nnz,)) colidx = torch.randint(low=0, high=ncols, size=(nnz,)) itemidx = torch.vstack((rowidx, colidx)) xvalues = torch.randn(nnz) index = torch.randint(low=0, high=n, size=(m,)) SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce() smtp = "SparseX.index_select(d, index)" timer = Timer(smtp, globals=globals(), label="coo.index_select", description=f"{name}: coo.index_select", sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}", num_threads=torch.get_num_threads()) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open(f"{name}_index_select.pickle", 'wb') as f: pickle.dump(results, f) ``` </details> <details> <summary>Gather results</summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "PR", "torch_sparse", "master" ] timers = [] for name in files: with open("{}_index_select.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> <details> <summary>PR/torch_sparse/master runtime comparison</summary> ``` [----------------------------------- coo.index_select ----------------------------------] | PR | torch_sparse | master 32 threads: ----------------------------------------------------------------------------- n=10000, nnz=100, index_len=100, dim=0 | 14 | 140 | 10 n=10000, nnz=100, index_len=100, dim=1 | 14 | 200 | 10 n=10000, nnz=100, index_len=1000, dim=0 | 30 | 180 | 38 n=10000, nnz=100, index_len=1000, dim=1 | 34 | 240 | 38 n=10000, nnz=100, index_len=10000, dim=0 | 278 | 460 | 330 n=10000, nnz=100, index_len=10000, dim=1 | 275 | 516 | 330 n=100000, nnz=1000, index_len=100, dim=0 | 16 | 290 | 31 n=100000, nnz=1000, index_len=100, dim=1 | 26 | 390 | 31 n=100000, nnz=1000, index_len=1000, dim=0 | 45 | 405 | 263 n=100000, nnz=1000, index_len=1000, dim=1 | 73 | 500 | 261 n=100000, nnz=1000, index_len=10000, dim=0 | 444 | 783 | 2570 n=100000, nnz=1000, index_len=10000, dim=1 | 470 | 890 | 2590 n=1000000, nnz=10000, index_len=100, dim=0 | 25 | 2400 | 270 n=1000000, nnz=10000, index_len=100, dim=1 | 270 | 4000 | 269 n=1000000, nnz=10000, index_len=1000, dim=0 | 74 | 2600 | 2620 n=1000000, nnz=10000, index_len=1000, dim=1 | 464 | 3600 | 2640 n=1000000, nnz=10000, index_len=10000, dim=0 | 635 | 3300 | 26400 n=1000000, nnz=10000, index_len=10000, dim=1 | 1000 | 3960 | 26400 n=10, nnz=100, index_len=100, dim=0 | 16 | 137 | 16 n=10, nnz=100, index_len=100, dim=1 | 16 | 220 | 16 n=10, nnz=100, index_len=1000, dim=0 | 63 | 238 | 81 n=10, nnz=100, index_len=1000, dim=1 | 60 | 698 | 78 n=10, nnz=100, index_len=10000, dim=0 | 480 | 940 | 862 n=10, nnz=100, index_len=10000, dim=1 | 330 | 4930 | 1070 n=10, nnz=1000, index_len=100, dim=0 | 60 | 200 | 73 n=10, nnz=1000, index_len=100, dim=1 | 56 | 683 | 70 n=10, nnz=1000, index_len=1000, dim=0 | 480 | 530 | 1050 n=10, nnz=1000, index_len=1000, dim=1 | 330 | 4550 | 1368 n=10, nnz=1000, index_len=10000, dim=0 | 3100 | 2900 | 9300 n=10, nnz=1000, index_len=10000, dim=1 | 3400 | 46000 | 9100 n=10, nnz=10000, index_len=100, dim=0 | 400 | 453 | 857 n=10, nnz=10000, index_len=100, dim=1 | 400 | 4070 | 1730 n=10, nnz=10000, index_len=1000, dim=0 | 2840 | 2600 | 13900 n=10, nnz=10000, index_len=1000, dim=1 | 3700 | 40600 | 16000 n=10, nnz=10000, index_len=10000, dim=0 | 83200 | 67400 | 160000 n=10, nnz=10000, index_len=10000, dim=1 | 68000 | 528000 | 190000 n=100, nnz=1000, index_len=100, dim=0 | 46 | 148 | 31 n=100, nnz=1000, index_len=100, dim=1 | 45 | 242 | 37 n=100, nnz=1000, index_len=1000, dim=0 | 68 | 248 | 240 n=100, nnz=1000, index_len=1000, dim=1 | 66 | 755 | 290 n=100, nnz=1000, index_len=10000, dim=0 | 370 | 802 | 2250 n=100, nnz=1000, index_len=10000, dim=1 | 372 | 5430 | 2770 n=100, nnz=10000, index_len=100, dim=0 | 82 | 210 | 224 n=100, nnz=10000, index_len=100, dim=1 | 74 | 986 | 270 n=100, nnz=10000, index_len=1000, dim=0 | 350 | 618 | 2600 n=100, nnz=10000, index_len=1000, dim=1 | 370 | 4660 | 4560 n=100, nnz=10000, index_len=10000, dim=0 | 3000 | 3400 | 41680 n=100, nnz=10000, index_len=10000, dim=1 | 5000 | 47500 | 30400 n=1000, nnz=10000, index_len=100, dim=0 | 71 | 160 | 185 n=1000, nnz=10000, index_len=100, dim=1 | 64 | 516 | 190 n=1000, nnz=10000, index_len=1000, dim=0 | 100 | 249 | 1740 n=1000, nnz=10000, index_len=1000, dim=1 | 98 | 1030 | 1770 n=1000, nnz=10000, index_len=10000, dim=0 | 600 | 808 | 18300 n=1000, nnz=10000, index_len=10000, dim=1 | 663 | 5300 | 18500 n=1000, nnz=100000, index_len=100, dim=0 | 160 | 258 | 1890 n=1000, nnz=100000, index_len=100, dim=1 | 200 | 3620 | 2050 n=1000, nnz=100000, index_len=1000, dim=0 | 500 | 580 | 18700 n=1000, nnz=100000, index_len=1000, dim=1 | 640 | 7550 | 30000 n=1000, nnz=100000, index_len=10000, dim=0 | 3400 | 3260 | 186000 n=1000, nnz=100000, index_len=10000, dim=1 | 3600 | 49600 | 194000 n=1000, nnz=1000000, index_len=100, dim=0 | 517 | 957 | 18700 n=1000, nnz=1000000, index_len=100, dim=1 | 680 | 39600 | 37600 n=1000, nnz=1000000, index_len=1000, dim=0 | 3600 | 4500 | 186000 n=1000, nnz=1000000, index_len=1000, dim=1 | 5800 | 76400 | 190000 n=1000, nnz=1000000, index_len=10000, dim=0 | 50000 | 67900 | 1800000 n=1000, nnz=1000000, index_len=10000, dim=1 | 45000 | 570000 | 1900000 Times are in microseconds (us). ``` </details> Pull Request resolved: #72710 Approved by: https://github.com/pearu, https://github.com/cpuhrsch
- Loading branch information