From 6170e47df559502f10b403a5258dc4c8e5c4b498 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Mon, 29 Apr 2024 17:41:56 +0300 Subject: [PATCH] Fix hybrid sparse COO tensor conversion to meta tensor ghstack-source-id: 5dc86c8e597bf27e47d46efedcdb086c5c80de93 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125120 --- aten/src/ATen/native/TensorConversions.cpp | 32 ++++++- .../ATen/native/sparse/SparseTensorMath.cpp | 5 +- test/test_sparse.py | 93 +++++++++++-------- 3 files changed, 87 insertions(+), 43 deletions(-) diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index c555706f4ced..a6c1118c4e80 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -254,7 +254,37 @@ Tensor _to_copy( // TODO: Use the dispatcher for this. // Currently there are unenumerated extensibility issues preventing this. - if (at::sparse_csr::is_sparse_compressed(self)) { + if (self.layout() == kSparse) { + TORCH_CHECK( + memory_format == MemoryFormat::Preserve, + "to(options): COO only supports memory format Preserve, but got ", memory_format, + " instead."); + auto indices = self._indices(); + const auto new_indices = at::native::to( + indices, + indices.scalar_type(), + c10::kStrided, + device, + pin_memory, + non_blocking, + true, // force copy since we are in _to_copy + memory_format); + const auto new_values = at::native::to( + self._values(), + dtype, + c10::kStrided, + device, + pin_memory, + non_blocking, + true, // force copy since we are in _to_copy + memory_format); + + return at::_sparse_coo_tensor_unsafe( + new_indices, + new_values, + self.sizes(), + options, self.is_coalesced()); + } else if (at::sparse_csr::is_sparse_compressed(self)) { TORCH_CHECK( memory_format == MemoryFormat::Preserve, "to(options): ", at::sparse_csr::layoutToString(self.layout()), diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 4b3691a7af3c..a3227df942c4 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -591,8 +591,9 @@ SparseTensor& add_out_sparse_cpu(const SparseTensor& t, const SparseTensor& src, TORCH_CHECK(is_same_density(t, src), "add: expected 'self' and 'other' to have same density, but 'self' has ", t.sparse_dim(), " sparse dimensions while 'other' has ", src.sparse_dim(), " sparse dimensions"); r.resize_as_(src); - - if (src._values().is_contiguous() && t._values().is_contiguous()) { + if (r.is_meta()) { + return r; + } else if (src._values().is_contiguous() && t._values().is_contiguous()) { return add_out_sparse_contiguous(r, t, src, value, commonDtype); } else { return add_out_sparse_non_contiguous(r, t, src, value, commonDtype); diff --git a/test/test_sparse.py b/test/test_sparse.py index da2e08d769cd..9653f97ef4a4 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -4381,6 +4381,56 @@ def test_print_meta(self, dtype, layout): self.maxDiff = orig_maxDiff raise + def assertEqualMeta(self, x, y, expected_nnz): + self.assertEqual(x.layout, y.layout) + self.assertEqual(x.shape, y.shape) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.sparse_dim(), y.sparse_dim()) + self.assertEqual(x.dense_dim(), y.dense_dim()) + + def assertEqualAttrs(x, y, expected_shape): + self.assertEqual(x.shape, expected_shape) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.layout, y.layout) + if not x.is_meta: + self.assertEqual(x.device, y.device) + + if x.layout is torch.sparse_coo: + assertEqualAttrs(x._indices(), y._indices(), (*y._indices().shape[:-1], expected_nnz)) + assertEqualAttrs(x._values(), y._values(), (expected_nnz, *y._values().shape[1:])) + elif x.layout in {torch.sparse_csr, torch.sparse_bsr}: + assertEqualAttrs(x.crow_indices(), y.crow_indices(), y.crow_indices().shape) + assertEqualAttrs(x.col_indices(), y.col_indices(), (*y.col_indices().shape[:-1], expected_nnz)) + batch_dim = x.col_indices().ndim - 1 + assertEqualAttrs(x.values(), y.values(), + (*y.values().shape[:batch_dim], expected_nnz, *y.values().shape[batch_dim + 1:])) + elif x.layout in {torch.sparse_csc, torch.sparse_bsc}: + assertEqualAttrs(x.ccol_indices(), y.ccol_indices(), y.ccol_indices().shape) + assertEqualAttrs(x.row_indices(), y.row_indices(), (*y.row_indices().shape[:-1], expected_nnz)) + batch_dim = x.row_indices().ndim - 1 + assertEqualAttrs(x.values(), y.values(), + (*y.values().shape[:batch_dim], expected_nnz, *y.values().shape[batch_dim + 1:])) + + @all_sparse_layouts('layout', include_strided=False) + @parametrize("dtype", [torch.float64]) + def test_to_meta(self, dtype, layout): + index_dtype = torch.int64 + device = 'cpu' + for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype): + m = t.to(device="meta") + self.assertEqual(m.device.type, "meta") + self.assertEqualMeta(m, t, t._nnz()) + + @all_sparse_layouts('layout', include_strided=False) + @parametrize("dtype", [torch.float64]) + def test_zeros_like_meta(self, dtype, layout): + index_dtype = torch.int64 + device = 'cpu' + for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype): + m = torch.zeros_like(t, device="meta") + self.assertEqual(m.device.type, "meta") + self.assertEqualMeta(m, t, 0) + @all_sparse_layouts('layout', include_strided=False) @parametrize("dtype", [torch.float64]) def test_fake(self, dtype, layout): @@ -4391,45 +4441,7 @@ def test_fake(self, dtype, layout): for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype): f = FakeTensor.from_tensor(t, fake_mode) self.assertIsInstance(f, FakeTensor) - self.assertEqual(f.layout, layout) - self.assertEqual(f.shape, t.shape) - self.assertEqual(f.device, t.device) - if layout is torch.sparse_coo: - nnz = 0 - indices = f._indices() - self.assertEqual(indices.dtype, index_dtype) - self.assertEqual(indices.device, t.device) - self.assertEqual(indices.shape, (*t._indices().shape[:-1], nnz)) - values = f._values() - self.assertEqual(values.dtype, dtype) - self.assertEqual(values.device, t.device) - self.assertEqual(values.shape, (nnz, *t._values().shape[1:])) - else: - nnz = 0 - if layout in {torch.sparse_csr, torch.sparse_bsr}: - f_compressed_indices, f_plain_indices = f.crow_indices(), f.col_indices() - compressed_indices, plain_indices = t.crow_indices(), t.col_indices() - else: - f_compressed_indices, f_plain_indices = f.ccol_indices(), f.row_indices() - compressed_indices, plain_indices = t.ccol_indices(), t.row_indices() - f_values = f.values() - values = t.values() - batch_dims = len(compressed_indices.shape) - 1 - self.assertEqual(f_compressed_indices.layout, compressed_indices.layout) - self.assertEqual(f_compressed_indices.shape, compressed_indices.shape) - self.assertEqual(f_compressed_indices.dtype, compressed_indices.dtype) - self.assertEqual(f_compressed_indices.device, compressed_indices.device) - - self.assertEqual(f_plain_indices.layout, plain_indices.layout) - self.assertEqual(f_plain_indices.shape, (*plain_indices.shape[:-1], nnz)) - self.assertEqual(f_plain_indices.dtype, plain_indices.dtype) - self.assertEqual(f_plain_indices.device, plain_indices.device) - - batch_dim = plain_indices.ndim - 1 - self.assertEqual(f_values.layout, values.layout) - self.assertEqual(f_values.shape, (*values.shape[:batch_dim], nnz, *values.shape[batch_dim + 1:])) - self.assertEqual(f_values.dtype, values.dtype) - self.assertEqual(f_values.device, values.device) + self.assertEqualMeta(f, t, 0) @all_sparse_layouts('layout', include_strided=False) @parametrize("dtype", [torch.float64]) @@ -4464,9 +4476,10 @@ def test_add_meta(self, dtype, layout): device = 'cpu' index_dtype = torch.int64 for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype): + expected = torch.add(t, t).to(device='meta') m = t.to(device='meta') r = torch.add(m, m) - self.assertEqual(r, m) + self.assertEqualMeta(r, expected, 0 if layout is torch.sparse_coo else expected._nnz()) class _SparseDataset(torch.utils.data.Dataset):