
General Message Passing:
- x_i is the node feature
- e_ji is an optional edge feature
- square is a diff perm invariant agg function
- gamma and phi are differential fn, like MLP

$$\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right),$$


PyG provides the MessagePassing base class, which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation. The user only has to define the functions  , $\phi$, i.e. message(), and $\gamma$, i.e. update(), as well as the aggregation scheme to use, i.e. aggr="add", aggr="mean" or aggr="max".

MessagePassing(aggr="add", flow="source_to_target", node_dim=-2): Defines the aggregation scheme to use ("add", "mean" or "max") and the flow direction of message passing (either "source_to_target" or "target_to_source"). Furthermore, the node_dim attribute indicates along which axis to propagate.

MessagePassing.propagate(edge_index, size=None, **kwargs): The initial call to start propagating messages. Takes in the edge indices and all additional data which is needed to construct messages and to update node embeddings. Note that propagate() is not limited to exchanging messages in square adjacency matrices of shape [N, N] only, but can also exchange messages in general sparse assignment matrices, e.g., bipartite graphs, of shape [N, M] by passing size=(N, M) as an additional argument. If set to None, the assignment matrix is assumed to be a square matrix. For bipartite graphs with two independent sets of nodes and indices, and each set holding its own information, this split can be marked by passing the information as a tuple, e.g. x=(x_N, x_M).

MessagePassing.message(...): Constructs messages to node  in analogy to $\phi$ for each edge $(j,i) \in \mathcal{E}$ if flow="source_to_target" and $(i,j) \in \mathcal{E}$ if flow="target_to_source". Can take any argument which was initially passed to propagate(). In addition, tensors passed to propagate() can be mapped to the respective nodes  and  by appending _i or _j to the variable name, e.g. x_i and x_j. Note that we generally refer to i as the central nodes that aggregates information, and refer to  j as the neighboring nodes, since this is the most common notation.

MessagePassing.update(aggr_out, ...): Updates node embeddings in analogy to $\gamma$ for each node $i \in \mathcal{V}$. Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().

## Implementing the GCN Layer 

$$\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b},$$

Before Message Passing:

- Add self-loops to the adjacency matrix.

- Linearly transform node feature matrix.

- Compute normalization coefficients.

During Message Passing: 

- Normalize node features in .

- Sum up neighboring node features ("add" aggregation).

Finally:
- Apply a final bias vector.

In [None]:
import torch
from torch.nn import Linear, Parameter, ReLU,  Sequential as Seq
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [None]:
class GNNLayer(MessagePassing):
    # Initialize the class with the aggregration type
    def __init__(self,aggr='add'):
        self.lin = Linear
    
    def forward(self,x,edge_index):
        '''
        
        Do the math inside the aggregation function Phi

        Then propagate the 'message':
            - propagate internally calls:
                - message, aggregate, and update
        
        '''

    def message(self, x_j, norm):
        '''define the message that is passed via propagation'''

In [2]:




class GCNConv(MessagePassing):
    def __init__(self,in_channels, out_channels): # Message Passing Initialization Component
        super().__init__(aggr="add") # initialize with aggregation mode of 'add', could also add flow, and node_dim args
        self.lin = Linear(in_channels, out_channels)
        self.bias = Parameter(torch.Tensor(out_channels))
        self.reset_parameters()
    
    def reset_parameters(self):
        self.lin.reset_parameters() # resets to initialization method: torch uses kaiming by default
        self.bias.data.zero_() # zeros all values in the tensor
    

    ''' forward contains the propagate step, which'''
    def forward(self, x, edge_index):
        # x is NxIn_Channels
        # edge_index = 2xE

        # step 1: Add loop to adj matrix
        edge_index = add_self_loops(edge_index,num_nodes=x.size(0))
        
        # step 2: Linearly Transform X
        wTx = self.lin(x)

        # step 3: compute normalization
        row, col = edge_index # row: all i's in edge matrix, col: all j's in adjacency matrix
        # uses scatter_add to count up the node degree at a given index value
        # counts all the nodes with incoming edges to calculate the degree for all the j values
        deg = degree(col, x.size(0))
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # access them by their indexes

        # steps 4 and 5: normalize and sum node features
        '''
        We then call propagate(), which internally calls message(),
        aggregate() and update(). We pass the node embeddings x and 
        the normalization coefficients norm as additional arguments 
        for message propagation.'''
        out = self.propagate(edge_index, x=wTx, norm=norm)

        # 6: add bias:
        out = out + self.bias
    
    def message(self, x_j, norm):
        '''
        In the message() function, we need to normalize the neighboring
         node features x_j by norm. Here, x_j denotes a lifted tensor, 
         which contains the source node features of each edge, i.e., the 
         neighbors of each node. Node features can be automatically lifted 
         by appending _i or _j to the variable name. In fact, any tensor 
         can be converted this way, as long as they hold source or 
         destination node features.
        '''
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1,1)*x_j


## Implementing an EdgeConv Layer ##

$$\mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}} \left( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right),$$

where $h_{\mathbf{\Theta}}$ denotes an MLP. In analogy to the GCN layer, we can use the MessagePassing class to implement this layer, this time using the channelwise "max" aggregation:



In [7]:

class EdgeConv(MessagePassing):
    def __init__(self,in_channels, out_channels): # Message Passing Initialization Component
        super().__init__(aggr="max") # initialize with aggregation mode of 'add', could also add flow, and node_dim args
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))


    ''' forward contains the propagate step, which'''
    def forward(self, x, edge_index):
        self.propagate(x, edge_index)
    
    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]
        return self.mlp(torch.cat([x_i,(x_i-x_j)],dim=1))


## Making a GNN from a Point Cloud ##

In [9]:
from torch_geometric.nn import knn_graph

class DynamicEdgeConv(EdgeConv):
    def __init__(self, in_channels, out_channels, k=6):
        super().__init__(in_channels, out_channels)
        self.k = k

    def forward(self, x, batch=None):
        edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
        return super().forward(x, edge_index)

(tensor([0, 1, 1, 2]), tensor([1, 0, 2, 1]))