Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <ATen/native/ReduceOps.h>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/ExpandUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/WrapDimUtils.h>
Expand Down Expand Up @@ -473,6 +474,40 @@ static Tensor& prod_out_impl(Tensor& result, const Tensor& self, IntArrayRef dim
return result;
}

// NOTE: this could be implemented via diag and sum, but this has perf problems,
// see https://github.com/pytorch/pytorch/pull/47305,
Tensor trace_cpu(const Tensor& self) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The best place to put this might be TriangularOps.cpp (trace_cuda is in TriangularOps.cu). Alternatively we can move trace_cuda out of TriangularOps.cu for consistency

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is you need the specialized promotion logic from here. In theory you could split things up but this doesn't seem worth it given the promotion logic is really specialized for reductions.

Tensor result;
ScalarType dtype = get_dtype(result, self, c10::nullopt, true);
result = at::empty({}, self.options().dtype(dtype));
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "trace", [&] {
using accscalar_t = at::acc_type<scalar_t, false>;
accscalar_t sum = 0;
const auto* t_data = self.data_ptr<scalar_t>();

int64_t t_stride_0, t_stride_1, t_diag_size;

TORCH_CHECK(self.dim() == 2, "trace: expected a matrix, but got tensor with dim ", self.dim());

t_stride_0 = self.stride(0);
t_stride_1 = self.stride(1);

t_diag_size = std::min(self.size(0), self.size(1));
for (int64_t i = 0; i < t_diag_size; i++) {
sum += t_data[i * (t_stride_0 + t_stride_1)];
}

// all integer types get promoted to kLong
if (result.scalar_type() == at::kLong) {
*result.data_ptr<int64_t>() = sum;
} else {
*result.data_ptr<scalar_t>() = sum;
}
});

return result;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're going to have to update test_trace:

pytorch/test/test_torch.py

Lines 5996 to 6029 in 774b638

def _test_trace(self, device, dtype, legacy):
def test(shape):
tensor = make_tensor(shape, device, dtype, low=-9, high=9)
diag = tensor.diag()
if legacy:
# NB: trace on cpu doesn't do type promotion... #47127
expected_dtype = dtype
else:
expected_dtype = tensor.sum().dtype
expected_dtype = torch_to_numpy_dtype_dict[expected_dtype]
result = np.trace(tensor.cpu().numpy(), dtype=expected_dtype)
expected = torch.tensor(result, device=device)
self.assertEqual(tensor.trace(), expected)
shapes = (
[10, 1],
[1, 10],
[100, 100],
[20, 100],
[100, 20],
)
for shape in shapes:
test(shape)
@onlyCPU
@dtypes(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False))
def test_trace_legacy(self, device, dtype):
self._test_trace(device, dtype, legacy=True)
@onlyCUDA
@dtypes(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))
def test_trace(self, device, dtype):
self._test_trace(device, dtype, legacy=False)


Tensor prod(const Tensor& self, int64_t dim, bool keepdim, c10::optional<ScalarType> dtype) {
Tensor result;
native::prod_out_impl(result, self, dim, keepdim, dtype);
Expand Down
25 changes: 0 additions & 25 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1985,29 +1985,4 @@ Tensor movedim(const Tensor& self, int64_t src, int64_t dst) {
return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst});
}

Tensor trace_cpu(const Tensor& self) {
Tensor result = at::empty({}, self.options());
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "trace", [&] {
using accscalar_t = at::acc_type<scalar_t, false>;
accscalar_t sum = 0;
const auto* t_data = self.data_ptr<scalar_t>();

int64_t t_stride_0, t_stride_1, t_diag_size;

TORCH_CHECK(self.dim() == 2, "trace: expected a matrix, but got tensor with dim ", self.dim());

t_stride_0 = self.stride(0);
t_stride_1 = self.stride(1);

t_diag_size = std::min(self.size(0), self.size(1));
for (int64_t i = 0; i < t_diag_size; i++) {
sum += t_data[i * (t_stride_0 + t_stride_1)];
}

*result.data_ptr<scalar_t>() = sum;
});

return result;
}

}} // at::native
23 changes: 6 additions & 17 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6048,15 +6048,14 @@ def test_diagonal_multidim(self, device, dtype):
self.assertEqual(expected.shape, result.shape)
self.assertEqual(expected, result)

def _test_trace(self, device, dtype, legacy):
@onlyOnCPUAndCUDA
@dtypesIfCPU(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False,
include_bfloat16=False))
@dtypesIfCUDA(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))
def test_trace(self, device, dtype):
def test(shape):
tensor = make_tensor(shape, device, dtype, low=-9, high=9)
diag = tensor.diag()
if legacy:
# NB: trace on cpu doesn't do type promotion... #47127
expected_dtype = dtype
else:
expected_dtype = tensor.sum().dtype
expected_dtype = tensor.sum().dtype
expected_dtype = torch_to_numpy_dtype_dict[expected_dtype]

result = np.trace(tensor.cpu().numpy(), dtype=expected_dtype)
Expand All @@ -6073,16 +6072,6 @@ def test(shape):
for shape in shapes:
test(shape)

@onlyCPU
@dtypes(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False))
def test_trace_legacy(self, device, dtype):
self._test_trace(device, dtype, legacy=True)

@onlyCUDA
@dtypes(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))
def test_trace(self, device, dtype):
self._test_trace(device, dtype, legacy=False)

@onlyCPU
@dtypes(torch.float)
def test_broadcast_tensors(self, device, dtype):
Expand Down