Skip to content

Commit 65a72ca

Browse files
gchananfacebook-github-bot
authored andcommitted
Fix type promotion for trace on CPU. (#47305)
Summary: Pull Request resolved: #47305 Fixes #47127. Ideally this would just use diag and sum (as the CUDA implementation does), but that seems to have performance problems, which I'll link in the github PR. Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D24729627 Pulled By: gchanan fbshipit-source-id: 151b786b53e7b958f0929c803dbf8e95981c6884
1 parent 57dcb04 commit 65a72ca

File tree

3 files changed

+41
-42
lines changed

3 files changed

+41
-42
lines changed

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <ATen/native/ReduceOps.h>
22

33
#include <ATen/ATen.h>
4+
#include <ATen/AccumulateType.h>
45
#include <ATen/ExpandUtils.h>
56
#include <ATen/NativeFunctions.h>
67
#include <ATen/WrapDimUtils.h>
@@ -473,6 +474,40 @@ static Tensor& prod_out_impl(Tensor& result, const Tensor& self, IntArrayRef dim
473474
return result;
474475
}
475476

477+
// NOTE: this could be implemented via diag and sum, but this has perf problems,
478+
// see https://github.com/pytorch/pytorch/pull/47305,
479+
Tensor trace_cpu(const Tensor& self) {
480+
Tensor result;
481+
ScalarType dtype = get_dtype(result, self, c10::nullopt, true);
482+
result = at::empty({}, self.options().dtype(dtype));
483+
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "trace", [&] {
484+
using accscalar_t = at::acc_type<scalar_t, false>;
485+
accscalar_t sum = 0;
486+
const auto* t_data = self.data_ptr<scalar_t>();
487+
488+
int64_t t_stride_0, t_stride_1, t_diag_size;
489+
490+
TORCH_CHECK(self.dim() == 2, "trace: expected a matrix, but got tensor with dim ", self.dim());
491+
492+
t_stride_0 = self.stride(0);
493+
t_stride_1 = self.stride(1);
494+
495+
t_diag_size = std::min(self.size(0), self.size(1));
496+
for (int64_t i = 0; i < t_diag_size; i++) {
497+
sum += t_data[i * (t_stride_0 + t_stride_1)];
498+
}
499+
500+
// all integer types get promoted to kLong
501+
if (result.scalar_type() == at::kLong) {
502+
*result.data_ptr<int64_t>() = sum;
503+
} else {
504+
*result.data_ptr<scalar_t>() = sum;
505+
}
506+
});
507+
508+
return result;
509+
}
510+
476511
Tensor prod(const Tensor& self, int64_t dim, bool keepdim, c10::optional<ScalarType> dtype) {
477512
Tensor result;
478513
native::prod_out_impl(result, self, dim, keepdim, dtype);

aten/src/ATen/native/TensorShape.cpp

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,29 +1985,4 @@ Tensor movedim(const Tensor& self, int64_t src, int64_t dst) {
19851985
return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst});
19861986
}
19871987

1988-
Tensor trace_cpu(const Tensor& self) {
1989-
Tensor result = at::empty({}, self.options());
1990-
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "trace", [&] {
1991-
using accscalar_t = at::acc_type<scalar_t, false>;
1992-
accscalar_t sum = 0;
1993-
const auto* t_data = self.data_ptr<scalar_t>();
1994-
1995-
int64_t t_stride_0, t_stride_1, t_diag_size;
1996-
1997-
TORCH_CHECK(self.dim() == 2, "trace: expected a matrix, but got tensor with dim ", self.dim());
1998-
1999-
t_stride_0 = self.stride(0);
2000-
t_stride_1 = self.stride(1);
2001-
2002-
t_diag_size = std::min(self.size(0), self.size(1));
2003-
for (int64_t i = 0; i < t_diag_size; i++) {
2004-
sum += t_data[i * (t_stride_0 + t_stride_1)];
2005-
}
2006-
2007-
*result.data_ptr<scalar_t>() = sum;
2008-
});
2009-
2010-
return result;
2011-
}
2012-
20131988
}} // at::native

test/test_torch.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6053,15 +6053,14 @@ def test_diagonal_multidim(self, device, dtype):
60536053
self.assertEqual(expected.shape, result.shape)
60546054
self.assertEqual(expected, result)
60556055

6056-
def _test_trace(self, device, dtype, legacy):
6056+
@onlyOnCPUAndCUDA
6057+
@dtypesIfCPU(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False,
6058+
include_bfloat16=False))
6059+
@dtypesIfCUDA(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))
6060+
def test_trace(self, device, dtype):
60576061
def test(shape):
60586062
tensor = make_tensor(shape, device, dtype, low=-9, high=9)
6059-
diag = tensor.diag()
6060-
if legacy:
6061-
# NB: trace on cpu doesn't do type promotion... #47127
6062-
expected_dtype = dtype
6063-
else:
6064-
expected_dtype = tensor.sum().dtype
6063+
expected_dtype = tensor.sum().dtype
60656064
expected_dtype = torch_to_numpy_dtype_dict[expected_dtype]
60666065

60676066
result = np.trace(tensor.cpu().numpy(), dtype=expected_dtype)
@@ -6078,16 +6077,6 @@ def test(shape):
60786077
for shape in shapes:
60796078
test(shape)
60806079

6081-
@onlyCPU
6082-
@dtypes(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False))
6083-
def test_trace_legacy(self, device, dtype):
6084-
self._test_trace(device, dtype, legacy=True)
6085-
6086-
@onlyCUDA
6087-
@dtypes(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))
6088-
def test_trace(self, device, dtype):
6089-
self._test_trace(device, dtype, legacy=False)
6090-
60916080
@onlyCPU
60926081
@dtypes(torch.float)
60936082
def test_broadcast_tensors(self, device, dtype):

0 commit comments

Comments
 (0)