Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparseTensor.to_symmetric() inplace operation #293

Closed
johannaSommer opened this issue Nov 21, 2022 · 3 comments · Fixed by #327
Closed

SparseTensor.to_symmetric() inplace operation #293

johannaSommer opened this issue Nov 21, 2022 · 3 comments · Fixed by #327
Labels
bug Something isn't working

Comments

@johannaSommer
Copy link

Hi,

thanks a lot for your work on torch_sparse! I have encountered the following error when computing gradients on the values of SparseTensor when it has previously been transformed with .to_symmetric(). I do not think this is intended as casting to a symmetric matrix should be a differentiable operation?

I was able to construct the following minimal example:

import torch
import torch_sparse

torch.manual_seed(0)
test_tensor = (torch.randn(5, 5) < 0.1).float()
test_tensor_sparse = torch_sparse.SparseTensor.from_dense(test_tensor)
test_tensor_sparse.requires_grad_()
test_tensor_sparse = test_tensor_sparse.to_symmetric()
loss = test_tensor_sparse.mean()
loss.backward()

Error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_3483971/1228083090.py in <module>
      8 test_tensor_sparse = test_tensor_sparse.to_symmetric()
      9 loss = test_tensor_sparse.mean()
---> 10 loss.backward()

~/miniconda3/envs/molgen/lib/python3.7/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    394                 create_graph=create_graph,
    395                 inputs=inputs)
--> 396         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    397 
    398     def register_hook(self, hook):

~/miniconda3/envs/molgen/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
    176 
    177 def grad(

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.LongTensor [30]] is at version 3; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

When setting torch.autograd.set_detect_anomaly(True):

...
 File "/nfs/homedirs/sommer/miniconda3/envs/molgen/lib/python3.7/site-packages/torch_sparse/tensor.py", line 373, in to_symmetric
    value = torch.cat([value, value])[perm]
 (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484809535/work/torch/csrc/autograd/python_anomaly_mode.cpp:102.)
  allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass

This happens with the following specifications:

python: 3.7.13
torch: 1.12.1 
torch_sparse: 0.6.15

I hope this is helpful, let me know if you need more information or if I can help. Thanks!

@github-actions
Copy link

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?

@github-actions github-actions bot added the stale label May 21, 2023
@rusty1s rusty1s added bug Something isn't working and removed stale labels May 21, 2023
@rusty1s
Copy link
Owner

rusty1s commented May 22, 2023

@johannaSommer I am really sorry, I missed this issue. Taking a look now :(

@rusty1s rusty1s linked a pull request May 22, 2023 that will close this issue
@rusty1s
Copy link
Owner

rusty1s commented May 22, 2023

Fixed via #327.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants