Skip to content

Commit

Permalink
Add sparse tensors support to dataloader. (#112842)
Browse files Browse the repository at this point in the history
  • Loading branch information
pearu authored and pytorchmergebot committed Nov 19, 2023
1 parent 12f95df commit 0bd4d1f
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,8 @@
"rebuild_cuda_tensor",
"rebuild_event",
"rebuild_nested_tensor",
"rebuild_sparse_coo_tensor",
"rebuild_sparse_compressed_tensor",
"rebuild_storage_empty",
"rebuild_storage_fd",
"rebuild_storage_filename",
Expand Down
26 changes: 26 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4309,6 +4309,19 @@ def test_basic(self):
self.assertEqual(r.values(), torch.empty(0, 4, device='meta'))


class _SparseDataset(torch.utils.data.Dataset):
# An utility class used in TestSparseAny.test_dataloader method.

def __init__(self, sparse_tensors):
self.sparse_tensors = sparse_tensors

def __len__(self):
return len(self.sparse_tensors)

def __getitem__(self, index):
return self.sparse_tensors[index]


class TestSparseAny(TestCase):

@onlyCPU
Expand Down Expand Up @@ -5130,6 +5143,19 @@ def identity(x):

gradcheck(func, x.requires_grad_(True), masked=masked, fast_mode=fast_mode)

@onlyCPU
@all_sparse_layouts('layout', include_strided=False)
@dtypes(torch.double)
def test_dataloader(self, device, layout, dtype):

data = list(self.generate_simple_inputs(layout, device=device, dtype=dtype))

dataset = _SparseDataset(data)
loader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=2)

loaded_data = list(loader)
self.assertEqual(data, loaded_data)


# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
Expand Down
93 changes: 91 additions & 2 deletions torch/multiprocessing/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,6 @@ def rebuild_cuda_tensor(


def reduce_tensor(tensor):
storage = tensor._typed_storage()

if tensor.requires_grad and not tensor.is_leaf:
raise RuntimeError(
"Cowardly refusing to serialize non-leaf tensor which requires_grad, "
Expand Down Expand Up @@ -299,6 +297,17 @@ def reduce_tensor(tensor):
if tensor.is_nested and not isinstance(tensor, NestedTensor):
return reduce_nested_tensor(tensor)

if tensor.layout in {
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_bsr,
torch.sparse_csc,
torch.sparse_bsc,
}:
return reduce_sparse_tensor(tensor)

storage = tensor._typed_storage()

if storage._untyped_storage.device.type == "cuda":
(
device,
Expand Down Expand Up @@ -387,6 +396,86 @@ def reduce_nested_tensor(nt):
)


def rebuild_sparse_coo_tensor(
rebuild_indices_func,
rebuild_indices_args,
rebuild_values_func,
rebuild_values_args,
shape,
is_coalesced,
):
indices = rebuild_indices_func(*rebuild_indices_args)
values = rebuild_values_func(*rebuild_values_args)
return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced)


def rebuild_sparse_compressed_tensor(
rebuild_compressed_indices_func,
rebuild_compressed_indices_args,
rebuild_plain_indices_func,
rebuild_plain_indices_args,
rebuild_values_func,
rebuild_values_args,
shape,
layout,
):
compressed_indices = rebuild_compressed_indices_func(
*rebuild_compressed_indices_args
)
plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args)
values = rebuild_values_func(*rebuild_values_args)
return torch.sparse_compressed_tensor(
compressed_indices, plain_indices, values, shape, layout=layout
)


def reduce_sparse_tensor(sparse):
if sparse.layout is torch.sparse_coo:
rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices())
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values())
return (
rebuild_sparse_coo_tensor,
(
rebuild_indices_func,
rebuild_indices_args,
rebuild_values_func,
rebuild_values_args,
sparse.shape,
sparse.is_coalesced(),
),
)
else:
if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}:
compressed_indices = sparse.crow_indices()
plain_indices = sparse.col_indices()
elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}:
compressed_indices = sparse.ccol_indices()
plain_indices = sparse.row_indices()
else:
raise NotImplementedError(sparse.layout)
(
rebuild_compressed_indices_func,
rebuild_compressed_indices_args,
) = reduce_tensor(compressed_indices)
rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor(
plain_indices
)
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values())
return (
rebuild_sparse_compressed_tensor,
(
rebuild_compressed_indices_func,
rebuild_compressed_indices_args,
rebuild_plain_indices_func,
rebuild_plain_indices_args,
rebuild_values_func,
rebuild_values_args,
sparse.shape,
sparse.layout,
),
)


def fd_id(fd):
# Returns a tuple which uniquely identifies a file descriptor. In Mac OS,
# this doesn't work with shared memory handles, which is why we don't
Expand Down
5 changes: 5 additions & 0 deletions torch/utils/data/_utils/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[
"Batches of nested tensors are not currently supported by the default collate_fn; "
"please provide a custom collate_fn to handle them appropriately."
)
if elem.layout in {torch.sparse_coo, torch.sparse_csr, torch.sparse_bsr, torch.sparse_csc, torch.sparse_bsc}:
raise RuntimeError(
"Batches of sparse tensors are not currently supported by the default collate_fn; "
"please provide a custom collate_fn to handle them appropriately."
)
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
Expand Down

0 comments on commit 0bd4d1f

Please sign in to comment.