From f23a2a11153a7feba0c3c9aa01f68996569144af Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sat, 7 Nov 2020 01:13:19 -0800 Subject: [PATCH] The dimension being reduced should not be coalesced by TensorIterator (#47237) Summary: Fixes https://github.com/pytorch/pytorch/issues/37583#issuecomment-720172838 Also add overload of `<<` for convenience of debugging. This PR is tested by `test_reduction_split_cuda` which was added in https://github.com/pytorch/pytorch/pull/37788. Reproduce ```python import torch a = torch.zeros(8, 1, 128, 1024, 1024) a.cuda().sum(1) ``` Before ``` TensorIterator @ 0x7ffd05b10ba0 { ntensors() = 2 noutputs() = 1 shape() = [1073741824] strides(*) = { (0) = [4] (1) = [4] } dtype(*) = { (0) = Float (1) = Float } is_reduction_ = 1 } ``` After ``` TensorIterator @ 0x7fffc9051010 { ntensors() = 2 noutputs() = 1 shape() = [1, 1073741824] strides(*) = { (0) = [0, 4] (1) = [536870912, 4] } dtype(*) = { (0) = Float (1) = Float } is_reduction_ = 1 } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/47237 Reviewed By: ejguan Differential Revision: D24734763 Pulled By: ngimel fbshipit-source-id: 02bb2b15694c68f96434f55033b63b6e5ff7085b --- aten/src/ATen/native/ReduceOpsUtils.h | 12 ++++++---- aten/src/ATen/native/TensorIterator.cpp | 31 ++++++++++++++++++++++++- aten/src/ATen/native/TensorIterator.h | 5 ++++ test/test_torch.py | 17 ++++++++++++-- 4 files changed, 57 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h index 429b6f49a7bd..205b08d86423 100644 --- a/aten/src/ATen/native/ReduceOpsUtils.h +++ b/aten/src/ATen/native/ReduceOpsUtils.h @@ -156,15 +156,17 @@ static void allocate_reduction_result( } static Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) { - if (keepdim) { - return result; - } auto shape = DimVector(result.sizes()); auto stride = DimVector(result.strides()); for (int dim = 0; dim < ndim; dim++) { if (mask[dim]) { - shape.insert(shape.begin() + dim, 1); - stride.insert(stride.begin() + dim, 0); + if (!keepdim) { + shape.insert(shape.begin() + dim, 1); + stride.insert(stride.begin() + dim, 0); + } else { + TORCH_INTERNAL_ASSERT(shape[dim] == 1); + stride[dim] = 0; + } } } return result.as_strided(shape, stride); diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp index cb37b2055eda..81a7ae3b9a9d 100644 --- a/aten/src/ATen/native/TensorIterator.cpp +++ b/aten/src/ATen/native/TensorIterator.cpp @@ -542,6 +542,15 @@ void TensorIterator::coalesce_dimensions() { auto can_coalesce = [&](int dim0, int dim1) { auto shape0 = shape_[dim0]; auto shape1 = shape_[dim1]; + if (is_reduction_) { + // The dimension being reduced should not be coalesced + for (int i = 0; i < noutputs(); i++) { + auto& stride = operands_[i].stride_bytes; + if (stride[dim0] == 0 || stride[dim1] == 0) { + return false; + } + } + } if (shape0 == 1 || shape1 == 1) { return true; } @@ -802,7 +811,7 @@ void TensorIterator::narrow(int dim, int64_t start, int64_t size) { for (auto& op : operands_) { op.data = ((char*)op.data) + op.stride_bytes[dim] * start; } - if (size == 1 && !is_reduction_) { + if (size == 1) { coalesce_dimensions(); } } @@ -1397,4 +1406,24 @@ std::array DimCounter::max_2d_step() const { return {step0, step1}; } +std::ostream& operator<<(std::ostream& os, const TensorIterator& iter) { + os << "TensorIterator @ " << &iter << " {" << std::endl; + os << " ntensors() = " << iter.ntensors() << std::endl; + os << " noutputs() = " << iter.noutputs() << std::endl; + os << " shape() = " << iter.shape() << std::endl; + os << " strides(*) = {" << std::endl; + for (int i = 0; i < iter.ntensors(); i++) { + os << " (" << i << ") = " << iter.strides(i) << std::endl; + } + os << " }" << std::endl; + os << " dtype(*) = {" << std::endl; + for (int i = 0; i < iter.ntensors(); i++) { + os << " (" << i << ") = " << iter.dtype(i) << std::endl; + } + os << " }" << std::endl; + os << " is_reduction_ = " << iter.is_reduction_ << std::endl; + os << "}"; + return os; +} + } // namespace at diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h index febf21a290dd..3a9612e158ae 100644 --- a/aten/src/ATen/native/TensorIterator.h +++ b/aten/src/ATen/native/TensorIterator.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -296,6 +297,8 @@ struct CAFFE2_API TensorIterator { return true; } + friend CAFFE2_API std::ostream& operator<<(std::ostream& os, const TensorIterator& iter); + protected: void build(TensorIteratorConfig&); @@ -533,4 +536,6 @@ struct CAFFE2_API SplitUntil32Bit { const TensorIterator& iter; }; +CAFFE2_API std::ostream& operator<<(std::ostream& os, const TensorIterator& iter); + } // namespace at diff --git a/test/test_torch.py b/test/test_torch.py index ae446c7aabb6..6a9a8e5603d5 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1,5 +1,6 @@ import sys import io +import gc import inspect import itertools import math @@ -8001,6 +8002,7 @@ def test_cholesky(self, device, dtype): B = torch.mm(L, L.t().conj()) self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix') + @skipIfRocm # This test has many dimensions, which is larger than the maximum dims supported by ROCm (16) def test_view(self, device): tensor = torch.rand(15, device=device) template = torch.rand(3, 5, device=device) @@ -9658,7 +9660,7 @@ def test_multidim(x, dim): expected = fn(y, 1, keepdim=False) self.assertEqual(x[:, 1], expected, msg='{} with out= kwarg'.format(fn_name)) - @slowTest + @onlyCUDA @largeTensorTest('10GB') def test_reduction_split(self, device): # Test reduction when there is a 32bit-indexing split @@ -9667,6 +9669,13 @@ def test_reduction_split(self, device): result = input_.sum(dim=0) expect = input_[0] + input_[1] + input_[2] + input_[3] + input_[4] self.assertEqual(result, expect) + gc.collect() + torch.cuda.empty_cache() + a = torch.randn(8, 1, 128, 1024, 1024, device=device, dtype=torch.half) + self.assertEqual((a.sum(1) - a.squeeze()).abs().max(), 0) + gc.collect() + torch.cuda.empty_cache() + self.assertEqual((a.sum(1, keepdim=True) - a).abs().max(), 0) @onlyCUDA @dtypes(torch.half, torch.float, torch.double) @@ -19152,7 +19161,11 @@ def test_nansum_out_dtype(self, device): torch_fn = partial(torch.nansum, dtype=out_dtype) np_out_dtype = torch_to_numpy_dtype_dict[out_dtype] np_fn = partial(np.nansum, dtype=np_out_dtype) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + if (inp_dtype, out_dtype) == (torch.uint8, torch.float16): + # 25504.0 vs 25536.0 + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None, atol=0, rtol=0.002) + else: + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) @dtypes(torch.int32, torch.int64) def test_large_linspace(self, device, dtype):