Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograd error after conv with sparse adjacency mat instead of edge_index #6371

Closed
ArchieGertsman opened this issue Jan 9, 2023 · 3 comments
Labels

Comments

@ArchieGertsman
Copy link
Contributor

ArchieGertsman commented Jan 9, 2023

馃悰 Describe the bug

I am getting an autograd error when calling backward on data that has been fed through a simple graph conv network using a SparseTensor adjacency matrix. If I instead feed the original edge_index into the network, then autograd works fine.

import torch
from torch.nn import Linear
from torch_geometric.nn import MessagePassing
from torch_sparse import matmul, SparseTensor

# fix random seed for reproducibility
torch.random.manual_seed(42)


# simple graph conv
class MyConv(MessagePassing):
    def __init__(self, in_ch, out_ch):
        super().__init__(aggr="add")
        self.lin = Linear(in_ch, out_ch)

    def forward(self, x, edge_index):
        x = self.lin(x)
        return self.propagate(edge_index, x=x)

    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x, reduce=self.aggr)

conv = MyConv(5, 8)


# generate data
x = torch.rand(6, 5)

edge_index = torch.tensor([
    [0,  1,  2,  3,  4],
    [2,  2,  3,  4,  5]
])

adj = SparseTensor(
    row=edge_index[0], 
    col=edge_index[1])


# feed data forward both ways
x1 = conv(x, edge_index) 
print(x1)

x2 = conv(x, adj.t())
print(x2)

assert (x1 == x2).all()


# compute grads
x1.sum().backward() # works!
x2.sum().backward() # error!

out:

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.4300,  1.2338,  2.0123, -0.3427, -1.3851,  0.2302, -0.9762, -0.2997],
        [ 0.7890,  0.6861,  1.0525, -0.4021, -0.9543, -0.1364, -0.3217, -0.2653],
        [ 0.3662,  0.4613,  0.7512, -0.2890, -0.8049, -0.2556, -0.4350,  0.0790],
        [ 0.4926,  0.3874,  0.8244, -0.1239, -0.4223,  0.0178, -0.5007, -0.0089]],
       grad_fn=<ScatterAddBackward0>)
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.4300,  1.2338,  2.0123, -0.3427, -1.3851,  0.2302, -0.9762, -0.2997],
        [ 0.7890,  0.6861,  1.0525, -0.4021, -0.9543, -0.1364, -0.3217, -0.2653],
        [ 0.3662,  0.4613,  0.7512, -0.2890, -0.8049, -0.2556, -0.4350,  0.0790],
        [ 0.4926,  0.3874,  0.8244, -0.1239, -0.4223,  0.0178, -0.5007, -0.0089]],
       grad_fn=<CppNode<SPMMSum>>)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[40], line 52
     50 # compute grads
     51 x1.sum().backward() # works!
---> 52 x2.sum().backward()

File ~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/_tensor.py:488, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    478 if has_torch_function_unary(self):
    479     return handle_torch_function(
    480         Tensor.backward,
    481         (self,),
   (...)
    486         inputs=inputs,
    487     )
--> 488 torch.autograd.backward(
    489     self, gradient, retain_graph, create_graph, inputs=inputs
    490 )

File ~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/autograd/__init__.py:197, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    192     retain_graph = create_graph
    194 # The reason we repeat same the comment below is that
    195 # some Python versions print out the first line of a multi-line function
    196 # calls in the traceback and some print out the last line
--> 197 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    198     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    199     allow_unreachable=True, accumulate_grad=True)

RuntimeError: Function torch::autograd::CppNode<SPMMSum> returned an invalid gradient at index 6 - got [5, 8] but expected shape compatible with [6, 8]

Environment

  • PyG version: 2.2.0
  • PyTorch version: 1.13.1
  • OS: CentOS Linux 7 (Core)
  • Python version: 3.9.12
  • CUDA/cuDNN version: 11.7
  • How you installed PyTorch and PyG: conda
  • torch_sparse version: 0.6.16
@ArchieGertsman ArchieGertsman changed the title Autograd error after successful forward pass with SparseTensor adjacency mat instead of edge_index Autograd error after conv with sparse adjacency mat instead of edge_index Jan 9, 2023
@rusty1s
Copy link
Member

rusty1s commented Jan 9, 2023

Can you define adj as

adj = SparseTensor(row=edge_index[0], col=edge_index[1], sparse_sizes=(6, 6))

This works for me.

@ArchieGertsman
Copy link
Contributor Author

Thanks @rusty1s, this works! I was following Memory-Efficient Aggregations and was unaware of the sparse_sizes argument. Are there docs for torch_sparse anywhere? I can't seem to find any.

@rusty1s
Copy link
Member

rusty1s commented Jan 10, 2023

I didn't manage to provide any internal docs for torch-sparse sadly, but the code should hopefully be self-explanatory, see here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants