Skip to content

Latest commit

 

History

History
239 lines (164 loc) · 12.9 KB

create_gnn.rst

File metadata and controls

239 lines (164 loc) · 12.9 KB

Creating Message Passing Networks

Generalizing the convolution operator to irregular domains is typically expressed as a neighborhood aggregation or message passing scheme. With \mathbf{x}^{(k-1)}_i \in \mathbb{R}^F denoting node features of node i in layer (k-1) and \mathbf{e}_{j,i} \in \mathbb{R}^D denoting (optional) edge features from node j to node i, message passing graph neural networks can be described as

\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \bigoplus_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right),

where \bigoplus denotes a differentiable, permutation invariant function, e.g., sum, mean or max, and \gamma and \phi denote differentiable functions such as MLPs (Multi Layer Perceptrons).

:pyg:`PyG` provides the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` base class, which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation. The user only has to define the functions \phi , i.e. :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message`, and \gamma , i.e. :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.update`, as well as the aggregation scheme to use, i.e. :obj:`aggr="add"`, :obj:`aggr="mean"` or :obj:`aggr="max"`.

This is done with the help of the following methods:

Let us verify this by re-implementing two popular GNN variants, the GCN layer from Kipf and Welling and the EdgeConv layer from Wang et al..

The GCN layer is mathematically defined as

\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b},

where neighboring node features are first transformed by a weight matrix \mathbf{W}, normalized by their degree, and finally summed up. Lastly, we apply the bias vector \mathbf{b} to the aggregated output. This formula can be divided into the following steps:

  1. Add self-loops to the adjacency matrix.
  2. Linearly transform node feature matrix.
  3. Compute normalization coefficients.
  4. Normalize node features in \phi.
  5. Sum up neighboring node features (:obj:`"add"` aggregation).
  6. Apply a final bias vector.

Steps 1-3 are typically computed before message passing takes place. Steps 4-5 can be easily processed using the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` base class. The full layer implementation is shown below:

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out += self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

:class:`~torch_geometric.nn.conv.GCNConv` inherits from :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` with :obj:`"add"` propagation. All the logic of the layer takes place in its :meth:`forward` method. Here, we first add self-loops to our edge indices using the :meth:`torch_geometric.utils.add_self_loops` function (step 1), as well as linearly transform node features by calling the :class:`torch.nn.Linear` instance (step 2).

The normalization coefficients are derived by the node degrees \deg(i) for each node i which gets transformed to 1/(\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}) for each edge (j,i) \in \mathcal{E}. The result is saved in the tensor :obj:`norm` of shape :obj:`[num_edges, ]` (step 3).

We then call :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`, which internally calls :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message`, :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.aggregate` and :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.update`. We pass the node embeddings :obj:`x` and the normalization coefficients :obj:`norm` as additional arguments for message propagation.

In the :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message` function, we need to normalize the neighboring node features :obj:`x_j` by :obj:`norm`. Here, :obj:`x_j` denotes a lifted tensor, which contains the source node features of each edge, i.e., the neighbors of each node. Node features can be automatically lifted by appending :obj:`_i` or :obj:`_j` to the variable name. In fact, any tensor can be converted this way, as long as they hold source or destination node features.

That is all that it takes to create a simple message passing layer. You can use this layer as a building block for deep architectures. Initializing and calling it is straightforward:

conv = GCNConv(16, 32)
x = conv(x, edge_index)

The edge convolutional layer processes graphs or point clouds and is mathematically defined as

\mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}} \left( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right),

where h_{\mathbf{\Theta}} denotes an MLP. In analogy to the GCN layer, we can use the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` class to implement this layer, this time using the :obj:`"max"` aggregation:

import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing

class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max') #  "Max" aggregation.
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)

Inside the :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message` function, we use :obj:`self.mlp` to transform both the target node features :obj:`x_i` and the relative source node features :obj:`x_j - x_i` for each edge (j,i) \in \mathcal{E}.

The edge convolution is actually a dynamic convolution, which recomputes the graph for each layer using nearest neighbors in the feature space. Luckily, :pyg:`PyG` comes with a GPU accelerated batch-wise k-NN graph generation method named :meth:`torch_geometric.nn.pool.knn_graph`:

from torch_geometric.nn import knn_graph

class DynamicEdgeConv(EdgeConv):
    def __init__(self, in_channels, out_channels, k=6):
        super().__init__(in_channels, out_channels)
        self.k = k

    def forward(self, x, batch=None):
        edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
        return super().forward(x, edge_index)

Here, :meth:`~torch_geometric.nn.pool.knn_graph` computes a nearest neighbor graph, which is further used to call the :meth:`forward` method of :class:`~torch_geometric.nn.conv.EdgeConv`.

This leaves us with a clean interface for initializing and calling this layer:

conv = DynamicEdgeConv(3, 128, k=6)
x = conv(x, batch)

Imagine we are given the following :obj:`~torch_geometric.data.Data` object:

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous())

Try to answer the following questions related to :class:`~torch_geometric.nn.conv.GCNConv`:

  1. What information does :obj:`row` and :obj:`col` hold?
  2. What does :meth:`~torch_geometric.utils.degree` do?
  3. Why do we use :obj:`degree(col, ...)` rather than :obj:`degree(row, ...)`?
  4. What does :obj:`deg_inv_sqrt[col]` and :obj:`deg_inv_sqrt[row]` do?
  5. What information does :obj:`x_j` hold in the :meth:`~torch_geometric.nn.conv.MessagePassing.message` function? If :obj:`self.lin` denotes the identity function, what is the exact content of :obj:`x_j`?
  6. Add an :meth:`~torch_geometric.nn.conv.MessagePassing.update` function to :class:`~torch_geometric.nn.conv.GCNConv` that adds transformed central node features to the aggregated output.

Try to answer the following questions related to :class:`~torch_geometric.nn.conv.EdgeConv`:

  1. What is :obj:`x_i` and :obj:`x_j - x_i`?
  2. What does :obj:`torch.cat([x_i, x_j - x_i], dim=1)` do? Why :obj:`dim = 1`?