From fd02fc5d715a7647631c5806db736794edc2a52f Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 5 Apr 2021 18:03:59 -0700 Subject: [PATCH] Port put_ and take from TH to ATen (#53356) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The two ports were don together, as they can be implemented with the same kernel. In TH, they were already implemented with the same kernel. Resolves https://github.com/pytorch/pytorch/issues/24751 Resolves https://github.com/pytorch/pytorch/issues/24614 Resolves https://github.com/pytorch/pytorch/issues/24640 Resolves https://github.com/pytorch/pytorch/issues/24772 This port makes sure that it interacts correctly with the "deterministic algorithms" flag, as done in https://github.com/pytorch/pytorch/pull/51388 This PR also makes these two functions correct in the following aspects (all of them added to the tests as well): - Support for complex numbers - Correct handling of scalar inputs and zero-dimensional inputs - Implementation that does not do any copies nor sorting of any of the input tensors - Faster and more correct implementation of the backwards (now it works as it should when `source.shape() != index.shape()`) - Now `put_(..., accumulate=True)` is implemented correctly with atomic operations on GPU / CPU (when possible) and is deterministic (modulo the loss of precision that might happen due to the reordering of a sum of floats) - Adds the `torch.put` function that was missing, (`index_put` exists, for example) - Corrected docs It also adds a much more thorough testing to the operations and their gradients. There is a BC-breaking change, and that is that now we check that the inputs do not overlap in the `put_` operation. This was handled (some of the cases, other cases were wrong) in the TH implementation by making contiguous copies of the inputs. How should we handle this one? **Edit.** Benchmarks:
Script ```python from IPython import get_ipython import torch from itertools import product torch.manual_seed(13) torch.set_num_threads(1) ipython = get_ipython() cpu = torch.device('cpu') cuda = torch.device('cuda') def run_test(ndims, size, index_len, device, cmd): print(f"cmd: {cmd}, ndims: {ndims}, tensor_size: {size}, index_len: {index_len}, device: {device}") large_tensor = torch.rand(*([size] * ndims), device=device) small_tensor = torch.rand((index_len,), device=device) index = torch.randint(size * ndims, (index_len,), dtype=torch.long, device=device) if cmd == "put": command = "large_tensor.put_(index, small_tensor, accumulate=False)" if device == cuda: command += "; torch.cuda.synchronize()" elif cmd == "accumulate": command = "large_tensor.put_(index, small_tensor, accumulate=True)" if device == cuda: command += "; torch.cuda.synchronize()" elif cmd == "take": command = "torch.take(large_tensor, index)" if device == cuda: command += "; torch.cuda.synchronize()" ipython.magic(f"timeit {command}") print() for method, device in product(["accumulate", "put", "take"], [cpu, cuda]): run_test(3, 1000, 10, device, method) run_test(3, 1000, 1000, device, method) run_test(3, 1000, 10000, device, method) run_test(2, 10000, 100000, device, method) ```
```python put_(accumulate=False) ```
ATen CPU (1.5x - 2x speedup) ```python cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 1.05 µs ± 2.35 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 3.15 µs ± 5.13 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 21.6 µs ± 13.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 238 µs ± 781 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) ```
TH CPU ```python cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 722 ns ± 2.67 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 4.89 µs ± 18.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 42.5 µs ± 96.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 428 µs ± 774 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) ```
ATen GPU (same speed) ```python cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 8.99 µs ± 16 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 10.4 µs ± 24.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 10.4 µs ± 11.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 15.6 µs ± 1.12 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ```
TH GPU ```python cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 8.44 µs ± 31.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 9.09 µs ± 4.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 9.77 µs ± 0.998 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 15.8 µs ± 5.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ```
```python put_(accumulate=True) ```
ATen CPU (x2 speedup) ```python cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 1.12 µs ± 2.91 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 3.14 µs ± 2.05 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 20.8 µs ± 25.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 264 µs ± 263 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) ```
TH CPU ```python cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 814 ns ± 1.87 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 5.11 µs ± 6.02 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 43.9 µs ± 49.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 442 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) ```
ATen GPU (3x - 11x speedup) ```python cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 9.01 µs ± 14.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 10.4 µs ± 15.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 10.3 µs ± 44.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 12.6 µs ± 19 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ```
TH GPU ```python cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 34.7 µs ± 131 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 38.2 µs ± 116 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 61.2 µs ± 50.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 140 µs ± 24.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) ```
```python take() ```
ATen CPU (1.1x speedup) ```python cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 1.18 µs ± 2.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 2.79 µs ± 2.96 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 16.6 µs ± 10.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 161 µs ± 984 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) ```
TH CPU ```python cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu 1.1 µs ± 3.14 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu 2.93 µs ± 7.31 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu 18.6 µs ± 14.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu 178 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) ```
ATen GPU (same speed) ```python cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 9.38 µs ± 23.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 10.7 µs ± 9.77 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 10.6 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 11.5 µs ± 21.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ```
TH GPU ```python cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda 9.31 µs ± 7.57 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda 9.52 µs ± 5.78 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda 9.73 µs ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda 11.7 µs ± 5.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ```
cc mruberry Pull Request resolved: https://github.com/pytorch/pytorch/pull/53356 Reviewed By: mruberry Differential Revision: D27520243 Pulled By: ngimel fbshipit-source-id: e3979349c2c62d2949e09fb05e5fd4883fbc9093 --- BUILD.bazel | 1 - aten/src/ATen/LegacyTHFunctionsCPU.cpp | 66 ------- aten/src/ATen/LegacyTHFunctionsCPU.h | 1 - aten/src/ATen/LegacyTHFunctionsCUDA.h | 1 - aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp | 73 ------- .../src/ATen/cuda/detail/OffsetCalculator.cuh | 6 +- .../ATen/native/TensorAdvancedIndexing.cpp | 187 +++++++++--------- aten/src/ATen/native/TensorAdvancedIndexing.h | 4 + aten/src/ATen/native/cpu/IndexKernel.cpp | 122 ++++++++++++ aten/src/ATen/native/cuda/IndexKernel.cu | 159 ++++++++------- aten/src/ATen/native/native_functions.yaml | 16 +- aten/src/ATen/test/scalar_tensor_test.cpp | 186 +++++++++-------- aten/src/TH/generic/THTensorEvenMoreMath.cpp | 55 ------ aten/src/THC/CMakeLists.txt | 3 - aten/src/THC/THCTensorIndex.cu | 119 ----------- aten/src/THC/THCTensorMath.h | 9 - aten/src/THC/generic/THCTensorIndex.cu | 147 -------------- aten/src/THC/generic/THCTensorIndex.h | 23 --- .../check_backward_compatibility.py | 1 + test/test_autograd.py | 26 --- test/test_torch.py | 183 +++++++++++++---- tools/autograd/derivatives.yaml | 6 +- tools/autograd/gen_variable_type.py | 2 +- tools/code_analyzer/default_op_deps.yaml | 6 - torch/__init__.py | 4 + torch/_tensor_docs.py | 25 ++- torch/overrides.py | 1 + .../_internal/common_methods_invocations.py | 82 +++++++- torch/testing/_internal/common_utils.py | 6 +- 29 files changed, 656 insertions(+), 864 deletions(-) delete mode 100644 aten/src/THC/THCTensorIndex.cu delete mode 100644 aten/src/THC/generic/THCTensorIndex.cu delete mode 100644 aten/src/THC/generic/THCTensorIndex.h diff --git a/BUILD.bazel b/BUILD.bazel index 395e74f057ad..9f0d759fce5c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -381,7 +381,6 @@ filegroup( "aten/src/THC/THCStorageCopy.cu.cc", "aten/src/THC/THCTensor.cu.cc", "aten/src/THC/THCTensorCopy.cu.cc", - "aten/src/THC/THCTensorIndex.cu.cc", "aten/src/THC/THCTensorMath.cu.cc", "aten/src/THC/THCTensorMathMagma.cu.cc", "aten/src/THC/THCTensorMathPairwise.cu.cc", diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.cpp b/aten/src/ATen/LegacyTHFunctionsCPU.cpp index 2c02795c5be6..f0c61cef5bab 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.cpp +++ b/aten/src/ATen/LegacyTHFunctionsCPU.cpp @@ -165,72 +165,6 @@ Tensor _th_nonzero(const Tensor & self) { } return result; } -Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Bool: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - THBoolTensor_put(self_, index_, source_, accumulate); - break; - } - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - THByteTensor_put(self_, index_, source_, accumulate); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - THCharTensor_put(self_, index_, source_, accumulate); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - THDoubleTensor_put(self_, index_, source_, accumulate); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - THFloatTensor_put(self_, index_, source_, accumulate); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - THIntTensor_put(self_, index_, source_, accumulate); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - THLongTensor_put(self_, index_, source_, accumulate); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); - THShortTensor_put(self_, index_, source_, accumulate); - break; - } - default: - AT_ERROR("_th_put_ not supported on CPUType for ", dispatch_scalar_type); - } - return self; -} Tensor _th_var(const Tensor & self, bool unbiased) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.h b/aten/src/ATen/LegacyTHFunctionsCPU.h index 7086ece6a335..c89e2a8d5f3d 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.h +++ b/aten/src/ATen/LegacyTHFunctionsCPU.h @@ -22,7 +22,6 @@ Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source); Tensor& _th_nonzero_out(const Tensor& self, Tensor& result); Tensor _th_nonzero(const Tensor & self); -Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate); Tensor _th_var(const Tensor & self, bool unbiased); Tensor _th_std(const Tensor & self, bool unbiased); Tensor & _th_renorm_out(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm, Tensor & result); diff --git a/aten/src/ATen/LegacyTHFunctionsCUDA.h b/aten/src/ATen/LegacyTHFunctionsCUDA.h index deb18d5a3908..98c587e812db 100644 --- a/aten/src/ATen/LegacyTHFunctionsCUDA.h +++ b/aten/src/ATen/LegacyTHFunctionsCUDA.h @@ -20,7 +20,6 @@ namespace cuda { Tensor & _th_masked_fill_(Tensor & self, const Tensor & mask, const Scalar& value); Tensor & _th_masked_fill_bool_(Tensor & self, const Tensor & mask, const Scalar& value); -Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate); std::tuple _th_sort_out(const Tensor & self, int64_t dim, bool descending, Tensor & values, Tensor & indices); std::tuple _th_sort(const Tensor & self, int64_t dim, bool descending); std::tuple _th_sort_out_stable(const Tensor & self, c10::optional stable, int64_t dim, bool descending, Tensor & values, Tensor & indices); diff --git a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp index a6ca71cf9a77..9c2e4d5859bf 100644 --- a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp +++ b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp @@ -39,79 +39,6 @@ namespace { return at::cuda::getCUDADeviceAllocator(); } } -Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Bool: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaBoolTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate); - break; - } - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate); - break; - } - default: - AT_ERROR("_th_put_ not supported on CUDAType for ", dispatch_scalar_type); - } - return self; -} std::tuple _th_sort_out_stable(const Tensor & self, c10::optional stable, int64_t dim, bool descending, Tensor & values, Tensor & indices) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); diff --git a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh index 1d435988f8e9..4cc5a8d0c895 100644 --- a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh +++ b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh @@ -7,8 +7,10 @@ #include #include -/// OffsetCalculator calculates the offset in bytes of a linear index for NARGS -/// operands that share the same shape, but may have different strides. +// If element_sizes is nullptr, then the strides will be in bytes, otherwise +// the strides will be in # of elements. +// Operands that share the same shape, but may have different strides. +// OffsetCalculator iterates the tensor in a column-major order #ifdef __HIP_PLATFORM_HCC__ constexpr int MAX_DIMS = 16; diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index b192f0b932cf..fd3c8d0e8831 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -75,6 +75,8 @@ DEFINE_DISPATCH(index_fill_stub); DEFINE_DISPATCH(index_copy_stub); DEFINE_DISPATCH(index_put_stub); DEFINE_DISPATCH(index_put_accum_stub); +DEFINE_DISPATCH(put_stub); +DEFINE_DISPATCH(take_stub); DEFINE_DISPATCH(masked_fill_stub); REGISTER_NO_CPU_DISPATCH(index_put_accum_stub, index_put_accum_fn); DEFINE_DISPATCH(masked_select_serial_stub); @@ -141,26 +143,6 @@ static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t return index.reshape(shape); } -static ptrdiff_t dataOffset(const Tensor& tensor, ptrdiff_t linearIndex) { - auto size = tensor.sizes(); - auto stride = tensor.strides(); - int nDim = tensor.dim(); - ptrdiff_t dataOffset = 0; - for (int i = nDim - 1; i >= 0; i--) { - dataOffset += (linearIndex % size[i]) * stride[i]; - linearIndex /= size[i]; - } - return dataOffset; -} - -static inline int64_t wrapLinearIndex(int64_t linearIndex, int64_t numel) { - return linearIndex < 0 ? linearIndex + numel : linearIndex; -} - -static inline void checkLinearIndex(int64_t linearIndex, int64_t numel) { - TORCH_CHECK(linearIndex < numel && linearIndex >= -numel, "out of range: ", linearIndex, " out of ", numel); -} - AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) { int64_t element_size_bytes = src.element_size(); @@ -333,6 +315,52 @@ Tensor& index_out(Tensor& result, const Tensor & self, const torch::List>& indices, const Tensor & value, bool accumulate) { return self.clone(at::MemoryFormat::Preserve).index_put_(indices, value, accumulate); } @@ -352,7 +380,7 @@ Tensor & _index_put_impl_(Tensor & self, const torch::List } } - if (accumulate && self.device().type() == kCUDA) { + if (accumulate && self.device().type() == DeviceType::CUDA) { TORCH_CHECK(value.device() == self.device(), "expected device ", self.device(), " but got device ", value.device(), " for value tensor"); index_put_accum_stub(self.device().type(), self, indices, value, unsafe); @@ -365,6 +393,45 @@ Tensor & _index_put_impl_(Tensor & self, const torch::List return self; } +Tensor& take_out(const Tensor& self, const Tensor& index, Tensor& out) { + // Type and device checks + TORCH_CHECK(index.scalar_type() == ScalarType::Long, "take(): Expected a long tensor for index, but got ", index.scalar_type()) + TORCH_CHECK(self.scalar_type() == out.scalar_type(), "take(): self and out expected to have the same dtype, but got self.dtype = ", self.scalar_type(), " and out.dtype = ", out.scalar_type()); + TORCH_CHECK(self.device() == out.device() && self.device() == index.device(), + "take(): self, index and out expected to be in the same device, but got self.device = ", + self.device(), ", index.device = ", index.device(), ", and out.device = ", out.device()); + + // index checks + TORCH_CHECK_INDEX(!(self.numel() == 0 && index.numel() != 0), "take(): tried to take from an empty tensor"); + + at::assert_no_internal_overlap(out); + at::assert_no_overlap(out, index); + at::assert_no_overlap(out, self); + + // Do not iterate over self, we will compute the offsets manually + // out is resized inside tensor_iterator + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .add_output(out) + .add_input(index) + .build(); + + // Early return after out has been resized + if (index.numel() == 0) { + return out; + } + + take_stub(iter.device_type(), iter, self); + + return out; +} + +Tensor take(const Tensor& self, const Tensor& index) { + auto out = at::empty(index.sizes(), self.options()); + at::native::take_out(self, index, out); + return out; +} Tensor & index_put_(Tensor & self, const torch::List>& indices, const Tensor & value, const bool accumulate) { return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false); @@ -1163,84 +1230,6 @@ Tensor masked_select_backward(const Tensor& grad, const Tensor& input, const Ten return result.masked_scatter_(mask, grad); } -void take_out_cpu_template( - Tensor& output, - Tensor const& input, - Tensor const& index) -{ - TORCH_CHECK(output.device().type() == at::kCPU, "device type of output (", output.device().type(), ") is not CPU"); - TORCH_CHECK(input.device().type() == at::kCPU, "device type of input (", input.device().type(), ") is not CPU"); - TORCH_CHECK(index.device().type() == at::kCPU, "device type of index (", index.device().type(), ") is not CPU"); - - TORCH_CHECK(output.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", - output.layout(), " on output tensor"); - TORCH_CHECK(input.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", - input.layout(), " on input tensor"); - TORCH_CHECK(index.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", - index.layout(), " on index tensor"); - - TORCH_CHECK(output.scalar_type() == input.scalar_type(), "output and input scalar type must match.", - "But got different types: ", output.scalar_type(), " and ", input.scalar_type()); - TORCH_CHECK(index.scalar_type() == kLong, "index must be an int64 tensor"); - - output.resize_(index.sizes()); - auto output_contiguous = output.contiguous(); - auto index_continuous = index.contiguous(); - bool is_contiguous = input.is_contiguous(); - auto input_size = input.numel(); - at::assert_no_internal_overlap(output); - at::assert_no_partial_overlap(output, index); - at::assert_no_overlap(output, input); - - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, input.scalar_type(), "take_cpu", [&] { - auto output_data = output_contiguous.data_ptr(); - auto input_data = input.data_ptr(); - auto index_data = index.data_ptr(); - - // Exceptions must not be thrown across parallel sections, so we - // record the position of the invalid index and throw the exception after the - // loop. - std::atomic invalidIdxPos(-1); - - at::parallel_for(0, index.numel(), at::internal::GRAIN_SIZE, - [&](int64_t start, int64_t end) { - for (auto i = start; i < end; i++) { - int64_t idx = index_data[i]; - if (idx < input_size && idx >= -input_size) { - idx = wrapLinearIndex(idx, input_size); - if (is_contiguous) { - output_data[i] = input_data[idx]; - } else { - output_data[i] = input_data[dataOffset(input, idx)]; - } - } else { - int64_t tmp = -1; - invalidIdxPos.compare_exchange_strong(tmp, i); - } - } - }); - - if (invalidIdxPos >= 0) { - checkLinearIndex(index_data[invalidIdxPos], input_size); - } - }); -} - -Tensor take_cpu(const Tensor& self, const Tensor& index) { - auto output = at::empty(index.sizes(), self.options()); - take_out_cpu_template(output, self, index); - return output; -} - -Tensor& take_out_cpu(const Tensor& self, const Tensor& index, Tensor& out) { - take_out_cpu_template(out, self, index); - return out; -} - -Tensor take_backward(const Tensor& grad, const Tensor& input, const Tensor& index) { - return at::zeros_like(input).put_(index, grad, true); -} - namespace { inline std::tuple _take_along_dim_helper( diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.h b/aten/src/ATen/native/TensorAdvancedIndexing.h index bb84f869d7bc..2d20c8694a29 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.h +++ b/aten/src/ATen/native/TensorAdvancedIndexing.h @@ -19,6 +19,8 @@ using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_d using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate); using index_put_accum_fn = void(*)(Tensor &, const c10::List> &, const Tensor &, bool unsafe); using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar); +using put_fn = void(*)(TensorIterator & iter, const Tensor& self, const bool accumulate); +using take_fn = void(*)(TensorIterator & iter, const Tensor& input); using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride); using masked_scatter_fn = void(*)(TensorIterator &, const Tensor &); @@ -36,6 +38,8 @@ DECLARE_DISPATCH(index_fill_fn, index_fill_stub); DECLARE_DISPATCH(index_copy_fn, index_copy_stub); DECLARE_DISPATCH(index_put_fn, index_put_stub); DECLARE_DISPATCH(index_put_accum_fn, index_put_accum_stub); +DECLARE_DISPATCH(put_fn, put_stub); +DECLARE_DISPATCH(take_fn, take_stub); DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub); DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub); DECLARE_DISPATCH(masked_select_fn, masked_select_stub); diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index dff341a7d733..256575eaeb91 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -105,6 +105,126 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde }); } +// Given a linear index, returns the offset of the tensor. +// Implements the same algorithm as its (legacy) GPU version cuda::detail::IndexToOffset +// OffsetCalculator implements yet again the same algorithm but in a column-major order +struct IndexToOffset { + const IntArrayRef sizes; + const IntArrayRef strides; + const int ndim; + IndexToOffset(const Tensor & tensor) : sizes(tensor.sizes()), + strides(tensor.strides()), + ndim(tensor.dim()) {} + + int64_t get(int64_t linear_index) const { + int64_t offset = 0; + for (int i = ndim - 1; i > 0; i--) { + offset += (linear_index % sizes[i]) * strides[i]; + linear_index /= sizes[i]; + } + return offset + linear_index * strides[0]; + } +}; + +template +void cpu_take_put_kernel( + TensorIterator& iter, + const Tensor& indexed, + const func_t& f, + bool serial_execution=false) { + // This kernel follows the same strategy as `cpu_index_kernel` + // Even though the indexed_tensor is const, we modify it through the data_ptr + // This is a bit dirty, but otherwise it would be necessary to innecessarily add tensor + // with zero strides to `iter` which would not be much better + + // When launch the parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE + // to make the whole available thread numbers get more balanced work load and a better cache location. + // The grain size here is chosen by the op benchmark to overcome the thread launch overhead + // Perhaps tweak this number for `put_`? This number was tweaked for `index_put` + constexpr int parallel_grain_size = 3000; + const bool is_contiguous = indexed.is_contiguous(); + const auto numel = indexed.numel(); + const auto offset_indexed = IndexToOffset(indexed); + + auto* indexed_data = indexed.data_ptr(); + auto loop = [&](char** data, const int64_t* strides, int64_t n) { + auto* iterated_data_bytes = data[0]; + auto* index_data_bytes = data[1]; + for (int64_t elem = 0; elem < n; ++elem) { + auto idx = *reinterpret_cast(index_data_bytes); + auto& iterated = *reinterpret_cast(iterated_data_bytes); + + TORCH_CHECK_INDEX(idx >= -numel && idx < numel, + "out of range: tried to access index ", + idx, " on a tensor of ", numel, " elements."); + if (idx < 0) { + idx += numel; + } + if (!is_contiguous) { + idx = offset_indexed.get(idx); + } + f(iterated, indexed_data, idx); + iterated_data_bytes += strides[0]; + index_data_bytes += strides[1]; + } + }; + if (serial_execution) { + iter.serial_for_each(loop, {0, iter.numel()}); + } else { + iter.for_each(loop, parallel_grain_size); + } +} + +void put_kernel( + TensorIterator& iter, + const Tensor & self, + const bool accumulate) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, + iter.dtype(), "take_put_cpu", [&] { + // iter could be const, but for_each does not have a const version + if (accumulate) { + // nb. This deterministic issue the same as that of `index_put_kernel` + // See Note [Enabling Deterministic Operations] + // Parallel cpu_put_kernel with accumulation is nondeterministic, so we + // must enable serial execution if deterministic algorithms are enabled. + bool is_deterministic = at::globalContext().deterministicAlgorithms(); + bool use_parallel_for = (!is_deterministic) && ( + (iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1)); + if (use_parallel_for && iter.dtype() == ScalarType::Float) { + cpu_take_put_kernel(iter, self, + [](float& iterated, float* indexed, const int64_t idx) { + cpu_atomic_add_float(indexed+idx, iterated); + }); + } else { + // TODO: investigate parallelization of the accumulate kernel. + // Unlike the non-accumulate case, this needs to be thread-safe. + cpu_take_put_kernel(iter, self, + [](scalar_t& iterated, scalar_t* indexed, const int64_t idx) { + indexed[idx] += iterated; + }, + /*serial_execution=*/true); + } + } else { + cpu_take_put_kernel(iter, self, + [](scalar_t& iterated, scalar_t* indexed, const int64_t idx) { + indexed[idx] = iterated; + }); + } + }); +} + +void take_kernel( + TensorIterator& iter, + const Tensor & input) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, + iter.dtype(), "take_cpu", [&] { + cpu_take_put_kernel(iter, input, + [](scalar_t& iterated, scalar_t* indexed, const int64_t idx) { + iterated = indexed[idx]; + }); + }); +} + void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) { // NOTE: duplicate indices are only supported if accumulate is true. AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, @@ -414,6 +534,8 @@ REGISTER_DISPATCH(index_stub, &index_kernel); REGISTER_DISPATCH(index_fill_stub, &index_fill_kernel); REGISTER_DISPATCH(index_copy_stub, &index_copy_kernel); REGISTER_DISPATCH(index_put_stub, &index_put_kernel); +REGISTER_DISPATCH(put_stub, &put_kernel); +REGISTER_DISPATCH(take_stub, &take_kernel); REGISTER_DISPATCH(masked_fill_stub, &masked_fill_kernel); REGISTER_DISPATCH(masked_select_serial_stub, &masked_select_serial_kernel); REGISTER_DISPATCH(masked_select_stub, &masked_select_kernel); diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index 89d3233d6f9f..859b46d03f2b 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -11,6 +12,7 @@ #include #include #include +#include #include #include @@ -24,54 +26,6 @@ static constexpr int launch_bound2 = 4; static constexpr int launch_size_nd = 128; -template -__device__ __forceinline__ IndexType indexToOffset( - const cuda::detail::TensorInfo& info, - int64_t index, - IndexType size) { - IndexType linearIndex = static_cast(index); - CUDA_KERNEL_ASSERT(linearIndex < size && linearIndex >= -size); - if (linearIndex < 0) { - linearIndex += size; - } - return cuda::detail::IndexToOffset::get(linearIndex, info); -} - -template -void dispatchTakePutImpl(const Tensor& input, Tensor& output, const Tensor& index) { - auto inputInfo = cuda::detail::getTensorInfo(input); - inputInfo.collapseDims(); - auto numel = input.numel(); - if (inputInfo.isContiguous()) { - cuda::CUDA_tensor_apply2( - output, - index, - [inputInfo, numel] __device__ ( - T & out, const int64_t& idx) { - auto offset = indexToOffset<-2, T, IndexType>(inputInfo, idx, numel); - out = inputInfo.data[offset]; - }); - } else { - cuda::CUDA_tensor_apply2( - output, - index, - [inputInfo, numel] __device__ ( - T & out, const int64_t& idx) { - auto offset = indexToOffset<-1, T, IndexType>(inputInfo, idx, numel); - out = inputInfo.data[offset]; - }); - } -} - -template -void dispatchTakePut(const Tensor& input, Tensor& output, const Tensor& index) { - if (cuda::detail::canUse32BitIndexMath(input)) { - dispatchTakePutImpl(input, output, index); - } else { - dispatchTakePutImpl(input, output, index); - } -} - template C10_LAUNCH_BOUNDS_2(nt, launch_bound2) __global__ void index_elementwise_kernel(int N, func_t f) { @@ -314,50 +268,91 @@ Tensor & masked_select_out_cuda(const Tensor & self, const Tensor & mask, Tensor return masked_select_out_cuda_impl(result, self, mask); } -void take_out_cuda_template(Tensor& output, const Tensor& input, const Tensor& index) { - TORCH_CHECK(output.device().type() == at::kCUDA, "device type of output (", output.device().type(), ") is not GPU"); - TORCH_CHECK(input.device().type() == at::kCUDA, "device type of input (", input.device().type(), ") is not GPU"); - TORCH_CHECK(index.device().type() == at::kCUDA, "device type of index (", index.device().type(), ") is not GPU"); - - TORCH_CHECK(output.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", output.layout(), " on output tensor"); - TORCH_CHECK(input.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", input.layout(), " on input tensor"); - TORCH_CHECK(index.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", index.layout(), " on index tensor"); +template +void cuda_take_put_kernel( + TensorIterator& iter, + const Tensor& indexed, + const func_t& f) { + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + cuda_take_put_kernel(sub_iter, indexed, f); + } + return; + } - TORCH_CHECK(output.scalar_type() == input.scalar_type(), - "output and input scalar type must match. but got different types: ", output.scalar_type(), " and ", input.scalar_type()); - TORCH_CHECK(index.scalar_type() == kLong, "index must be an int64 tensor"); + const auto numel = indexed.numel(); + const bool is_contiguous = indexed.is_contiguous(); - TensorArg output_arg{ output, "output", 1 }; - TensorArg input_arg{ input, "input", 2 }; - TensorArg index_arg{ index, "index", 3 }; - checkAllSameGPU("take", {output_arg, input_arg, index_arg}); + char* __restrict__ iterated_ptr = reinterpret_cast(iter.data_ptr(0)); + char* __restrict__ idx_ptr = reinterpret_cast(iter.data_ptr(1)); - TORCH_CHECK(input.dim() < MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING); - TORCH_CHECK(output.dim() < MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING); - TORCH_CHECK(index.dim() < MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING); + const auto offset_calc = make_offset_calculator<2>(iter); + using uindex_t = std::make_unsigned_t; - TORCH_CHECK(!(input.numel() == 0 && index.numel() != 0), "tried to take from an empty tensor"); + // OffsetCalculator needs the sizes and strides reveresed + const auto indexed_sizes = std::vector(indexed.sizes().rbegin(), indexed.sizes().rend()); + const auto indexed_strides = std::vector(indexed.strides().rbegin(), indexed.strides().rend()); + const auto* indexed_strides_data = indexed_strides.data(); + const auto offset_indexed = OffsetCalculator<1, uindex_t>(indexed.dim(), + indexed_sizes.data(), + &indexed_strides_data); - at::assert_no_internal_overlap(output); - at::assert_no_partial_overlap(output, index); - at::assert_no_overlap(output, input); + auto loop = [=]C10_DEVICE(int i) { + auto offsets = offset_calc.get(i); - output.resize_(index.sizes()); + auto& iterated = *reinterpret_cast(iterated_ptr + offsets[0]); + const auto idx = *reinterpret_cast(idx_ptr + offsets[1]); + CUDA_KERNEL_ASSERT(idx < numel && idx >= -numel && "cuda_take_put_kernel() index out of bounds"); + index_t offset = static_cast(idx); + if (offset < 0) { + offset += numel; + } + if (!is_contiguous) { + offset = offset_indexed.get(offset)[0]; + } - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, input.scalar_type(), "take_cuda", [&] { - dispatchTakePut(input, output, index); - }); + f(iterated, offset); + }; + launch_kernel(iter.numel(), loop); } -Tensor take_cuda(const Tensor& self, const Tensor& index) { - auto out = at::empty(index.sizes(), self.options()); - take_out_cuda_template(out, self, index); - return out; +void put_kernel(TensorIterator& iter, const Tensor& output, const bool accumulate) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "put_cuda", [&] { + // Cannot use `OpaqueType`, as we need the actual type for `fastSpecializedgpuAtomicAdd` + AT_DISPATCH_INDEX_TYPES(cuda::detail::canUse32BitIndexMath(output) ? ScalarType::Int : ScalarType::Long, + "put_cuda_index", [&] { + auto* __restrict__ indexed_ptr = output.template data(); + if (accumulate) { + const auto numel = output.numel(); + cuda_take_put_kernel(iter, output, + [numel, indexed_ptr] __device__(scalar_t& iterated, const index_t offset) { + fastSpecializedAtomicAdd(indexed_ptr, offset, numel, iterated); + }); + } + else { + cuda_take_put_kernel(iter, output, + [indexed_ptr] __device__(scalar_t& iterated, const index_t offset) { + indexed_ptr[offset] = iterated; + }); + } + }); + }); } -Tensor& take_out_cuda(const Tensor& self, const Tensor& index, Tensor& out) { - take_out_cuda_template(out, self, index); - return out; +void take_kernel( + TensorIterator& iter, + const Tensor& input) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "take_cuda", [&] { + // Cannot use `OpaqueType`, as Tensor::data_ptr> is not implemented + AT_DISPATCH_INDEX_TYPES(cuda::detail::canUse32BitIndexMath(input) ? ScalarType::Int : ScalarType::Long, + "take_cuda_index", [&] { + const auto* __restrict__ indexed_ptr = input.template data(); + cuda_take_put_kernel(iter, input, + [indexed_ptr] __device__(scalar_t& iterated, const index_t offset) { + iterated = indexed_ptr[offset]; + }); + }); + }); } namespace { @@ -464,5 +459,7 @@ REGISTER_DISPATCH(index_stub, &index_kernel); REGISTER_DISPATCH(index_fill_stub, &index_fill_kernel); REGISTER_DISPATCH(index_copy_stub, &index_copy_kernel); REGISTER_DISPATCH(index_put_stub, &index_put_kernel); +REGISTER_DISPATCH(put_stub, &put_kernel); +REGISTER_DISPATCH(take_stub, &take_kernel); }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b26d7eddc6e5..d7e5b083b8f4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4724,8 +4724,10 @@ - func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) variants: method dispatch: - CPU: legacy::cpu::_th_put_ - CUDA: legacy::cuda::_th_put_ + CPU, CUDA: put_ + +- func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor + variants: function, method - func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) variants: method @@ -5424,18 +5426,12 @@ - func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: take_out_cpu - CUDA: take_out_cuda + CPU, CUDA: take_out - func: take(Tensor self, Tensor index) -> Tensor variants: method, function dispatch: - CPU: take_cpu - CUDA: take_cuda - -- func: take_backward(Tensor grad, Tensor input, Tensor index) -> Tensor - variants: function - device_guard: False + CPU, CUDA: take - func: take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) diff --git a/aten/src/ATen/test/scalar_tensor_test.cpp b/aten/src/ATen/test/scalar_tensor_test.cpp index 8bf8225d9312..36236cfd3867 100644 --- a/aten/src/ATen/test/scalar_tensor_test.cpp +++ b/aten/src/ATen/test/scalar_tensor_test.cpp @@ -163,109 +163,121 @@ void test(DeprecatedTypeProperties &T) { } } // forced size functions (resize_, resize_as, set_) - {// resize_ - {auto lhs = ones(*lhs_it, T); - auto rhs = ones(*rhs_it, T); - lhs.resize_(*rhs_it); - require_equal_size_dim(lhs, rhs); - } - // resize_as_ - { - auto lhs = ones(*lhs_it, T); - auto rhs = ones(*rhs_it, T); - lhs.resize_as_(rhs); - require_equal_size_dim(lhs, rhs); - } - // set_ - { + // resize_ { - // with tensor - auto lhs = ones(*lhs_it, T); - auto rhs = ones(*rhs_it, T); - lhs.set_(rhs); - require_equal_size_dim(lhs, rhs); + { + auto lhs = ones(*lhs_it, T); + auto rhs = ones(*rhs_it, T); + lhs.resize_(*rhs_it); + require_equal_size_dim(lhs, rhs); + } + // resize_as_ + { + auto lhs = ones(*lhs_it, T); + auto rhs = ones(*rhs_it, T); + lhs.resize_as_(rhs); + require_equal_size_dim(lhs, rhs); + } + // set_ + { + { + // with tensor + auto lhs = ones(*lhs_it, T); + auto rhs = ones(*rhs_it, T); + lhs.set_(rhs); + require_equal_size_dim(lhs, rhs); + } + { + // with storage + auto lhs = ones(*lhs_it, T); + auto rhs = ones(*rhs_it, T); + lhs.set_(rhs.storage()); + // should not be dim 0 because an empty storage is dim 1; all other + // storages aren't scalars + ASSERT_NE(lhs.dim(), 0); + } + { + // with storage, offset, sizes, strides + auto lhs = ones(*lhs_it, T); + auto rhs = ones(*rhs_it, T); + lhs.set_(rhs.storage(), rhs.storage_offset(), rhs.sizes(), rhs.strides()); + require_equal_size_dim(lhs, rhs); + } + } } + + // view { - // with storage auto lhs = ones(*lhs_it, T); auto rhs = ones(*rhs_it, T); - lhs.set_(rhs.storage()); - // should not be dim 0 because an empty storage is dim 1; all other - // storages aren't scalars - ASSERT_NE(lhs.dim(), 0); + auto rhs_size = *rhs_it; + TRY_CATCH_ELSE(auto result = lhs.view(rhs_size), + ASSERT_NE(lhs.numel(), rhs.numel()), + ASSERT_EQ(lhs.numel(), rhs.numel()); + require_equal_size_dim(result, rhs);); } + + // take { - // with storage, offset, sizes, strides auto lhs = ones(*lhs_it, T); - auto rhs = ones(*rhs_it, T); - lhs.set_(rhs.storage(), rhs.storage_offset(), rhs.sizes(), rhs.strides()); - require_equal_size_dim(lhs, rhs); + auto rhs = zeros(*rhs_it, T).toType(ScalarType::Long); + TRY_CATCH_ELSE(auto result = lhs.take(rhs), + ASSERT_EQ(lhs.numel(), 0); ASSERT_NE(rhs.numel(), 0), + require_equal_size_dim(result, rhs)); } - } - } - - // view - { - auto lhs = ones(*lhs_it, T); - auto rhs = ones(*rhs_it, T); - auto rhs_size = *rhs_it; - TRY_CATCH_ELSE(auto result = lhs.view(rhs_size), - ASSERT_NE(lhs.numel(), rhs.numel()), - ASSERT_EQ(lhs.numel(), rhs.numel()); - require_equal_size_dim(result, rhs);); - } - // take - { - auto lhs = ones(*lhs_it, T); - auto rhs = zeros(*rhs_it, T).toType(ScalarType::Long); - TRY_CATCH_ELSE(auto result = lhs.take(rhs), ASSERT_EQ(lhs.numel(), 0); - ASSERT_NE(rhs.numel(), 0), - require_equal_size_dim(result, rhs)); - } + // put + { + auto lhs = ones(*lhs_it, T); + auto rhs1 = zeros(*rhs_it, T).toType(ScalarType::Long); + auto rhs2 = zeros(*rhs_it, T); + TRY_CATCH_ELSE(auto result = lhs.put(rhs1, rhs2), + ASSERT_EQ(lhs.numel(), 0); ASSERT_NE(rhs1.numel(), 0), + require_equal_size_dim(result, lhs)); + } - // ger - { - auto lhs = ones(*lhs_it, T); - auto rhs = ones(*rhs_it, T); - TRY_CATCH_ELSE(auto result = lhs.ger(rhs), - ASSERT_TRUE( - (lhs.numel() == 0 || rhs.numel() == 0 || - lhs.dim() != 1 || rhs.dim() != 1)), - [&]() { - int64_t dim0 = lhs.dim() == 0 ? 1 : lhs.size(0); - int64_t dim1 = rhs.dim() == 0 ? 1 : rhs.size(0); - require_equal_size_dim( - result, at::empty({dim0, dim1}, result.options())); - }();); - } + // ger + { + auto lhs = ones(*lhs_it, T); + auto rhs = ones(*rhs_it, T); + TRY_CATCH_ELSE(auto result = lhs.ger(rhs), + ASSERT_TRUE( + (lhs.numel() == 0 || rhs.numel() == 0 || + lhs.dim() != 1 || rhs.dim() != 1)), + [&]() { + int64_t dim0 = lhs.dim() == 0 ? 1 : lhs.size(0); + int64_t dim1 = rhs.dim() == 0 ? 1 : rhs.size(0); + require_equal_size_dim( + result, at::empty({dim0, dim1}, result.options())); + }();); + } - // expand - { - auto lhs = ones(*lhs_it, T); - auto lhs_size = *lhs_it; - auto rhs = ones(*rhs_it, T); - auto rhs_size = *rhs_it; - bool should_pass = should_expand(lhs_size, rhs_size); - TRY_CATCH_ELSE(auto result = lhs.expand(rhs_size), - ASSERT_FALSE(should_pass), - ASSERT_TRUE(should_pass); - require_equal_size_dim(result, rhs);); + // expand + { + auto lhs = ones(*lhs_it, T); + auto lhs_size = *lhs_it; + auto rhs = ones(*rhs_it, T); + auto rhs_size = *rhs_it; + bool should_pass = should_expand(lhs_size, rhs_size); + TRY_CATCH_ELSE(auto result = lhs.expand(rhs_size), + ASSERT_FALSE(should_pass), + ASSERT_TRUE(should_pass); + require_equal_size_dim(result, rhs);); - // in-place functions (would be good if we can also do a non-broadcasting - // one, b/c broadcasting functions will always end up operating on tensors - // of same size; is there an example of this outside of assign_ ?) - { - bool should_pass_inplace = should_expand(rhs_size, lhs_size); - TRY_CATCH_ELSE(lhs.add_(rhs), - ASSERT_FALSE(should_pass_inplace), - ASSERT_TRUE(should_pass_inplace); - require_equal_size_dim(lhs, ones(*lhs_it, T));); + // in-place functions (would be good if we can also do a non-broadcasting + // one, b/c broadcasting functions will always end up operating on tensors + // of same size; is there an example of this outside of assign_ ?) + { + bool should_pass_inplace = should_expand(rhs_size, lhs_size); + TRY_CATCH_ELSE(lhs.add_(rhs), + ASSERT_FALSE(should_pass_inplace), + ASSERT_TRUE(should_pass_inplace); + require_equal_size_dim(lhs, ones(*lhs_it, T));); + } + } } } } -} -} TEST(TestScalarTensor, TestScalarTensorCPU) { manual_seed(123); diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index ad6b34218c92..0501e8a950fb 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -73,59 +73,4 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) #undef IS_NONZERO } - -#if !defined(TH_REAL_IS_HALF) /* non half part */ - -#if !defined(TH_REAL_IS_BFLOAT16) /* non bfloat16 part*/ - -static ptrdiff_t THTensor_(dataOffset)(THTensor* tensor, ptrdiff_t linearIndex) { - auto size = THTensor_sizesLegacyNoScalars(tensor); - auto stride = THTensor_stridesLegacyNoScalars(tensor); - int nDim = THTensor_nDimensionLegacyAll(tensor); - ptrdiff_t dataOffset = 0; - for (int i = nDim - 1; i >= 0; i--) { - dataOffset += (linearIndex % size[i]) * stride[i]; - linearIndex /= size[i]; - } - return dataOffset; -} - -static inline void THTensor_(checkLinearIndex)(int64_t linearIndex, int64_t numel) { - THArgCheck(linearIndex < numel && linearIndex >= -numel, 2, "out of range: %d out of %d", (int)linearIndex, (int)numel); -} - -static inline int64_t THTensor_(wrapLinearIndex)(int64_t linearIndex, int64_t numel) { - return linearIndex < 0 ? linearIndex + numel : linearIndex; -} - -void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int accumulate) -{ - THArgCheck(THLongTensor_nElement(index) == THTensor_(nElement)(src), 3, - "src should have the same number of elements as index"); - - index = THLongTensor_newContiguous(index); - src = THTensor_(newContiguous)(src); - scalar_t* data = tensor->data(); - ptrdiff_t numel = THTensor_(nElement)(tensor); - int is_contiguous = THTensor_(isContiguous)(tensor); - - TH_TENSOR_APPLY2(int64_t, index, scalar_t, src, - THTensor_(checkLinearIndex)(*index_data, numel); - int64_t linearIndex = THTensor_(wrapLinearIndex)(*index_data, numel); - int64_t dataOffset = is_contiguous ? linearIndex : THTensor_(dataOffset)(tensor, linearIndex); - if (accumulate) { - data[dataOffset] += *src_data; - } else { - data[dataOffset] = *src_data; - } - ); - - c10::raw::intrusive_ptr::decref(src); - THLongTensor_free(index); -} - -#endif - -#endif - #endif /* TH_GENERIC_FILE */ diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt index 45229806306e..8a9a078405ca 100644 --- a/aten/src/THC/CMakeLists.txt +++ b/aten/src/THC/CMakeLists.txt @@ -50,7 +50,6 @@ set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathMagma.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathPairwise.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathReduce.cu - ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorIndex.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorScatterGather.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorSort.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCSortUtils.cu @@ -138,8 +137,6 @@ install(FILES generic/THCTensorMathReduce.cu generic/THCTensorScatterGather.h generic/THCTensorScatterGather.cu - generic/THCTensorIndex.h - generic/THCTensorIndex.cu generic/THCTensorSort.h generic/THCTensorSort.cu generic/THCTensorTopK.h diff --git a/aten/src/THC/THCTensorIndex.cu b/aten/src/THC/THCTensorIndex.cu deleted file mode 100644 index 35406f0a915a..000000000000 --- a/aten/src/THC/THCTensorIndex.cu +++ /dev/null @@ -1,119 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include // for std::min -#include -#include - -template -__device__ __forceinline__ IndexType indexToOffset( - const TensorInfo& info, - int64_t index, - IndexType size) -{ - IndexType linearIndex = static_cast(index); - CUDA_KERNEL_ASSERT(linearIndex < size && linearIndex >= -size); - if (linearIndex < 0) { - linearIndex += size; - } - return IndexToOffset::get(linearIndex, info); -} - -struct WrapIndexOp { - WrapIndexOp(int64_t size) : size(size) {} - - __device__ __forceinline__ void operator()(int64_t* out, int64_t* in) { - auto idx = *in; - CUDA_KERNEL_ASSERT(idx < size && idx >= -size); - *out = idx < 0 ? idx + size : idx; - } - - int64_t size; -}; - -template -struct TensorPutOp { - TensorPutOp(TensorInfo info, IndexType numel, int64_t*, int64_t*) - : info(info), numel(numel) {} - - __device__ __forceinline__ void operator()(T* value, int64_t* index) { - auto offset = indexToOffset(info, *index, numel); - info.data[offset] = *value; - } - - const TensorInfo info; - IndexType numel; -}; - -template -struct TensorPutAccumulateOp { - TensorPutAccumulateOp(TensorInfo info, IndexType numel, int64_t* start, int64_t* end) - : info(info), numel(numel), start(start), end(end) {} - - __device__ __forceinline__ void operator()(T* value, int64_t* index) { - if (index == start || *index != *(index - 1)) { - int64_t linear_index = *index; - auto offset = indexToOffset(info, linear_index, numel); - do { - info.data[offset] = THCNumerics::add(info.data[offset], *value); - index++; - value++; - } while (index != end && *index == linear_index); - } - } - - const TensorInfo info; - IndexType numel; - int64_t* start; - int64_t* end; -}; - - -template class Op, typename TensorType> -void dispatchTakePutImpl(THCState *state, TensorType *a, TensorType *b, THCudaLongTensor *index) { - // These are only valid if index is contiguous - auto start = THCudaLongTensor_data(state, index); - auto end = start + THCudaLongTensor_numel(state, index); - - auto aInfo = getTensorInfo(state, a); - aInfo.collapseDims(); - auto numel = THCTensor_nElement(state, a); - if (aInfo.isContiguous()) { - auto op = Op(aInfo, numel, start, end); - THC_pointwiseApply2(state, b, index, op); - } else { - auto op = Op(aInfo, numel, start, end); - THC_pointwiseApply2(state, b, index, op); - } -} - -template class Op, typename TensorType> -void dispatchTakePut(THCState *state, TensorType *a, TensorType *b, THCudaLongTensor *index) { - if (THCTensor_canUse32BitIndexMath(state, a, INT_MAX)) { - dispatchTakePutImpl(state, a, b, index); - } else { - dispatchTakePutImpl(state, a, b, index); - } -} - -#include -#include - -#include -#include - -#include -#include diff --git a/aten/src/THC/THCTensorMath.h b/aten/src/THC/THCTensorMath.h index 3802fce089b1..cb79cea52617 100644 --- a/aten/src/THC/THCTensorMath.h +++ b/aten/src/THC/THCTensorMath.h @@ -43,15 +43,6 @@ #include #include -#include -#include - -#include -#include - -#include -#include - #include #include diff --git a/aten/src/THC/generic/THCTensorIndex.cu b/aten/src/THC/generic/THCTensorIndex.cu deleted file mode 100644 index e913e2b1da10..000000000000 --- a/aten/src/THC/generic/THCTensorIndex.cu +++ /dev/null @@ -1,147 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorIndex.cu" -#else - -#include -#include -#include - -// Check tensor dimensions for index operations, and return the slice size. -static ptrdiff_t THCTensor_(getSliceSize)(THCState *state, THCTensor *dst, - int dim, - THCudaLongTensor *index, - THCTensor *src) -{ - int dstDims = THCTensor_(nDimensionLegacyNoScalars)(state, dst); - int srcDims = (src == nullptr) ? dstDims : THCTensor_(nDimensionLegacyNoScalars)(state, src); - - THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, index) == 1, 4, - "expecting vector of indices"); - THArgCheck(dim >= 0 && dim < dstDims, 2, "Indexing dim is out of bounds"); - - ptrdiff_t dstSliceSize = 1; - for (int d = 0; d < dstDims; d++) { - if (d != dim) { - dstSliceSize *= THTensor_sizeLegacyNoScalars(dst, d); - } - } - - if (src == nullptr) return dstSliceSize; - - THArgCheck(dim < srcDims, 3, "Indexing dim is out of bounds"); - THArgCheck(THCudaLongTensor_nElement(state, index) == THTensor_sizeLegacyNoScalars(src, dim), 4, - "length of src.size[dim] is not equal to length of indices"); - - ptrdiff_t srcSliceSize = 1; - bool mismatch = false; - - if (dstDims != srcDims) mismatch = true; - - for (int d = 0; d < srcDims; d++) { - if (d != dim) { - srcSliceSize *= THTensor_sizeLegacyNoScalars(src, d); - if (!mismatch && THTensor_sizeLegacyNoScalars(dst, d) != THTensor_sizeLegacyNoScalars(src, d)) mismatch = true; - } - } - - THArgCheck(dstSliceSize == srcSliceSize, 2, - "Source/destination tensor have different slice sizes (%ld vs %ld)", - dstSliceSize, srcSliceSize); - - if (mismatch) { - static bool warningShown = false; - if (!warningShown) { - warningShown = true; - fprintf(stderr, - "Warning: source/destination slices have same size but different " - "shape for an index operation. This behavior is deprecated.\n"); - } - } - - return dstSliceSize; -} - -// Compare the stride between adjacent slices (sliceStride) with strides in the -// other dimensions (i.e., strides *inside* each slice). -// -// - Returns true if some dimension inside the slice has lower stride than -// sliceStride. The simplest example is a 2-D contiguous tensor with sliceDim -// == 0 (that is, each slice is a row). -// -// In this case, we choose the CUDA kernel that processes the data in -// "index-major order". For example, if thread count equals slice size, then -// all threads process slice #0 in lockstep, and then slice #1, and so on. -// -// - Otherwise (i.e., sliceStride has the lowest value), this function returns -// false. The simplest example is a 2-D contiguous tensor with sliceDim == 1 -// (each slice is a column). -// -// In this case, we choose the CUDA kernel that processes the data in -// "elementInSlice-major order". For example, each thread can process element -// #0 of every slice, and then element #1 of every slice, and so on. -bool THCTensor_(indexShouldBeMajor)(TensorInfo &info, - int sliceDim) -{ - // The stride between adjacent slices (e.g., between element #0 of slice #100 - // and element #0 of slice #101). - unsigned int sliceStride = info.strides[sliceDim]; - - for (int i = 0; i < info.dims; ++i) { - if (i != sliceDim && info.sizes[i] > 1 && info.strides[i] < sliceStride) { - return true; - } - } - - return false; -} - -static void THCTensor_(sort_indices)(THCState *state, THCudaLongTensor *index, THCTensor *src) { - THCThrustAllocator thrustAlloc(state); - - auto index_iter = thrust::device_ptr(THCudaLongTensor_data(state, index)); - auto src_iter = thrust::device_ptr(THCTensor_(data)(state, src)); - auto numel = THCTensor_(numel)(state, src); - - thrust::sort_by_key( - thrust::cuda::par(thrustAlloc).on(c10::cuda::getCurrentCUDAStream()), - index_iter, index_iter + numel, - src_iter, ThrustLTOp()); -} - -void THCTensor_(put)(THCState *state, THCTensor *dst, THCudaLongTensor *index, THCTensor *src, int accumulate) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, dst, src)); - THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index)); - - ptrdiff_t dstSize = THCTensor_(nElement)(state, dst); - ptrdiff_t numIndices = THCudaLongTensor_nElement(state, index); - THArgCheck(THCTensor_(nElement)(state, src) == numIndices, - 3, "src should have the same number of elements as index"); - - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, dst) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, src) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, index) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - - if (numIndices == 0) { - return; - } - - if (accumulate) { - // wrap indices so to replace negative indices - THCudaLongTensor* sorted_index = THCudaLongTensor_new(state); - THCudaLongTensor_resizeAs(state, sorted_index, index); - THC_pointwiseApply2(state, sorted_index, index, WrapIndexOp(dstSize)); - - THCTensor* sorted_src = THCTensor_(newClone)(state, src); - - THCTensor_(sort_indices)(state, sorted_index, sorted_src); - dispatchTakePut(state, dst, sorted_src, sorted_index); - - THCTensor_(free)(state, sorted_src); - THCudaLongTensor_free(state, sorted_index); - } else { - dispatchTakePut(state, dst, src, index); - } -} - -#endif diff --git a/aten/src/THC/generic/THCTensorIndex.h b/aten/src/THC/generic/THCTensorIndex.h deleted file mode 100644 index c77061e8881c..000000000000 --- a/aten/src/THC/generic/THCTensorIndex.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorIndex.h" -#else - -TORCH_CUDA_CU_API void THCTensor_(indexSelect)( - THCState* state, - THCTensor* tensor, - THCTensor* src, - int dim, - THCudaLongTensor* index); -TORCH_CUDA_CU_API void THCTensor_(take)( - THCState* state, - THCTensor* res_, - THCTensor* src, - THCudaLongTensor* index); -TORCH_CUDA_CU_API void THCTensor_(put)( - THCState* state, - THCTensor* res_, - THCudaLongTensor* indices, - THCTensor* src, - int accumulate); - -#endif diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index abbb6926111e..b6fa4356ddca 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -87,6 +87,7 @@ ("aten::cumprod_backward", datetime.date(2021, 5, 1)), ("aten::_triangular_solve_helper", datetime.date(9999, 1, 1)), ("aten::complex*", datetime.date(2021, 5, 1)), + ("aten::take_backward", datetime.date(2021, 5, 1)), ] def allow_listed(schema, allow_list): diff --git a/test/test_autograd.py b/test/test_autograd.py index 73416776dfc7..73f17018f279 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2087,32 +2087,6 @@ def test_unbind(self): for j in range(3)], dim=0) self.assertEqual(g, g_expected) - def test_put(self): - root = torch.randn(4, 5, requires_grad=True) - values = torch.randn(6, requires_grad=True) - idx = Variable(torch.LongTensor([1, 2, 3, -1, -2, -3])) - - def func(root, values): - x = root.clone() - x.put_(idx, values) - return x - - gradcheck(func, [root, values]) - gradgradcheck(func, [root, values]) - - def test_put_accumulate(self): - root = torch.randn(4, 5, requires_grad=True) - values = torch.randn(6, requires_grad=True) - idx = Variable(torch.LongTensor([1, 2, 3, 1, 2, 3])) - - def func(root, values): - x = root.clone() - x.put_(idx, values, accumulate=True) - return x - - gradcheck(func, [root, values]) - gradgradcheck(func, [root, values]) - def test_fill(self): root = torch.randn(4, 5, requires_grad=True) diff --git a/test/test_torch.py b/test/test_torch.py index 9c2bb28f2d94..5c66c58a69d7 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -16,6 +16,7 @@ from torch.utils.dlpack import from_dlpack, to_dlpack from torch._six import inf, nan, string_classes from itertools import product, combinations, permutations +from functools import partial from torch import multiprocessing as mp from torch.testing._internal.common_utils import ( TestCase, TEST_WITH_ROCM, run_tests, @@ -950,44 +951,6 @@ def test_index_add_all_dtypes(self): added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor, alpha=-1) self.assertEqual(added, -tensor) - def test_take(self): - def check(src, idx): - expected = src.contiguous().view(-1).index_select( - 0, idx.contiguous().view(-1)).view_as(idx) - actual = src.take(idx) - self.assertEqual(actual.size(), idx.size()) - self.assertEqual(expected, actual) - - src = torch.randn(2, 3, 5) - idx = torch.LongTensor([[0, 2], [3, 4]]) - check(src, idx) - check(src.transpose(1, 2), idx) - check(src.bool(), idx) - - def test_put_(self): - def check(dst, idx, value): - expected = dst.clone(memory_format=torch.contiguous_format).view(-1).index_copy_( - 0, idx.contiguous().view(-1), value.contiguous().view(-1)) - expected = expected.view_as(dst) - dst.put_(idx, value) - self.assertEqual(expected, dst) - - dst = torch.randn(2, 3, 5) - idx = torch.LongTensor([[0, 2], [3, 4]]) - values = torch.randn(2, 2) - check(dst, idx, values) - check(dst.transpose(1, 2), idx, values) - - values = torch.tensor([[False, False], [False, False]]) - check(dst.bool(), idx, values) - - def test_put_accumulate(self): - dst = torch.ones(2, 2) - idx = torch.LongTensor([[0, 1], [0, 1]]) - src = torch.Tensor([1, 2, 3, 4]) - dst.put_(idx, src, accumulate=True) - self.assertEqual(dst.tolist(), [[5, 7], [1, 1]]) - # Fill idx with valid indices. @staticmethod def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o): @@ -4672,6 +4635,131 @@ def ref_index_select(src, dim, idx): out = source.index_select(0, idx) self.assertEqual(out.item(), source.item()) + @dtypes(*torch.testing.get_all_dtypes()) + def test_take(self, device, dtype): + idx_size = (4,) + + make_arg = partial(make_tensor, device=device, dtype=dtype) + make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64) + + def ref_take(src, idx): + if dtype == torch.bfloat16: + src = src.half() + src = src.cpu().numpy() + idx = idx.cpu().numpy() + out = torch.from_numpy(np.take(src, idx)).to(device=device, dtype=dtype) + return out + + for src_contig, idx_contig, idx_reshape in product([True, False], repeat=3): + for src_size in ((5,), (4, 5)): + src = make_arg(src_size, discontiguous=not src_contig) + idx = make_idx(idx_size, high=src.numel(), discontiguous=not idx_contig) + if idx_reshape: + idx = idx.reshape(2, 2) + out = torch.take(src, idx) + out2 = ref_take(src, idx) + self.assertEqual(out, out2) + + # Create the 4 possible combinations of scalar sizes for source / index + for size_s, size_i in product([(), (1,)], repeat=2): + source = make_arg(size_s) + idx = make_idx(size_i, high=1) + out = source.take(idx) + self.assertEqual(out.item(), source.item()) + + # The bool instance does not work on GPU. See + # https://github.com/pytorch/pytorch/issues/54317 + @dtypes(*torch.testing.get_all_dtypes(include_bool=False)) + def test_put(self, device, dtype): + src_size = (4,) + + make_arg = partial(make_tensor, device=device, dtype=dtype) + make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64) + + def ref_put(dst, idx, src, accumulate): + new_dst = dst.clone(memory_format=torch.contiguous_format).view(-1) + new_idx = idx.contiguous().view(-1) + new_src = src.contiguous().view(-1) + method = new_dst.index_add_ if accumulate else new_dst.index_copy_ + return method(0, new_idx, new_src).view_as(dst) + + for dst_contig, src_contig, idx_contig, idx_reshape, accumulate in product([True, False], repeat=5): + for dst_size in ((5,), (4, 5)): + dst = make_arg(dst_size, discontiguous=not dst_contig) + src = make_arg(src_size, discontiguous=not src_contig) + + # If accumulate=True, `put_` should be deterministic regardless of the inputs on CPU + # On CUDA it may not be, but the test has enough tolerance to account for this + if accumulate: + idx = make_idx(src_size, high=dst.numel()) + else: + idx = torch.randperm(dst.numel(), dtype=torch.int64, device=device)[:src_size[0]] + if not idx_contig: + idx = torch.repeat_interleave(idx, 2, dim=-1)[..., ::2] + if idx_reshape: + idx = idx.reshape(2, 2) + out = torch.put(dst, idx, src, accumulate) + # out-place + reference = ref_put(dst, idx, src, accumulate) + self.assertEqual(out, reference) + + # in-place + dst.put_(idx, src, accumulate) + self.assertEqual(dst, reference) + + + # Create the 8 possible combinations of scalar sizes for target / index / source + scalars = ((make_arg(size_t), + make_idx(size_i, high=1), + make_arg(size_s)) + for size_t, size_i, size_s in product([(), (1,)], repeat=3)) + for (dest, idx, source), accumulate in product(scalars, [True, False]): + dest_init = dest.clone() + # out-place + out = torch.put(dest, idx, source, accumulate=accumulate) + # in-place + dest1 = dest.clone() + dest1.put_(idx, source, accumulate=accumulate) + for d in [out, dest1]: + if accumulate: + self.assertEqual(d.item(), (dest_init + source).item()) + else: + self.assertEqual(d.item(), source.item()) + + # Empty case + dest = make_arg((3, 2)) + reference = dest.clone() + idx = make_idx((0,), high=1) + source = make_arg((0,)) + for accumulate in [True, False]: + out = torch.put(dest, idx, source, accumulate=accumulate) + self.assertEqual(out, reference) + dest.put_(idx, source, accumulate=accumulate) + self.assertEqual(dest, reference) + + # The bool instance does not work on GPU. See + # https://github.com/pytorch/pytorch/issues/54317 + @dtypes(*torch.testing.get_all_dtypes(include_bool=False)) + def test_put_accumulate(self, device, dtype): + # Test for parallel adds with accumulate == True + low_precision = dtype == torch.half or dtype == torch.bfloat16 + # Less numbers to avoid overflow with low_precision + # Grainsize is 3000 for the for_loop to be parallized on CPU + sizes = ((100,)) if low_precision else ((200,), (3002,)) + # Bfloat16 has a particularly bad performance here + # This operation is nondeterministic on GPU, so we are generous with the rtol + rtol, atol = (1e-1, 1e-2) if low_precision else (1e-3, 1e-4) + + make_arg = partial(make_tensor, low=-2, high=3, device=device, dtype=dtype) + # Dump everything into the 0-th position + make_idx = partial(torch.zeros, device=device, dtype=torch.int64) + args = ((make_idx(size), make_arg(size)) for size in sizes) + + for idx, source in args: + orig = make_arg((1,)) + out = orig.put(idx, source, accumulate=True) + self.assertEqual(out, orig + source.sum(), rtol=rtol, atol=atol) + def test_take_empty(self, device): for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]: for indices_shape in [(0,), (0, 1, 2, 0)]: @@ -5412,6 +5500,25 @@ def test_bernoulli_mem_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): torch.bernoulli(torch.rand_like(x), out=x) + @onlyOnCPUAndCUDA + def test_put_mem_overlap(self, device): + x = torch.rand((1,), device=device).expand((6,)) + y = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device) + value = torch.rand((3,), device=device) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + x.put_(ind, value) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + y.put_(ind[0], y[0]) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.put_(ind, ind) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + y.put_(ind, y[:3]) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.put_(ind, ind.clone()) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.put_(ind.clone(), ind) + @onlyOnCPUAndCUDA def test_index_put_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index b346274601d2..bf1828cd4e32 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -912,9 +912,9 @@ self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim) - name: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) - self: grad.clone().put_(index, zeros_like(source), accumulate) + self: "accumulate ? grad : grad.put(index, zeros_like(source), false)" index: non_differentiable - source: grad.take(index) + source: grad.take(index).reshape_as(source) - name: linalg_qr(Tensor self, str mode='reduced') -> (Tensor Q, Tensor R) self: linalg_qr_backward(grads, self, mode, Q, R) @@ -1114,7 +1114,7 @@ self: grad.rot90(-k, dims) - name: take(Tensor self, Tensor index) -> Tensor - self: take_backward(grad, self, index) + self: zeros_like(self).put_(index, grad, true) index: non_differentiable - name: tan(Tensor self) -> Tensor diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 4d084c52b02a..b4d48d1a8bcf 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -88,7 +88,7 @@ 'l1_loss_backward', 'baddbmm', 'addbmm', 'addmm', 'addmv', 'addr', 'linalg_householder_product', 'constant_pad_nd', 'reflection_pad1d', 'reflection_pad2d', 'reflection_pad1d_backward', 'reflection_pad2d_backward', - 'replication_pad1d', 'replication_pad2d', 'replication_pad3d', + 'replication_pad1d', 'replication_pad2d', 'replication_pad3d', 'take', 'put_', 'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward', 'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub', 'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward', diff --git a/tools/code_analyzer/default_op_deps.yaml b/tools/code_analyzer/default_op_deps.yaml index d9d89883a54f..93c914d1ff1f 100644 --- a/tools/code_analyzer/default_op_deps.yaml +++ b/tools/code_analyzer/default_op_deps.yaml @@ -10184,12 +10184,6 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::resize_ -- name: aten::take_backward - depends: - - name: aten::eq - - name: aten::is_nonzero - - name: aten::put_ - - name: aten::zeros_like - name: aten::tan depends: - name: aten::as_strided_ diff --git a/torch/__init__.py b/torch/__init__.py index 8c10bb5873dd..4fe9249ce2f7 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -371,6 +371,8 @@ def use_deterministic_algorithms(d): ``indices`` is a list of tensors * :func:`torch.index_put` with ``accumulate=True`` when called on a CPU tensor + * :func:`torch.put` with ``accumulate=True`` when called on a CPU + tensor The following normally-nondeterministic operations will throw a :class:`RuntimeError` when `d=True`: @@ -401,6 +403,8 @@ def use_deterministic_algorithms(d): * :func:`torch.scatter_add_` when called on a CUDA tensor * :func:`torch.index_add_` when called on a CUDA tensor * :func:`torch.index_copy` + * :func:`torch.put` when ``accumulate=False`` + * :func:`torch.put` when ``accumulate=True`` and called on a CUDA tensor * :func:`torch.index_select` when called on a CUDA tensor that requires grad * :func:`torch.repeat_interleave` when called on a CUDA tensor that requires grad * :func:`torch.histc` when called on a CUDA tensor diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index dc5855004ebe..24dbbeb7a4c7 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2702,19 +2702,22 @@ def callable(a, b) -> number add_docstr_all('put_', r""" -put_(indices, tensor, accumulate=False) -> Tensor +put_(index, source, accumulate=False) -> Tensor -Copies the elements from :attr:`tensor` into the positions specified by -indices. For the purpose of indexing, the :attr:`self` tensor is treated as if +Copies the elements from :attr:`source` into the positions specified by +:attr:`index`. For the purpose of indexing, the :attr:`self` tensor is treated as if it were a 1-D tensor. -If :attr:`accumulate` is ``True``, the elements in :attr:`tensor` are added to -:attr:`self`. If accumulate is ``False``, the behavior is undefined if indices +:attr:`index` and :attr:`source` need to have the same number of elements, but not necessarily +the same shape. + +If :attr:`accumulate` is ``True``, the elements in :attr:`source` are added to +:attr:`self`. If accumulate is ``False``, the behavior is undefined if :attr:`index` contain duplicate elements. Args: - indices (LongTensor): the indices into self - tensor (Tensor): the tensor containing values to copy from + index (LongTensor): the indices into self + source (Tensor): the tensor containing values to copy from accumulate (bool): whether to accumulate into self Example:: @@ -2726,6 +2729,14 @@ def callable(a, b) -> number [ 10, 7, 8]]) """) +add_docstr_all('put', + r""" +put(input, index, source, accumulate=False) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.put_`. +`input` corresponds to `self` in :meth:`torch.Tensor.put_`. +""") + add_docstr_all('qr', r""" qr(some=True) -> (Tensor, Tensor) diff --git a/torch/overrides.py b/torch/overrides.py index 5a23cc0bb0f4..52695002fec4 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -740,6 +740,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, torch.pow: lambda input, exponent, out=None: -1, torch.prod: lambda input, dtype=None: -1, + torch.put: lambda input, index, source, accumulate=False: -1, torch.q_per_channel_axis: lambda input: -1, torch.q_per_channel_scales: lambda input: -1, torch.q_per_channel_zero_points: lambda input: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 75b8f59dda0b..7a83814517fb 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1114,6 +1114,75 @@ def sample_inputs_mode(op_info, device, dtype, requires_grad): for input_tensor, args in args) return inputs +# Missing to test the nondeterminism of the operation +# https://github.com/pytorch/pytorch/issues/53352 +def sample_inputs_put(op_info, device, dtype, requires_grad): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False) + + S = 3 + + def gen_inputs(): + # Generic inputs + tgt_gen = (make_arg((S, S), discontiguous=not ctg) for ctg in (True, False)) + src_gen = (make_arg((S,), discontiguous=not ctg) for ctg in (True, False)) + idx = torch.randperm(S * S, device=device, dtype=torch.int64)[:S] + idx_nonctg = torch.repeat_interleave(idx, 2, dim=-1)[::2] + idx_neg = -idx - 1 + idx_list = [idx, idx_nonctg, idx_neg] + for tgt, idx, src, acc in product(tgt_gen, idx_list, src_gen, (True, False)): + yield SampleInput(input=tgt, args=(idx, src, acc)) + + # Scalar cases + scalar_sizes = [(), (1,)] + tgt_gen = (make_arg(size) for size in scalar_sizes) + idx_gen = (make_idx(size, high=1) for size in scalar_sizes) + src_gen = (make_arg(size) for size in scalar_sizes) + for tgt, idx, src, acc in product(tgt_gen, idx_gen, src_gen, (True, False)): + yield SampleInput(input=tgt, args=(idx, src, acc)) + + # Empty cases + tgt_sizes = [(0,), (), (1,), (3, 2)] + tgt_gen = (make_arg(size) for size in tgt_sizes) + idx = make_idx((0,), high=1) + src = make_arg((0,)) + for tgt, acc in product(tgt, (True, False)): + yield SampleInput(input=tgt, args=(idx, src, acc)) + + return list(gen_inputs()) + +def sample_inputs_take(op_info, device, dtype, requires_grad): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False) + + S = 3 + + def gen_inputs(): + # Generic inputs: take S elements out of S * S + src_gen = (make_arg((S, S), discontiguous=not ctg) for ctg in (True, False)) + idx = make_idx((S,), high=S * S) + idx_nonctg = make_idx((S,), high=S * S, discontiguous=True) + idx_neg = -idx - 1 + idx_list = [idx, idx_nonctg, idx_neg] + for src, idx in product(src_gen, idx_list): + yield SampleInput(input=src, args=(idx,)) + + # Scalar cases + scalar_sizes = [(), (1,)] + src_gen = (make_arg(size) for size in scalar_sizes) + idx_gen = (make_idx(size, high=1) for size in scalar_sizes) + for src, idx in product(src_gen, idx_gen): + yield SampleInput(input=src, args=(idx,)) + + # Empty cases + src_sizes = [(0,), (), (1,), (3, 2)] + src_gen = (make_arg(size) for size in src_sizes) + idx = make_idx((0,), high=1) + for src in src_gen: + yield SampleInput(input=src, args=(idx,)) + + return list(gen_inputs()) + def sample_movedim_moveaxis(op_info, device, dtype, requires_grad): return ( SampleInput( @@ -3725,6 +3794,15 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): # sort does not correctly warn when resizing out= inputs SkipInfo('TestCommon', 'test_out'), )), + OpInfo('put', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + check_batched_gradgrad=False, # vmap complains of the sizes + sample_inputs_func=sample_inputs_put), + OpInfo('take', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + check_batched_grad=False, # vmap complains of the sizes + sample_inputs_func=sample_inputs_take), OpInfo('stack', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_stack, @@ -4592,10 +4670,6 @@ def method_tests(): ('topk', (), (1, 0), 'dim_scalar', (), [1]), ('topk', (), (1, 0, True), 'dim_desc_scalar', (), [1]), ('topk', (), (1, 0, True, True), 'dim_desc_sort_scalar', (), [1]), - ('take', (S, S, S), (torch.LongTensor([[-3, 2], [20, 2]]),)), - ('take', (S, S, S), (torch.tensor(0, dtype=torch.int64),), 'scalar_index'), - ('take', (), (torch.LongTensor([0]),), 'scalar_data'), - ('take', (), (torch.tensor(0, dtype=torch.int64),), 'scalar_both'), ('where', (M, M), (mask_not_all_zeros((M, M)), (M, M)), '', (True,)), ('where', (M, 1, M), (mask_not_all_zeros((M, M)), (M, M, 1)), 'broadcast_all', (True,)), ('where', (), (bernoulli_scalar(), ()), 'scalar', (True,)), diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 11bb41876738..3997a466ebb6 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1650,7 +1650,6 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig result = (torch.rand(size, device=device, dtype=torch.float32) * span + low).to(torch.bfloat16) else: result = torch.rand(size, device=device, dtype=dtype) * span + low - result.requires_grad = requires_grad else: assert dtype in complex_types() low = -9 if low is None else max(low, -9) @@ -1660,12 +1659,15 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig real = torch.rand(size, device=device, dtype=float_dtype) * span + low imag = torch.rand(size, device=device, dtype=float_dtype) * span + low result = torch.complex(real, imag) - result.requires_grad = requires_grad if discontiguous and result.numel() > 1: result = torch.repeat_interleave(result, 2, dim=-1) result = result[..., ::2] + if dtype in floating_types_and(torch.half, torch.bfloat16) or\ + dtype in complex_types(): + result.requires_grad = requires_grad + return result def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):