# Aggregation function in Graph Neural Network

Generalizing the convolution operator to irregular domains is typically expressed as a neighborhood aggregation or message passing scheme. With $\mathbf{x}_i^{(k-1)} \in \mathbb{R}^F$
 denoting node features of node $i$ in $(k-1)$ layer  and $\mathbf{e}_{j,i} \in \mathbb{R}^D$  denoting (optional) edge features from node  to node , message passing graph neural networks can be described as

$$
\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right),
$$

where
* $\square$ denotes a differentiable, permutation invariant function, e.g., sum, mean or max
* $\gamma$  and  $\phi$ denote differentiable functions such as MLPs (Multi Layer Perceptrons).

![](https://i.imgur.com/Q291Xuq.png)




The "MessagePassing" Base Class
-------------------------------

PyG provides the `MessagePassing` base class, which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation.
The user only has to define the functions $\phi$ , *i.e.* `MessagePassing.message`, and $\gamma$ , *i.e.* `MessagePassing.update`, as well as the aggregation scheme to use, *i.e.* `aggr="add"`, `aggr="mean"` or `aggr="max"`.

This is done with the help of the following methods:

* `MessagePassing(aggr="add")`: Defines the aggregation scheme to use (`"add"`, `"mean"` or `"max"`).
* `MessagePassing.propagate(edge_index, size=None, **kwargs)`:
  The initial call to start propagating messages.
  Takes in the edge indices and all additional data which is needed to construct messages and to update node embeddings.
* `MessagePassing.message(...)`: Constructs messages to node `i` in analogy to $\phi$ for each edge $(j,i) \in \mathcal{E}$.Can take any argument which was initially passed to `propagate`.
  In addition, tensors passed to `propagate` can be mapped to the respective nodes `i` and `j` by appending `_i` or `_j` to the variable name, *e.g.* `x_i` and `x_j`.
  Note that we generally refer to `i` as the central nodes that aggregates information, and refer to `j` as the neighboring nodes, since this is the most common notation.

## Implement a very simple GNN module
Let's first implement a simple GNN layer!<br>
The **message** of node $i$ and node $j$ is defined by the concatenation of their embedding and transformed by a weight matrix.  <br>
Finally, the embedding $x_i$ of node $i$ is the summation of the message of its neighbor $N(i)$.<br>


$$
\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i)} \left( \mathbf{W}^{\top} \cdot \left( \mathbf{x}_i^{(k-1)} || \mathbf{x}_j^{(k-1)} \right)  \right),
$$

In [1]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html

2.1.0+cu118
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m98.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m60.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu118.html
Collecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu118/torch_cluster-1.6.3%2Bpt21cu118-cp310-cp310-linux_x86_64.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m38.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-cluster
Successfully installed torch-cluster-1.6.3+pt21cu118


In [2]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing

class BasicGNN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # Sum aggregation
        self.lin = Linear(in_channels*2, out_channels, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if len(p.shape) > 1:
                torch.nn.init.ones_(p)
            else:
                torch.nn.init.zeros_(p)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Obtain the message of each edge
        # messages has shape [E, in_channels*2]
        messages = self.propagate(edge_index, x=x)
        print("message",messages)

        # Apply linear transformation
        out = self.lin(messages)

        return out

    def message(self, x_i, x_j):

        # x_i, x_j has shape [E, in_channels]
        return torch.cat([x_i,x_j],dim=-1)

In [3]:
node_feature = torch.arange(9).view(3,3).float()
edge_index = torch.LongTensor([
    [0,1],
    [1,2]
])
print("Node feature:\n",node_feature)
print("Edges:\n",edge_index)

Node feature:
 tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
Edges:
 tensor([[0, 1],
        [1, 2]])


In [4]:
GCN_model = BasicGNN(3,1)
gcn_output = GCN_model(node_feature, edge_index).flatten().detach()
print("Output:\n",gcn_output)

message tensor([[0., 0., 0., 0., 0., 0.],
        [3., 4., 5., 0., 1., 2.],
        [6., 7., 8., 3., 4., 5.]])
Output:
 tensor([ 0., 15., 33.])


## Implementing the GCN Layer
The GCN layer is mathematically defined as
$$
\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b},
$$
where neighboring node features are first transformed by a weight matrix $\mathbf{W}$, normalized by their degree, and finally summed up.
Lastly, we apply the bias vector $\mathbf{b}$ to the aggregated output.
This formula can be divided into the following steps:

1. Add self-loops to the adjacency matrix.
2. Linearly transform node feature matrix.
3. Compute normalization coefficients.
4. Normalize node features in $`\phi$.
5. Sum up neighboring node features (`"add"` aggregation).
6. Apply a final bias vector.

Steps 1-3 are typically computed before message passing takes place.
Steps 4-5 can be easily processed using the `MessagePassing` base class.
The full layer implementation is shown below:


In [5]:
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.Tensor(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if len(p.shape) > 1:
                torch.nn.init.ones_(p)
            else:
                torch.nn.init.zeros_(p)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out += self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

In [6]:
node_feature = torch.arange(9).view(3,3).float()
edge_index = torch.LongTensor([
    [0,],
    [1,]
])
print("Node feature:\n",node_feature)
print("Edges:\n",edge_index)

Node feature:
 tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
Edges:
 tensor([[0],
        [1]])


In [50]:
GCN_model = GCNConv(3,1)
gcn_output = GCN_model(node_feature, edge_index).flatten().detach()
print("Output:\n",gcn_output)

self_degree = torch.FloatTensor([1,2,1])
self_feature = node_feature.sum(dim=-1) / self_degree
neighbor_feature = torch.FloatTensor([
    0,
    (0+1+2) / (1**0.5 * 2**0.5),
    0,
])
expected_solution = self_feature + neighbor_feature
print("Expected:\n",expected_solution)

Output:
 tensor([ 3.0000, 14.3053, 15.3990])
Expected:
 tensor([ 3.0000,  8.1213, 21.0000])


## Practice: GraphSage Implementation

Now let's start working on our own implementation of layers! This part is to get you familiar with how to implement Pytorch layer based on Message Passing. You will be implementing the **forward**, **message** and **aggregate** functions.

Generally, the **forward** function is where the actual message passing is conducted. All logic in each iteration happens in **forward**, where we'll call **propagate** function to propagate information from neighbor nodes to central nodes.  So the general paradigm will be pre-processing -> propagate -> post-processing.

Recall the process of message passing we introduced in homework 1. **propagate** further calls **message** which transforms information of neighbor nodes into messages, **aggregate** which aggregates all messages from neighbor nodes into one, and **update** which further generates the embedding for nodes in the next iteration.

Our implementation is slightly variant from this, where we'll not explicitly implement **update**, but put the logic for updating nodes in **forward** function. To be more specific, after information is propagated, we can further conduct some operations on the output of **propagate**. The output of **forward** is exactly the embeddings after the current iteration.

In addition, tensors passed to **propagate()** can be mapped to the respective nodes $i$ and $j$ by appending _i or _j to the variable name, .e.g. x_i and x_j. Note that we generally refer to $i$ as the central nodes that aggregates information, and refer to $j$ as the neighboring nodes, since this is the most common notation.

Please find more details in the comments. One thing to note is that we're adding **skip connections** to our GraphSage. Formally, the update rule for our model is described as below:

\begin{equation}
h_v^{(l)} = W_l\cdot h_v^{(l-1)} + W_r \cdot AGG(\{h_u^{(l-1)}, \forall u \in N(v) \})
\end{equation}

For simplicity, we use mean aggregations where:

\begin{equation}
AGG(\{h_u^{(l-1)}, \forall u \in N(v) \}) = \frac{1}{|N(v)|} \sum_{u\in N(v)} h_u^{(l-1)}
\end{equation}

Additionally, $\ell$-2 normalization is applied after each iteration.

\begin{equation}
h_v^{(l)} = h_v^{(l)} /  \lVert h_v^{(l)}\rVert_2, \forall v \in \mathcal{V}
\end{equation}



In [8]:
class GraphSage(MessagePassing):

    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, heads=1, **kwargs):
        super(GraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        self.lin_l = None
        self.lin_r = None

        ############################################################################
        # TODO: Your code here!
        # Define the layers needed for the message and update functions below.
        # self.lin_l is the linear transformation that you apply to embedding
        #            for central node.
        # self.lin_r is the linear transformation that you apply to aggregated
        #            message from neighbors.
        # Our implementation is ~2 lines, but don't worry if you deviate from this.

        ############################################################################
        self.lin_l = nn.Linear(in_channels, out_channels)
        self.lin_r = nn.Linear(in_channels, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if len(p.shape) > 1:
                torch.nn.init.ones_(p)
            else:
                torch.nn.init.zeros_(p)

    def forward(self, x, edge_index, size = None):
        """"""
        ############################################################################
        # TODO: Your code here!
        # Implement message passing, as well as any post-processing (our update rule).
        # 1. First call propagate function to conduct the message passing.
        #    1.1 See there for more information:
        #        https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html
        #    1.2 We use the same representations for central (x_central) and
        #        neighbor (x_neighbor) nodes, which means you'll pass x=(x, x)
        #        to propagate.
        # 2. Update our node embedding with skip connection.
        # 3. If normalize is set, do L-2 normalization (defined in
        #    torch.nn.functional)
        # Our implementation is ~5 lines, but don't worry if you deviate from this.

        ############################################################################
        out = self.propagate(edge_index, x=(x,x), size=size)
        out = self.lin_l(out)
        out += self.lin_r(x)
        if self.normalize:
          out = F.normalize(out, dim=-1)
        return out

    def message(self, x_j):
        ############################################################################
        # TODO: Your code here!
        # Implement your message function here.
        # Our implementation is ~1 lines, but don't worry if you deviate from this.

        ############################################################################
        out = x_j
        return out

In [24]:
node_feature = torch.arange(9).view(3,3).float()

edge_index = torch.LongTensor([
    [0,1,2],
    [1,2,1]
])
print("Node feature:\n",node_feature)
print("Edges:\n",edge_index)

Node feature:
 tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
Edges:
 tensor([[0, 1, 2],
        [1, 2, 1]])


## Check you answer!
Generate the anwers(output) from `SAGE_model` and calculate the expected output manually. <br>
In addition, use the official implementation `torch_geometric.nn.SAGEConv` and see if the output match!


In [51]:
from torch_geometric.nn import SAGEConv
SAGE_model = GraphSage(3,1,normalize=False, aggr="mean")
sage_output = SAGE_model(node_feature, edge_index).flatten().detach()
print("Output:\n",sage_output)

expected_solution = None
################################
# TODO: Your code here!
# Caluate the answer manually!

################################
self_feature = node_feature.sum(dim=-1)
print("Self_Feature:\n",self_feature)
num_nodes = edge_index.max().item() + 1
in_degree = torch.zeros(num_nodes, dtype=torch.long)
in_features_sum = torch.zeros_like(node_feature)

for src_node,dst_node in edge_index.t().tolist():
    in_degree[dst_node] += 1
    in_features_sum[dst_node] += node_feature[src_node]

neighbor_feature = torch.zeros(num_nodes, dtype=torch.long)

for i in range(0,num_nodes) :
  # 有些 in-degree 為0。
  if(sum(in_features_sum[i]) != 0):
    neighbor_feature[i] = sum(in_features_sum[i]) / in_degree[i]
  else:  # 有些 in-degree 為0。 in-degree 為0的node的neighbor_feature 為0
    neighbor_feature[i] = 0

expected_solution = self_feature + neighbor_feature
print("Expected:\n",expected_solution)

################################
# TODO: Your code here!
# 1. import SAGEConv from torch_geometric.nn
# 2. generate the output
# 3. Remember to initialize the weight matrix with ones, and biases with zeros
################################
official_sage_output = None
pyg_SAGE = SAGEConv(3,1,normalize=False,aggr="mean")
for p in pyg_SAGE.parameters():
    if len(p.shape) > 1:
        torch.nn.init.ones_(p)
    else:
        torch.nn.init.zeros_(p)
official_sage_output = pyg_SAGE(node_feature, edge_index).flatten().detach()

print("Official Output:\n",official_sage_output)

Output:
 tensor([ 3., 24., 33.])
Self_Feature:
 tensor([ 3., 12., 21.])
Expected:
 tensor([ 3., 24., 33.])
Official Output:
 tensor([ 3., 24., 33.])
