Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does spspmm operation support autograd? #45

Open
changym3 opened this issue Mar 9, 2020 · 17 comments
Open

Does spspmm operation support autograd? #45

changym3 opened this issue Mar 9, 2020 · 17 comments
Labels
enhancement New feature or request

Comments

@changym3
Copy link

changym3 commented Mar 9, 2020

Hi, you say autograd is supported for values tensors, but it seems it doesn't work in spspmm.

Like this:

indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]])
valueA = torch.tensor([1, 2.0, 3, 4, 5], requires_grad=True)
indexB = torch.tensor([[0, 2], [1, 0]])
valueB = torch.tensor([2, 4.0], requires_grad=True)
indexC, valueC = torch_sparse.spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)

print(valueC.requires_grad)
print(valueC.grad_fn)

And the answer is:

False
None

In my case, I want to parameterize the sparse adjacent matrix and feature matrix in GCN, so the inputs need to be both differentiable. I wonder if there're some bugs or just the way it is.

Regards.

@rusty1s
Copy link
Owner

rusty1s commented Mar 9, 2020

That's the only function that does not have proper autograd support. Gradients for sparse-sparse matrix multiplication are quite difficult to obtain (since they are usually dense). I had a working, but slow implementation up to 0.4.4 release, but removed it since it wasn't a really good implementation. If you desperately need it, feel free to try it out.

@changym3
Copy link
Author

That's the only function that does not have proper autograd support. Gradients for sparse-sparse matrix multiplication are quite difficult to obtain (since they are usually dense). I had a working, but slow implementation up to 0.4.4 release, but removed it since it wasn't a really good implementation. If you desperately need it, feel free to try it out.

Hey! Thanks for your great work! I have installed the 0.4.4 release of torch_sparse and it totally works out in my experiments! Maybe you could add this information to the documentation. It takes me so long to figure out this no-autograd problem.

Thanks a lot again!

@LuciusMos
Copy link

Hey! Thanks for your great work! I have installed the 0.4.4 release of torch_sparse and it totally works out in my experiments! Maybe you could add this information to the documentation. It takes me so long to figure out this no-autograd problem.

Thanks a lot again!

Thank you so much for your question raising! It really troubles me for almost a week!

@rusty1s
Copy link
Owner

rusty1s commented Jul 27, 2020

Sorry for the inconveniences. I have plans to add backward support for spspmm back in ASAP, see pyg-team/pytorch_geometric#1465.

@jlevy44
Copy link

jlevy44 commented Dec 27, 2020

Do you have any updates on autograd support?

@jlevy44
Copy link

jlevy44 commented Dec 27, 2020

I'm parameterizing the weights of a sparse matrix to treat it as a locally connected network for a sparsely connected MLP implementation. Could I still run a backward pass to update these weights after calling matmul between this sparse matrix and a dense input?

@jlevy44
Copy link

jlevy44 commented Dec 27, 2020

@JRD971000
Copy link

Does spspmm still lack autograd support?

@github-actions
Copy link

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?

@github-actions github-actions bot added the stale label Feb 17, 2022
@rusty1s rusty1s added enhancement New feature or request and removed stale labels Feb 17, 2022
@jaynanavati-az
Copy link

jaynanavati-az commented Apr 14, 2022

Does spspmm still lack autograd support? @rusty1s .. it seems to use SparseTensor, which is supposed to be fully supported by autograd?

@rusty1s
Copy link
Owner

rusty1s commented Apr 14, 2022

Sadly yes :(

@jaynanavati-az
Copy link

Is there an alternative? It is difficult to get earlier versions of torch sparse that have this to work on newer cuda versions.. :(

@rusty1s
Copy link
Owner

rusty1s commented Apr 19, 2022

There isn‘t a workaround except for installing an earlier version. If you are interested, we can try to bring it back with your help. WDYT?

@jaynanavati-az
Copy link

@rusty1s sounds good, why don't we start with putting back your existing implementation? is it not better than having nothing?

@rusty1s
Copy link
Owner

rusty1s commented Apr 22, 2022

RexYing pushed a commit to RexYing/pytorch_sparse that referenced this issue Apr 26, 2022
* Add config setup for Kumo training pipeline

* Add config setup for Kumo training pipeline

* Update cfg, rm irrelevant configs

* Add scripts for generating&sampling batch of configs

* Fix workflow issues

* Fix workflow issues

* fix windows CI

* update according to reviews

* fix lint

* rm unnecessary configs to fix bug

* Update testing.yml

Add dependencies for testing

* Update gpu_testing.yml

Add dependencies for testing

* change output dir

* fix GPU test

* fix GPU test

* update

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
@jaynanavati-az
Copy link

jaynanavati-az commented Oct 11, 2022 via email

@rusty1s
Copy link
Owner

rusty1s commented Oct 11, 2022

With PyTorch 1.12, I assume you can also try to use the sparse-matrix multiplication from PyTorch directly. PyTorch recently integrated better sparse matrix support into its library :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

6 participants