Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
88 lines (62 sloc) 2.74 KB

Graph ConvNets in PyTorch

October 15, 2017

Xavier Bresson


Prototype implementation in PyTorch of the NIPS'16 paper:
Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering
M Defferrard, X Bresson, P Vandergheynst
Advances in Neural Information Processing Systems, 3844-3852, 2016
ArXiv preprint: arXiv:1606.09375

Code objective

The code provides a simple example of graph ConvNets for the MNIST classification task.
The graph is a 8-nearest neighbor graph of a 2D grid.
The signals on graph are the MNIST images vectorized as $28^2 \times 1$ vectors.


git clone
cd graph_convnets_pytorch
pip install -r requirements.txt # installation for python 3.6.2
jupyter notebook # run the 2 notebooks


GPU Quadro M4000

  • Standard ConvNets: 01_standard_convnet_lenet5_mnist_pytorch.ipynb, accuracy= 99.31, speed= 6.9 sec/epoch.
  • Graph ConvNets: 02_graph_convnet_lenet5_mnist_pytorch.ipynb, accuracy= 99.19, speed= 100.8 sec/epoch


PyTorch has not yet implemented function, dense) for variables: It will be certainly implemented but in the meantime, I defined a new autograd function for sparse variables, called "my_sparse_mm", by subclassing torch.autograd.function and implementing the forward and backward passes.

class my_sparse_mm(torch.autograd.Function):
    Implementation of a new autograd function for sparse variables, 
    called "my_sparse_mm", by subclassing torch.autograd.Function 
    and implementing the forward and backward passes.
    def forward(self, W, x):  # W is SPARSE
        self.save_for_backward(W, x)
        y =, x)
        return y
    def backward(self, grad_output):
        W, x = self.saved_tensors 
        grad_input = grad_output.clone()
        grad_input_dL_dW =, x.t()) 
        grad_input_dL_dx =, grad_input )
        return grad_input_dL_dW, grad_input_dL_dx

When to use this algorithm?

Any problem that can be cast as analyzing a set of signals on a fixed graph, and you want to use ConvNets for this analysis.