Skip to content

Commit

Permalink
Port put_ and take from TH to ATen (#53356)
Browse files Browse the repository at this point in the history
Summary:
The two ports were don together, as they can be implemented with the same kernel. In TH, they were already implemented with the same kernel.

Resolves #24751
Resolves #24614
Resolves #24640
Resolves #24772

This port makes sure that it interacts correctly with the "deterministic algorithms" flag, as done in #51388

This PR also makes these two functions correct in the following aspects (all of them added to the tests as well):
- Support for complex numbers
- Correct handling of scalar inputs and zero-dimensional inputs
- Implementation that does not do any copies nor sorting of any of the input tensors
- Faster and more correct implementation of the backwards (now it works as it should when `source.shape() != index.shape()`)
- Now `put_(..., accumulate=True)` is implemented correctly with atomic operations on GPU / CPU (when possible) and is deterministic (modulo the loss of precision that might happen due to the reordering of a sum of floats)
- Adds the `torch.put` function that was missing, (`index_put` exists, for example)
- Corrected docs

It also adds a much more thorough testing to the operations and their gradients.

There is a BC-breaking change, and that is that now we check that the inputs do not overlap in the `put_` operation. This was handled (some of the cases, other cases were wrong) in the TH implementation by making contiguous copies of the inputs. How should we handle this one?

**Edit.** Benchmarks:
<details>
<summary>Script</summary>

```python
from IPython import get_ipython
import torch
from itertools import product

torch.manual_seed(13)
torch.set_num_threads(1)

ipython = get_ipython()

cpu = torch.device('cpu')
cuda = torch.device('cuda')

def run_test(ndims, size, index_len, device, cmd):
    print(f"cmd: {cmd}, ndims: {ndims}, tensor_size: {size}, index_len: {index_len}, device: {device}")

    large_tensor = torch.rand(*([size] * ndims), device=device)
    small_tensor = torch.rand((index_len,), device=device)
    index = torch.randint(size * ndims, (index_len,), dtype=torch.long, device=device)
    if cmd == "put":
        command = "large_tensor.put_(index, small_tensor, accumulate=False)"
        if device == cuda:
            command += "; torch.cuda.synchronize()"
    elif cmd == "accumulate":
        command = "large_tensor.put_(index, small_tensor, accumulate=True)"
        if device == cuda:
            command += "; torch.cuda.synchronize()"
    elif cmd == "take":
        command = "torch.take(large_tensor, index)"
        if device == cuda:
            command += "; torch.cuda.synchronize()"
    ipython.magic(f"timeit {command}")
    print()

for method, device in product(["accumulate", "put", "take"], [cpu, cuda]):
    run_test(3, 1000, 10, device, method)
    run_test(3, 1000, 1000, device, method)
    run_test(3, 1000, 10000, device, method)
    run_test(2, 10000, 100000, device, method)
```
</details>

```python
put_(accumulate=False)
```

<details>
<summary>ATen CPU (1.5x - 2x speedup)</summary>

```python
cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu
1.05 µs ± 2.35 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu
3.15 µs ± 5.13 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu
21.6 µs ± 13.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu
238 µs ± 781 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
</details>

<details>
<summary>TH CPU</summary>

```python
cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu
722 ns ± 2.67 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu
4.89 µs ± 18.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu
42.5 µs ± 96.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu
428 µs ± 774 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
</details>
<details>
<summary>ATen GPU (same speed)</summary>

```python
cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda
8.99 µs ± 16 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda
10.4 µs ± 24.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda
10.4 µs ± 11.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda
15.6 µs ± 1.12 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```
</details>

<details>
<summary>TH GPU</summary>

```python
cmd: put, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda
8.44 µs ± 31.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: put, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda
9.09 µs ± 4.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: put, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda
9.77 µs ± 0.998 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: put, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda
15.8 µs ± 5.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```
</details>

```python
put_(accumulate=True)
```

<details>
<summary>ATen CPU (x2 speedup)</summary>

```python
cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu
1.12 µs ± 2.91 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu
3.14 µs ± 2.05 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu
20.8 µs ± 25.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu
264 µs ± 263 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
</details>

<details>
<summary>TH CPU</summary>

```python
cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu
814 ns ± 1.87 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu
5.11 µs ± 6.02 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu
43.9 µs ± 49.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu
442 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
</details>
<details>
<summary>ATen GPU (3x - 11x speedup)</summary>

```python
cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda
9.01 µs ± 14.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda
10.4 µs ± 15.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda
10.3 µs ± 44.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda
12.6 µs ± 19 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```
</details>

<details>
<summary>TH GPU</summary>

```python
cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda
34.7 µs ± 131 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda
38.2 µs ± 116 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

cmd: accumulate, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda
61.2 µs ± 50.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

cmd: accumulate, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda
140 µs ± 24.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
</details>

```python
take()
```

<details>
<summary>ATen CPU (1.1x speedup)</summary>

```python
cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu
1.18 µs ± 2.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu
2.79 µs ± 2.96 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu
16.6 µs ± 10.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu
161 µs ± 984 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
</details>

<details>
<summary>TH CPU</summary>

```python
cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cpu
1.1 µs ± 3.14 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cpu
2.93 µs ± 7.31 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cpu
18.6 µs ± 14.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cpu
178 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
</details>
<details>
<summary>ATen GPU (same speed)</summary>

```python
cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda
9.38 µs ± 23.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda
10.7 µs ± 9.77 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda
10.6 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda
11.5 µs ± 21.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```
</details>

<details>
<summary>TH GPU</summary>

```python
cmd: take, ndims: 3, tensor_size: 1000, index_len: 10, device: cuda
9.31 µs ± 7.57 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: take, ndims: 3, tensor_size: 1000, index_len: 1000, device: cuda
9.52 µs ± 5.78 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: take, ndims: 3, tensor_size: 1000, index_len: 10000, device: cuda
9.73 µs ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cmd: take, ndims: 2, tensor_size: 10000, index_len: 100000, device: cuda
11.7 µs ± 5.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```
</details>

cc mruberry

Pull Request resolved: #53356

Reviewed By: mruberry

Differential Revision: D27520243

Pulled By: ngimel

fbshipit-source-id: e3979349c2c62d2949e09fb05e5fd4883fbc9093
  • Loading branch information
lezcano authored and facebook-github-bot committed Apr 6, 2021
1 parent bf37bf7 commit fd02fc5
Show file tree
Hide file tree
Showing 29 changed files with 656 additions and 864 deletions.
1 change: 0 additions & 1 deletion BUILD.bazel
Expand Up @@ -381,7 +381,6 @@ filegroup(
"aten/src/THC/THCStorageCopy.cu.cc",
"aten/src/THC/THCTensor.cu.cc",
"aten/src/THC/THCTensorCopy.cu.cc",
"aten/src/THC/THCTensorIndex.cu.cc",
"aten/src/THC/THCTensorMath.cu.cc",
"aten/src/THC/THCTensorMathMagma.cu.cc",
"aten/src/THC/THCTensorMathPairwise.cu.cc",
Expand Down
66 changes: 0 additions & 66 deletions aten/src/ATen/LegacyTHFunctionsCPU.cpp
Expand Up @@ -165,72 +165,6 @@ Tensor _th_nonzero(const Tensor & self) {
}
return result;
}
Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);

switch (dispatch_scalar_type) {
case ScalarType::Bool: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
THBoolTensor_put(self_, index_, source_, accumulate);
break;
}
case ScalarType::Byte: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
THByteTensor_put(self_, index_, source_, accumulate);
break;
}
case ScalarType::Char: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
THCharTensor_put(self_, index_, source_, accumulate);
break;
}
case ScalarType::Double: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
THDoubleTensor_put(self_, index_, source_, accumulate);
break;
}
case ScalarType::Float: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
THFloatTensor_put(self_, index_, source_, accumulate);
break;
}
case ScalarType::Int: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
THIntTensor_put(self_, index_, source_, accumulate);
break;
}
case ScalarType::Long: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
THLongTensor_put(self_, index_, source_, accumulate);
break;
}
case ScalarType::Short: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CPU, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type);
THShortTensor_put(self_, index_, source_, accumulate);
break;
}
default:
AT_ERROR("_th_put_ not supported on CPUType for ", dispatch_scalar_type);
}
return self;
}
Tensor _th_var(const Tensor & self, bool unbiased) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/LegacyTHFunctionsCPU.h
Expand Up @@ -22,7 +22,6 @@ Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor &
Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source);
Tensor& _th_nonzero_out(const Tensor& self, Tensor& result);
Tensor _th_nonzero(const Tensor & self);
Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate);
Tensor _th_var(const Tensor & self, bool unbiased);
Tensor _th_std(const Tensor & self, bool unbiased);
Tensor & _th_renorm_out(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm, Tensor & result);
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/LegacyTHFunctionsCUDA.h
Expand Up @@ -20,7 +20,6 @@ namespace cuda {

Tensor & _th_masked_fill_(Tensor & self, const Tensor & mask, const Scalar& value);
Tensor & _th_masked_fill_bool_(Tensor & self, const Tensor & mask, const Scalar& value);
Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate);
std::tuple<Tensor &,Tensor &> _th_sort_out(const Tensor & self, int64_t dim, bool descending, Tensor & values, Tensor & indices);
std::tuple<Tensor,Tensor> _th_sort(const Tensor & self, int64_t dim, bool descending);
std::tuple<Tensor &,Tensor &> _th_sort_out_stable(const Tensor & self, c10::optional<bool> stable, int64_t dim, bool descending, Tensor & values, Tensor & indices);
Expand Down
73 changes: 0 additions & 73 deletions aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp
Expand Up @@ -39,79 +39,6 @@ namespace {
return at::cuda::getCUDADeviceAllocator();
}
}
Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);

switch (dispatch_scalar_type) {
case ScalarType::Bool: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaBoolTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate);
break;
}
case ScalarType::Byte: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaByteTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate);
break;
}
case ScalarType::Char: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaCharTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate);
break;
}
case ScalarType::Double: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaDoubleTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate);
break;
}
case ScalarType::Float: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate);
break;
}
case ScalarType::Int: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaIntTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate);
break;
}
case ScalarType::Long: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaLongTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate);
break;
}
case ScalarType::Short: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaShortTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate);
break;
}
case ScalarType::Half: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_put_", false, DeviceType::CUDA, ScalarType::Long);
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_put_", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaHalfTensor_put(globalContext().getTHCState(), self_, index_, source_, accumulate);
break;
}
default:
AT_ERROR("_th_put_ not supported on CUDAType for ", dispatch_scalar_type);
}
return self;
}
std::tuple<Tensor &,Tensor &> _th_sort_out_stable(const Tensor & self, c10::optional<bool> stable, int64_t dim, bool descending, Tensor & values, Tensor & indices) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/cuda/detail/OffsetCalculator.cuh
Expand Up @@ -7,8 +7,10 @@
#include <ATen/native/TensorIterator.h>
#include <THC/THCIntegerDivider.cuh>

/// OffsetCalculator calculates the offset in bytes of a linear index for NARGS
/// operands that share the same shape, but may have different strides.
// If element_sizes is nullptr, then the strides will be in bytes, otherwise
// the strides will be in # of elements.
// Operands that share the same shape, but may have different strides.
// OffsetCalculator iterates the tensor in a column-major order

#ifdef __HIP_PLATFORM_HCC__
constexpr int MAX_DIMS = 16;
Expand Down

0 comments on commit fd02fc5

Please sign in to comment.