Skip to content

PyTorch implementation of a GNN with a CNN filter

License

Notifications You must be signed in to change notification settings

shobrook/MatrixConv

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MatrixConv

MatrixConv is a graph convolutional filter for graphs where the node features are n-dimensional matrices, such as 2D or 3D images, such as a scene graph. The filter applies a (non-graph) convolution, i.e. torch.nn.Conv{1/2/3}d, to transform the node features. Node embeddings are updated like so:

Where φr and φm are CNNs, and We is a weight matrix.

Installation

This module can be installed with pip:

$ pip install matrix_conv

Usage

MatrixConv is built on PyTorch Geometric and derives from the MessagePassing module. It expects an input graph where each node's "features" is a matrix (either 1D, 2D, or 3D). MatrixConv, similarly to NNConv, also incorporates any available edge features when collecting messages from a node's neighbors.

Parameters:

  • in_channels (int): Number of channels in the input node matrix (e.g. if each node's features is a 3x5 matrix with 2 input channels, then in_channels=2)
  • out_channels (int): Number of channels in the output node embedding
  • matrix_dims (list or tuple): Dimensions of matrix associated with node (e.g. if each node's features is a 3x5 matrix, then matrix_dims=[3, 5])
  • num_edge_attr (int): Number of edge attributes/features
  • kernel_dims (list or tuple): Dimensions of the convolving kernel in the CNN
  • aggr (string, optional): The message aggregation scheme to use ("add", "mean", "max")
  • root_cnn (bool, optional): If set to False, the layer will not add the CNN-transformed root node features to the output
  • bias (bool, optional): If set to False, the layer will not learn an additive bias
  • **kwargs (optional): Additional arguments for torch.nn.Conv{1/2/3}d

Example Usage:

import torch
from matrix_conv import MatrixConv

# Convolutional layer
conv_layer = MatrixConv(
    in_channels=1,
    out_channels=10,
    matrix_dims=[5, 5, 5],
    num_edge_attr=3,
    kernel_dims=[2, 3, 3]
)

# Your input graph (see: https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#data-handling-of-graphs)
x = torch.randn((3, 1, 5, 5, 5), dtype=torch.float) # Shape is [num_nodes, in_channels, *matrix_dims]
edge_index = torch.tensor([
    [0, 1, 1, 2],
    [1, 0, 2, 1]
], dtype=torch.long)
edge_attr = torch.randn((4, 3), dtype=torch.float)

# Your output graph
x = conv_layer(x, edge_index, edge_attr) # Shape is now [3, 10, 4, 3, 3]

To-Do: Show example of using this in a graph classifier (include stacking)

About

PyTorch implementation of a GNN with a CNN filter

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages