Skip to content

Commit

Permalink
dense.to_sparse() re: #8853 (#12171)
Browse files Browse the repository at this point in the history
Summary:
Here is my stab at ```dense.to_sparse```
Pull Request resolved: #12171

Differential Revision: D10859078

Pulled By: weiyangfb

fbshipit-source-id: 5df72f72ba4f8f10e283402ff7731fd535682664
  • Loading branch information
realdoug authored and facebook-github-bot committed Oct 27, 2018
1 parent 5182fda commit bc352ac
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/core/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@ class CAFFE2_API Tensor {
Tensor values() const;
int64_t numel() const;
std::vector<Tensor> unbind(int64_t dim=0) const;
Tensor to_sparse(int64_t sparse_dim) const;
Tensor to_sparse() const;
Tensor to(Device device, ScalarType dtype, bool non_blocking=false, bool copy=false) const;
Tensor to(ScalarType dtype, bool non_blocking=false, bool copy=false) const;
Tensor to(Device device, bool non_blocking=false, bool copy=false) const;
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/core/TensorMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,12 @@ inline int64_t Tensor::numel() const {
inline std::vector<Tensor> Tensor::unbind(int64_t dim) const {
return type().unbind(*this, dim);
}
inline Tensor Tensor::to_sparse(int64_t sparse_dim) const {
return type().to_sparse(*this, sparse_dim);
}
inline Tensor Tensor::to_sparse() const {
return type().to_sparse(*this);
}
inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking, bool copy) const {
return type().to(*this, device, dtype, non_blocking, copy);
}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ struct CAFFE2_API Type {
virtual Tensor values(const Tensor & self) const = 0;
virtual int64_t numel(const Tensor & self) const = 0;
virtual std::vector<Tensor> unbind(const Tensor & self, int64_t dim) const = 0;
virtual Tensor to_sparse(const Tensor & self, int64_t sparse_dim) const = 0;
virtual Tensor to_sparse(const Tensor & self) const = 0;
virtual Tensor to(const Tensor & self, Device device, ScalarType dtype, bool non_blocking, bool copy) const = 0;
virtual Tensor to(const Tensor & self, ScalarType dtype, bool non_blocking, bool copy) const = 0;
virtual Tensor to(const Tensor & self, Device device, bool non_blocking, bool copy) const = 0;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ _(aten, threshold) \
_(aten, threshold_backward) \
_(aten, threshold_forward) \
_(aten, to) \
_(aten, to_sparse) \
_(aten, to_dense) \
_(aten, topk) \
_(aten, trace) \
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2204,6 +2204,18 @@
- func: unbind(Tensor self, int64_t dim=0) -> TensorList
variants: function, method

- func: to_sparse(Tensor self, int64_t sparse_dim) -> Tensor
variants: method
dispatch:
CPU: dense_to_sparse
CUDA: dense_to_sparse

- func: to_sparse(Tensor self) -> Tensor
variants: method
dispatch:
CPU: dense_to_sparse
CUDA: dense_to_sparse

- func: to(Tensor self, Device device, ScalarType dtype, bool non_blocking=false, bool copy=false) -> Tensor
variants: method
device_guard: False
Expand Down
32 changes: 32 additions & 0 deletions aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,38 @@ SparseTensor& resize_as_sparse_(SparseTensor& self, const SparseTensor& src) {
return self;
}

SparseTensor dense_to_sparse(const Tensor& self){
return dense_to_sparse(self, self.dim());
}

SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
int64_t dims = self.dim();
AT_CHECK(sparse_dim > 0, "sparse_dim must be >0");
AT_CHECK(sparse_dim <= dims,
"sparse_dim must be less than or equal to self.dim()");
at::TensorOptions sparse_options = self.type().toSparse().options();
std::vector<int64_t> sizes = self.sizes().vec();

Tensor nz = self.nonzero().transpose(0, 1);
if (nz.numel() == 0) {
return new_with_dims_sparse(sparse_dim, dims - sparse_dim, sizes, sparse_options);
}
LongTensor indices;
if (sparse_dim == dims) {
indices = nz.clone();
} else {
Tensor i = nz.narrow(0, 0, sparse_dim);
std::tie(indices, std::ignore) = _unique_dim(i, 1);
indices = indices.contiguous(); // many sparse CUDA kernels require contiguity, see issue #12633
}

std::vector<Tensor> ix = indices.chunk(indices.size(0), 0);
Tensor values = self.index(ix).squeeze(0).clone();

Tensor sparse = at::sparse_coo_tensor(indices, values, sizes, sparse_options);
return sparse._coalesced_(true);
}

// NB: Dropped the resizeNd variants

Tensor sparse_to_dense(const SparseTensor& self) {
Expand Down
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: tanh_
.. automethod:: tolist
.. automethod:: topk
.. automethod:: to_sparse
.. automethod:: trace
.. automethod:: transpose
.. automethod:: transpose_
Expand Down
18 changes: 18 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,24 @@ def test_tensor(x, res):
res = self.ValueTensor(3, 4, 5, 0)
test_tensor(x, res)

@skipIfRocm # see https://github.com/pytorch/pytorch/pull/12171#issuecomment-431069849
def test_to_sparse(self):
shape = [10, 5, 19, 8]
max_nnz = 1
for dim, dim_sz in enumerate(shape, 1):
max_nnz *= dim_sz
rnnz = torch.randint(2, max_nnz, (1,)).item()
for nnz in [0, 1, rnnz]:
expected, _, _ = self._gen_sparse(dim, nnz, shape)
d = expected.to_dense()
result = d.to_sparse(dim)
self.assertEqual(d, result.to_dense()) # == not implemented for sparse tensors yet
self.assertEqual(expected.size(), result.size())
self.assertEqual(dim, result.sparse_dim())

sp, _, _ = self._gen_sparse(2, 10, [3, 3, 3])
self.assertRaises(RuntimeError, lambda: sp.to_sparse())

@skipIfRocm
def test_shared(self):
i = self.IndexTensor([[2]])
Expand Down
24 changes: 24 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,6 +2460,30 @@ def callable(a, b) -> number
See :func:`torch.topk`
""")

add_docstr_all('to_sparse',
r"""
to_sparse(sparseDims) -> Tensor
Returns a sparse copy of the tensor. PyTorch supports sparse tensors in
:ref:`coordinate format <sparse-docs>`.
Args:
sparseDims (int, optional): the number of sparse dimensions to include in the new sparse tensor
Example::
>>> d = torch.tensor([[0, 0, 0], [9, 0, 10], [0, 0, 0]])
>>> d
tensor([[ 0, 0, 0],
[ 9, 0, 10],
[ 0, 0, 0]])
>>> d.to_sparse()
tensor(indices=tensor([[1, 1],
[0, 2]]),
values=tensor([ 9, 10]),
size=(3, 3), nnz=2, layout=torch.sparse_coo)
>>> d.to_sparse(1)
tensor(indices=tensor([[1]]),
values=tensor([[ 9, 0, 10]]),
size=(3, 3), nnz=1, layout=torch.sparse_coo)
""")

add_docstr_all('trace',
r"""
trace() -> Tensor
Expand Down

0 comments on commit bc352ac

Please sign in to comment.