Skip to content

Commit

Permalink
The dimension being reduced should not be coalesced by TensorIterator (
Browse files Browse the repository at this point in the history
…#47237)

Summary:
Fixes #37583 (comment)

Also add overload of `<<` for convenience of debugging.

This PR is tested by `test_reduction_split_cuda` which was added in #37788.

Reproduce
```python
import torch

a = torch.zeros(8, 1, 128, 1024, 1024)
a.cuda().sum(1)
```

Before

```
TensorIterator @ 0x7ffd05b10ba0 {
  ntensors() = 2
  noutputs() = 1
  shape() = [1073741824]
  strides(*) = {
    (0) = [4]
    (1) = [4]
  }
  dtype(*) = {
    (0) = Float
    (1) = Float
  }
  is_reduction_ = 1
}
```

After

```
TensorIterator @ 0x7fffc9051010 {
  ntensors() = 2
  noutputs() = 1
  shape() = [1, 1073741824]
  strides(*) = {
    (0) = [0, 4]
    (1) = [536870912, 4]
  }
  dtype(*) = {
    (0) = Float
    (1) = Float
  }
  is_reduction_ = 1
}
```

Pull Request resolved: #47237

Reviewed By: ejguan

Differential Revision: D24734763

Pulled By: ngimel

fbshipit-source-id: 02bb2b15694c68f96434f55033b63b6e5ff7085b
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Nov 7, 2020
1 parent 29184f8 commit f23a2a1
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 8 deletions.
12 changes: 7 additions & 5 deletions aten/src/ATen/native/ReduceOpsUtils.h
Expand Up @@ -156,15 +156,17 @@ static void allocate_reduction_result(
}

static Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
if (keepdim) {
return result;
}
auto shape = DimVector(result.sizes());
auto stride = DimVector(result.strides());
for (int dim = 0; dim < ndim; dim++) {
if (mask[dim]) {
shape.insert(shape.begin() + dim, 1);
stride.insert(stride.begin() + dim, 0);
if (!keepdim) {
shape.insert(shape.begin() + dim, 1);
stride.insert(stride.begin() + dim, 0);
} else {
TORCH_INTERNAL_ASSERT(shape[dim] == 1);
stride[dim] = 0;
}
}
}
return result.as_strided(shape, stride);
Expand Down
31 changes: 30 additions & 1 deletion aten/src/ATen/native/TensorIterator.cpp
Expand Up @@ -542,6 +542,15 @@ void TensorIterator::coalesce_dimensions() {
auto can_coalesce = [&](int dim0, int dim1) {
auto shape0 = shape_[dim0];
auto shape1 = shape_[dim1];
if (is_reduction_) {
// The dimension being reduced should not be coalesced
for (int i = 0; i < noutputs(); i++) {
auto& stride = operands_[i].stride_bytes;
if (stride[dim0] == 0 || stride[dim1] == 0) {
return false;
}
}
}
if (shape0 == 1 || shape1 == 1) {
return true;
}
Expand Down Expand Up @@ -802,7 +811,7 @@ void TensorIterator::narrow(int dim, int64_t start, int64_t size) {
for (auto& op : operands_) {
op.data = ((char*)op.data) + op.stride_bytes[dim] * start;
}
if (size == 1 && !is_reduction_) {
if (size == 1) {
coalesce_dimensions();
}
}
Expand Down Expand Up @@ -1397,4 +1406,24 @@ std::array<int64_t, 2> DimCounter::max_2d_step() const {
return {step0, step1};
}

std::ostream& operator<<(std::ostream& os, const TensorIterator& iter) {
os << "TensorIterator @ " << &iter << " {" << std::endl;
os << " ntensors() = " << iter.ntensors() << std::endl;
os << " noutputs() = " << iter.noutputs() << std::endl;
os << " shape() = " << iter.shape() << std::endl;
os << " strides(*) = {" << std::endl;
for (int i = 0; i < iter.ntensors(); i++) {
os << " (" << i << ") = " << iter.strides(i) << std::endl;
}
os << " }" << std::endl;
os << " dtype(*) = {" << std::endl;
for (int i = 0; i < iter.ntensors(); i++) {
os << " (" << i << ") = " << iter.dtype(i) << std::endl;
}
os << " }" << std::endl;
os << " is_reduction_ = " << iter.is_reduction_ << std::endl;
os << "}";
return os;
}

} // namespace at
5 changes: 5 additions & 0 deletions aten/src/ATen/native/TensorIterator.h
@@ -1,5 +1,6 @@
#pragma once

#include <iostream>
#include <ATen/ATen.h>
#include <c10/util/FunctionRef.h>
#include <c10/util/SmallVector.h>
Expand Down Expand Up @@ -296,6 +297,8 @@ struct CAFFE2_API TensorIterator {
return true;
}

friend CAFFE2_API std::ostream& operator<<(std::ostream& os, const TensorIterator& iter);

protected:
void build(TensorIteratorConfig&);

Expand Down Expand Up @@ -533,4 +536,6 @@ struct CAFFE2_API SplitUntil32Bit {
const TensorIterator& iter;
};

CAFFE2_API std::ostream& operator<<(std::ostream& os, const TensorIterator& iter);

} // namespace at
17 changes: 15 additions & 2 deletions test/test_torch.py
@@ -1,5 +1,6 @@
import sys
import io
import gc
import inspect
import itertools
import math
Expand Down Expand Up @@ -8001,6 +8002,7 @@ def test_cholesky(self, device, dtype):
B = torch.mm(L, L.t().conj())
self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix')

@skipIfRocm # This test has many dimensions, which is larger than the maximum dims supported by ROCm (16)
def test_view(self, device):
tensor = torch.rand(15, device=device)
template = torch.rand(3, 5, device=device)
Expand Down Expand Up @@ -9658,7 +9660,7 @@ def test_multidim(x, dim):
expected = fn(y, 1, keepdim=False)
self.assertEqual(x[:, 1], expected, msg='{} with out= kwarg'.format(fn_name))

@slowTest
@onlyCUDA
@largeTensorTest('10GB')
def test_reduction_split(self, device):
# Test reduction when there is a 32bit-indexing split
Expand All @@ -9667,6 +9669,13 @@ def test_reduction_split(self, device):
result = input_.sum(dim=0)
expect = input_[0] + input_[1] + input_[2] + input_[3] + input_[4]
self.assertEqual(result, expect)
gc.collect()
torch.cuda.empty_cache()
a = torch.randn(8, 1, 128, 1024, 1024, device=device, dtype=torch.half)
self.assertEqual((a.sum(1) - a.squeeze()).abs().max(), 0)
gc.collect()
torch.cuda.empty_cache()
self.assertEqual((a.sum(1, keepdim=True) - a).abs().max(), 0)

@onlyCUDA
@dtypes(torch.half, torch.float, torch.double)
Expand Down Expand Up @@ -19152,7 +19161,11 @@ def test_nansum_out_dtype(self, device):
torch_fn = partial(torch.nansum, dtype=out_dtype)
np_out_dtype = torch_to_numpy_dtype_dict[out_dtype]
np_fn = partial(np.nansum, dtype=np_out_dtype)
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
if (inp_dtype, out_dtype) == (torch.uint8, torch.float16):
# 25504.0 vs 25536.0
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None, atol=0, rtol=0.002)
else:
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)

@dtypes(torch.int32, torch.int64)
def test_large_linspace(self, device, dtype):
Expand Down

0 comments on commit f23a2a1

Please sign in to comment.