Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backpropagation for sparse matrix indexing is problematic (colab provided) #45996

Open
Tracked by #44634
jasonbian97 opened this issue Oct 7, 2020 · 3 comments
Open
Tracked by #44634
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jasonbian97
Copy link

jasonbian97 commented Oct 7, 2020

馃悰 Bug

When indexing the sparse matrix, it performs well in the forward, but in backpropagation, the gradients become all zeros (i.e. empty sparse matrix).

To Reproduce

I put up a toy case in colab to reproduce.

The strange behavior is the gradients (in sparse) is empty.

Expected behavior

non-zero gradients appear at non-zero entry location.

cc @vincentqb @aocsa @nikitaved @pearu @mruberry

@jasonbian97 jasonbian97 changed the title Backpropagation for sparse matrix is problematic (colab provided) Backpropagation for sparse matrix indexing is problematic (colab provided) Oct 7, 2020
@ngimel ngimel added module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 8, 2020
@mruberry
Copy link
Collaborator

mruberry commented Oct 8, 2020

Thanks for reporting this issue, @jasonbian97! We're reviewing our sparse tensor implementation now, actually, and we'll be sure to look at this behavior, too.

@aocsa
Copy link
Contributor

aocsa commented Jan 18, 2021

I was looking this issue and I find out that there is some part of the code that does not make sense with sparse tensors. The index or get_item operation sp1[i], which call internally to index_select_sparse function creates a new sparse tensor which is different to the strided tensor version code which returns a view. So in the case of the sparse tensor code the computational graph is not connected anymore. And this make sense due to not every sub-tensor sp1[i] is materialized when sp1 is defined as a sparse tensor. IMO this is not an issue as there is not a way to create a view from a sparse tensor and for this case probably the torch.sparse.sum function can be used instead.

cc @jasonbian97, @mruberry, @rgommers

import torch
import torch.nn.functional as F

device = "cuda"
# device = "cpu"

#construct sparse mat
ind = torch.LongTensor([[0, 1, 1,3],
                          [2, 1, 2,3],
                    ])
vals = torch.FloatTensor([3, 4, 5,9])
sp1 = torch.sparse.FloatTensor(ind, vals, torch.Size([5,5])).to(device)

# index sparse mat
print("sp1.to_dense() = \n",sp1.to_dense())
print("sp1[0] = \n", sp1[0]) # this is a new sparse tensor
print("sp1[0].to_dense() = \n", sp1[0].to_dense())

print(sp1[0].grad_fn)

sp1 = sp1.detach().requires_grad_(True)
losses = []
for i in range(sp1.shape[0]):
    loss = torch.sparse.sum(sp1[i]) # not a view, for each i  a new sparse tensor is created
    print(loss)
    losses.append(loss)
l = sum(losses)
print(l)

@jasonbian97
Copy link
Author

Thanks for looking into this!

So what if I just want to use the first non-zero value in the sparse matrix to do some following computation, is there any way I can keep the gradient flow back to that value?

@pearu pearu added this to In progress in Sparse tensors Aug 10, 2021
@pearu pearu moved this from In progress to To do in Sparse tensors Aug 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Development

No branches or pull requests

4 participants