Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port put_ and take from TH to ATen (#53356)
Summary: The two ports were don together, as they can be implemented with the same kernel. In TH, they were already implemented with the same kernel. Resolves #24751 Resolves #24614 Resolves #24640 Resolves #24772 This port makes sure that it interacts correctly with the "deterministic algorithms" flag, as done in #51388 This PR also makes these two functions correct in the following aspects (all of them added to the tests as well): - Support for complex numbers - Correct handling of scalar inputs and zero-dimensional inputs - Implementation that does not do any copies nor sorting of any of the input tensors - Faster and more correct implementation of the backwards (now it works as it should when `source.shape() != index.shape()`) - Now `put_(..., accumulate=True)` is implemented correctly with atomic operations on GPU / CPU (when possible) and is deterministic (modulo the loss of precision that might happen due to the reordering of a sum of floats) - Adds the `torch.put` function that was missing, (`index_put` exists, for example) - Corrected docs It also adds a much more thorough testing to the operations and their gradients. There is a BC-breaking change, and that is that now we check that the inputs do not overlap in the `put_` operation. This was handled (some of the cases, other cases were wrong) in the TH implementation by making contiguous copies of the inputs. How should we handle this one? **Edit.** Benchmarks: <details> <summary>Script</summary> ```python from IPython import get_ipython import torch from itertools import product torch.manual_seed(13) torch.set_num_threads(1) ipython = get_ipython() cpu = torch.device('cpu') cuda = torch.device('cuda') def run_test(ndims, size, index_len, device, cmd): print(f"cmd: {cmd}, ndims: {ndims}, tensor_size: {size}, index_len: {index_len}, device: {device}") large_tensor = torch.rand(*([size] * ndims), device=device) small_tensor = torch.rand((index_len,), device=device) index = torch.randint(size * ndims, (index_len,), dtype=torch.long, device=device) if cmd == "put": command = "large_tensor.put_(index, small_tensor, accumulate=False)" if device == cuda: command += "; torch.cuda.synchronize()" elif cmd == "accumulate": command = "large_tensor.put_(index, small_tensor, accumulate=True)" if device == cuda: command += "; torch.cuda.synchronize()" elif cmd == "take": command = "torch.take(large_tensor, index)" if device == cuda: command += "; torch.cuda.synchronize()" ipython.magic(f"timeit {command}") print() for method, device in product(["accumulate", "put", "take"], [cpu, cuda]): run_test(3, 1000, 10, device, method) run_test(3, 1000, 1000, device, method) run_test(3, 1000, 10000, device, method) run_test(2, 10000, 100000, device, method) ``` </details> ```python put_(accumulate=False) ``` <details> <summary>ATen CPU (1.5x - 2x speedup)</summary> ```python cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 1.05 µs ± 2.35 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 3.15 µs ± 5.13 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 21.6 µs ± 13.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 238 µs ± 781 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` </details> <details> <summary>TH CPU</summary> ```python cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 722 ns ± 2.67 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 4.89 µs ± 18.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 42.5 µs ± 96.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 428 µs ± 774 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` </details> <details> <summary>ATen GPU (same speed)</summary> ```python cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 8.99 µs ± 16 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 10.4 µs ± 24.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 10.4 µs ± 11.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 15.6 µs ± 1.12 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ``` </details> <details> <summary>TH GPU</summary> ```python cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 8.44 µs ± 31.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 9.09 µs ± 4.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 9.77 µs ± 0.998 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 15.8 µs ± 5.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ``` </details> ```python put_(accumulate=True) ``` <details> <summary>ATen CPU (x2 speedup)</summary> ```python cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 1.12 µs ± 2.91 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 3.14 µs ± 2.05 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 20.8 µs ± 25.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 264 µs ± 263 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` </details> <details> <summary>TH CPU</summary> ```python cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 814 ns ± 1.87 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 5.11 µs ± 6.02 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 43.9 µs ± 49.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 442 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` </details> <details> <summary>ATen GPU (3x - 11x speedup)</summary> ```python cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 9.01 µs ± 14.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 10.4 µs ± 15.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 10.3 µs ± 44.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 12.6 µs ± 19 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ``` </details> <details> <summary>TH GPU</summary> ```python cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 34.7 µs ± 131 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 38.2 µs ± 116 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 61.2 µs ± 50.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 140 µs ± 24.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` </details> ```python take() ``` <details> <summary>ATen CPU (1.1x speedup)</summary> ```python cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 1.18 µs ± 2.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 2.79 µs ± 2.96 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 16.6 µs ± 10.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 161 µs ± 984 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` </details> <details> <summary>TH CPU</summary> ```python cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 1.1 µs ± 3.14 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 2.93 µs ± 7.31 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 18.6 µs ± 14.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 178 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` </details> <details> <summary>ATen GPU (same speed)</summary> ```python cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 9.38 µs ± 23.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 10.7 µs ± 9.77 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 10.6 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 11.5 µs ± 21.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ``` </details> <details> <summary>TH GPU</summary> ```python cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 9.31 µs ± 7.57 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 9.52 µs ± 5.78 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 9.73 µs ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 11.7 µs ± 5.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ``` </details> cc mruberry Pull Request resolved: #53356 Reviewed By: mruberry Differential Revision: D27520243 Pulled By: ngimel fbshipit-source-id: e3979349c2c62d2949e09fb05e5fd4883fbc9093
- Loading branch information