diff --git a/docs/source/conf.py b/docs/source/conf.py index ae27491dadab2..9965bb6c6ac33 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1204,6 +1204,8 @@ "rebuild_cuda_tensor", "rebuild_event", "rebuild_nested_tensor", + "rebuild_sparse_coo_tensor", + "rebuild_sparse_compressed_tensor", "rebuild_storage_empty", "rebuild_storage_fd", "rebuild_storage_filename", diff --git a/test/test_sparse.py b/test/test_sparse.py index fad4db97e4d2c..ae8bf0d1ac1c9 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -4309,6 +4309,19 @@ def test_basic(self): self.assertEqual(r.values(), torch.empty(0, 4, device='meta')) +class _SparseDataset(torch.utils.data.Dataset): + # An utility class used in TestSparseAny.test_dataloader method. + + def __init__(self, sparse_tensors): + self.sparse_tensors = sparse_tensors + + def __len__(self): + return len(self.sparse_tensors) + + def __getitem__(self, index): + return self.sparse_tensors[index] + + class TestSparseAny(TestCase): @onlyCPU @@ -5130,6 +5143,19 @@ def identity(x): gradcheck(func, x.requires_grad_(True), masked=masked, fast_mode=fast_mode) + @onlyCPU + @all_sparse_layouts('layout', include_strided=False) + @dtypes(torch.double) + def test_dataloader(self, device, layout, dtype): + + data = list(self.generate_simple_inputs(layout, device=device, dtype=dtype)) + + dataset = _SparseDataset(data) + loader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=2) + + loaded_data = list(loader) + self.assertEqual(data, loaded_data) + # e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta') diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index e17fd232da987..f5eb0a6abd86f 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -188,8 +188,6 @@ def rebuild_cuda_tensor( def reduce_tensor(tensor): - storage = tensor._typed_storage() - if tensor.requires_grad and not tensor.is_leaf: raise RuntimeError( "Cowardly refusing to serialize non-leaf tensor which requires_grad, " @@ -299,6 +297,17 @@ def reduce_tensor(tensor): if tensor.is_nested and not isinstance(tensor, NestedTensor): return reduce_nested_tensor(tensor) + if tensor.layout in { + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_bsr, + torch.sparse_csc, + torch.sparse_bsc, + }: + return reduce_sparse_tensor(tensor) + + storage = tensor._typed_storage() + if storage._untyped_storage.device.type == "cuda": ( device, @@ -387,6 +396,86 @@ def reduce_nested_tensor(nt): ) +def rebuild_sparse_coo_tensor( + rebuild_indices_func, + rebuild_indices_args, + rebuild_values_func, + rebuild_values_args, + shape, + is_coalesced, +): + indices = rebuild_indices_func(*rebuild_indices_args) + values = rebuild_values_func(*rebuild_values_args) + return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced) + + +def rebuild_sparse_compressed_tensor( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + rebuild_plain_indices_func, + rebuild_plain_indices_args, + rebuild_values_func, + rebuild_values_args, + shape, + layout, +): + compressed_indices = rebuild_compressed_indices_func( + *rebuild_compressed_indices_args + ) + plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args) + values = rebuild_values_func(*rebuild_values_args) + return torch.sparse_compressed_tensor( + compressed_indices, plain_indices, values, shape, layout=layout + ) + + +def reduce_sparse_tensor(sparse): + if sparse.layout is torch.sparse_coo: + rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices()) + rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values()) + return ( + rebuild_sparse_coo_tensor, + ( + rebuild_indices_func, + rebuild_indices_args, + rebuild_values_func, + rebuild_values_args, + sparse.shape, + sparse.is_coalesced(), + ), + ) + else: + if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices = sparse.crow_indices() + plain_indices = sparse.col_indices() + elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}: + compressed_indices = sparse.ccol_indices() + plain_indices = sparse.row_indices() + else: + raise NotImplementedError(sparse.layout) + ( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + ) = reduce_tensor(compressed_indices) + rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor( + plain_indices + ) + rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values()) + return ( + rebuild_sparse_compressed_tensor, + ( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + rebuild_plain_indices_func, + rebuild_plain_indices_args, + rebuild_values_func, + rebuild_values_args, + sparse.shape, + sparse.layout, + ), + ) + + def fd_id(fd): # Returns a tuple which uniquely identifies a file descriptor. In Mac OS, # this doesn't work with shared memory handles, which is why we don't diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index af9033014a345..dc9bc7c3c1491 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -160,6 +160,11 @@ def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[ "Batches of nested tensors are not currently supported by the default collate_fn; " "please provide a custom collate_fn to handle them appropriately." ) + if elem.layout in {torch.sparse_coo, torch.sparse_csr, torch.sparse_bsr, torch.sparse_csc, torch.sparse_bsc}: + raise RuntimeError( + "Batches of sparse tensors are not currently supported by the default collate_fn; " + "please provide a custom collate_fn to handle them appropriately." + ) if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy