Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Todo functions and autograd supports for Sparse Tensor #8853

Open
9 of 14 tasks
weiyangfb opened this issue Jun 25, 2018 · 21 comments
Open
9 of 14 tasks

Todo functions and autograd supports for Sparse Tensor #8853

weiyangfb opened this issue Jun 25, 2018 · 21 comments
Assignees
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@weiyangfb
Copy link
Contributor

weiyangfb commented Jun 25, 2018

Here summarizes a list of requested Sparse Tensor functions and autograd supports from previous PRs. Please feel free to comment on functions that should be added also.

Functions

Wish list

  • bmm(S, D) (add an extra sparse dim at indices of SparseTensor as batch dim?)
  • broadcasting mul(S, D) -> S
  • Dataset, Dataloader
  • save, load for sparse tensors

Existing

  • autograd supported for values() via [sparse] Autograd indices/values and sparse_coo ctor #13001 (Thanks to @ssnl!), that means all element-wise ops are supported in sparse now
  • norm (cannot take dim args)
  • pow
  • clone
  • zero_
  • t_ / t
  • add_ / add(Sparse, Sparse, Scalar) -> Sparse
  • add_ / add(Dense, Sparse, Scalar) -> Dense
  • sub_ / sub(Sparse, Sparse, Scalar) -> Sparse
  • mul_ / mul(Sparse, Sparse) -> Sparse
  • mul_ / mul(Sparse, Scalar) -> Sparse
  • div_ / div(Sparse, Scalar) -> Sparse
  • addmm(Dense, Sparse, Dense, Scalar, Scalar) -> Dense
  • sspaddmm(Sparse, Sparse, Dense, Scalar, Scalar) -> Sparse
  • mm(Sparse, Dense) -> Dense
  • smm(Sparse, Dense) -> Sparse
  • hspmm(Sparse, Dense) -> HybridSparse
  • spmm(Sparse, Dense) -> Dense
@ezyang
Copy link
Contributor

ezyang commented Jun 25, 2018

Also:

Also Wojciej has been posting about these recently. He'd like some variant of this script to work:

# sparse matrix (e.g. weights)
indices = torch.LongTensor([[0,2], [1,0], [1,2]])
indices = indices.transpose(1,0).contiguous()
values = torch.FloatTensor(int(indices.view(-1).size(0)/2)).uniform_(-1,1)
t = torch.sparse.FloatTensor(indices, values, torch.Size([10,10])).cuda()

# sparse vector (e.g. activations)
v = torch.sparse.FloatTensor(
    torch.LongTensor([[0], [1], [2]]).transpose(1,0), 
    torch.FloatTensor([3,4,5]),
    torch.Size([10])
).cuda()

# works fine
print(F.linear(v.to_dense(), t.to_dense()))

# RuntimeError: numel is not implemented for type torch.cuda.sparse.FloatTensor
# print(F.linear(v, t.to_dense()))

# RuntimeError: numel is not implemented for type torch.cuda.sparse.FloatTensor
# print(F.linear(v, t)

# RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.cuda.sparse.FloatTensor for argument #2 'mat2'
# print(F.linear(v.to_dense(), t))

@ezyang
Copy link
Contributor

ezyang commented Jun 25, 2018

See also #8856

@weiyangfb
Copy link
Contributor Author

@ezyang Ok I am adding them as well. For #1550, is it asking for matrix-matrix multiplication or element-wise multiplication, or both?

@li-roy
Copy link
Contributor

li-roy commented Jun 27, 2018

I think norm is already implemented.

@weiyangfb
Copy link
Contributor Author

@li-roy oh, you're right! Will remove it from todo

facebook-github-bot pushed a commit that referenced this issue Jun 28, 2018
Summary:
- fixes log1p at #8853
- added log1p of sparse tensor in ATen
- make log1p of sparse tensor non-differentiable and raise error, because local derivate of log1p for zero element is 1 / (0 + 1) = 1 and make tensor dense
Closes #8969

Reviewed By: ezyang

Differential Revision: D8677491

fbshipit-source-id: 8363a613519de4bc75eda087ccd20a3eb2d18126
@weiyangfb
Copy link
Contributor Author

weiyangfb commented Jun 28, 2018

also added TODO requests from @adefazio: ops and autograd support of

  • mul(S, S) -> S
  • mul(S, D) -> S
  • dim-wise ops (softmax)

goodlux pushed a commit to goodlux/pytorch that referenced this issue Aug 15, 2018
Summary:
- fixes log1p at pytorch#8853
- added log1p of sparse tensor in ATen
- make log1p of sparse tensor non-differentiable and raise error, because local derivate of log1p for zero element is 1 / (0 + 1) = 1 and make tensor dense
Closes pytorch#8969

Reviewed By: ezyang

Differential Revision: D8677491

fbshipit-source-id: 8363a613519de4bc75eda087ccd20a3eb2d18126
@weiyangfb weiyangfb added this to sparse in Issue Categories Aug 31, 2018
realdoug pushed a commit to realdoug/pytorch that referenced this issue Sep 6, 2018
@realdoug
Copy link
Contributor

realdoug commented Sep 6, 2018

@weiyangfb I took a pass at implementing narrow from your list. Posted a couple questions on the PR if you have a chance to look. Thanks!

@weiyangfb
Copy link
Contributor Author

@realdoug Thanks for working on this feature! I will take a look!

@weiyangfb weiyangfb added the module: sparse Related to torch.sparse label Sep 10, 2018
realdoug pushed a commit to realdoug/pytorch that referenced this issue Sep 18, 2018
realdoug pushed a commit to realdoug/pytorch that referenced this issue Sep 19, 2018
@realdoug
Copy link
Contributor

@weiyangfb is there a priority list at all for these ops? I think I can knock out a few more.

@weiyangfb
Copy link
Contributor Author

@realdoug Thanks a lot for the great work! Let me take a look, stay tuned!

realdoug pushed a commit to realdoug/pytorch that referenced this issue Sep 25, 2018
realdoug added a commit to realdoug/pytorch that referenced this issue Oct 26, 2018
facebook-github-bot pushed a commit that referenced this issue Oct 27, 2018
Summary:
Here is my stab at ```dense.to_sparse```
Pull Request resolved: #12171

Differential Revision: D10859078

Pulled By: weiyangfb

fbshipit-source-id: 5df72f72ba4f8f10e283402ff7731fd535682664
@Queuecumber
Copy link

I'd like to add einsum and mm(S,S)->S to the list if possible

@gongliyu
Copy link

Is it possible to add matmul(S, D) where S and D have rank > 2?

@pietern pietern added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed medium priority (this tag is deprecated) labels Apr 24, 2019
@rdv0011
Copy link

rdv0011 commented Mar 23, 2020

Is this topic related to conversion pytorch to ONNX?
I am facing the error when GraphEncoder::EncodeTensor() tries to call SparseTensorImpl::is_contiguous() which is not implemented. Please check the details here: onnx/onnx#2676

@rohanrajpal
Copy link

Hey

I'm trying to implement cosine similarity of two csr matrices in Pytorch

def awesome_cossim_top(A, B, ntop, lower_bound=0):
  # force A and B as a CSR matrix.
  # If they have already been CSR, there is no overhead
  A = A.tocsr()
  B = B.tocsr()
  M, _ = A.shape
  _, N = B.shape

  idx_dtype = np.int32

  nnz_max = M*ntop

  indptr = np.zeros(M+1, dtype=idx_dtype)
  indices = np.zeros(nnz_max, dtype=idx_dtype)
  data = np.zeros(nnz_max, dtype=A.dtype)

  ct.sparse_dot_topn(
      M, N, np.asarray(A.indptr, dtype=idx_dtype),
      np.asarray(A.indices, dtype=idx_dtype),
      A.data,
      np.asarray(B.indptr, dtype=idx_dtype),
      np.asarray(B.indices, dtype=idx_dtype),
      B.data,
      ntop,
      lower_bound,
      indptr, indices, data)

  return csr_matrix((data,indices,indptr),shape=(M,N))

So far I've reached

def csr_to_coo(X):
  X = X.tocoo()
  return torch.sparse.LongTensor(torch.LongTensor([X.row.tolist(), X.col.tolist()]),
                              torch.LongTensor(X.data.astype(np.int32)))

def cosine_distance(x1, x2=None, eps=1e-8):
    return 1 - torch.sparse.mm(x1, x2)

But if I run the code below

from scipy.sparse import csr_matrix

Acsr = csr_matrix([[1, 2, 0], [0, 0, 3], [4, 0, 5]])

print(cosine_distance(csr_to_coo(Acsr),csr_to_coo(Acsr)))

it throws the following error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-eeed50549737> in <module>()
     21   return torch.sparse.LongTensor(torch.LongTensor([X.row.tolist(), X.col.tolist()]),
     22                               torch.LongTensor(X.data.astype(np.int32)))
---> 23 print(cosine_distance(csr_to_coo(Acsr),csr_to_coo(Acsr)))

1 frames
<ipython-input-16-eeed50549737> in cosine_distance(x1, x2, eps)
      5     # w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
      6     # return 1 - torch.sparse.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)
----> 7     return 1 - torch.sparse.mm(x1, x2)
      8 
      9 Acsr = csr_matrix([[1, 2, 0], [0, 0, 3], [4, 0, 5]])

/usr/local/lib/python3.6/dist-packages/torch/sparse/__init__.py in mm(mat1, mat2)
     66                size=(2, 3), nnz=6, layout=torch.sparse_coo)
     67     """
---> 68     return torch._sparse_mm(mat1, mat2)
     69 
     70 

RuntimeError: sparse tensors do not have strides

From the ToDo I can see that (sparse, sparse) multiplication is supported, am I doing something wrong or this isn't supported yet?

@ezyang
Copy link
Contributor

ezyang commented Jul 22, 2020

^ @aocsa @pearu maybe you'd be able to answer this?

@aocsa
Copy link
Contributor

aocsa commented Jul 23, 2020

@rohanrajpal, sparse-sparse matmul is not supported in current master branch. Current torch.sparse.mm(x1, x2) function support sparse-dense matmul. The sparse-sparse matmul is an WIP in this PR.

@rohanrajpal
Copy link

@rohanrajpal, sparse-sparse matmul is not supported in current master branch. Current torch.sparse.mm(x1, x2) function support sparse-dense matmul. The sparse-sparse matmul is an WIP in this PR.

Alright, thanks!

@tczhangzhi
Copy link
Contributor

@weiyangfb, I am working on graph convolutional neural networks, and need to perform expm1 operations on sparse tensors. In this s #46914 PR, I implemented an expm1() with autograd. It would be nice if someone can help review it.

@vadimkantorov
Copy link
Contributor

Also, it would be good to support torch.min / torch.amin.

These could then be used in absence of #22378

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Sparse tensors
Tracking Issues
Development

No branches or pull requests