<a href="https://colab.research.google.com/github/shahabday/graph-neural-networks/blob/main/4_MessagePassingNeuralNetworks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Please, make a copy of the notebook before we start.
# Turn on the GPU support in the Runtime/Change runtime type.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

# Install Pytorch Geometric and its dependencies
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

2.5.0+cu121
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m61.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch-geometric (pyproject.toml) ... [?25l[?25hdone


# **[Generalized message passing](https://arxiv.org/abs/1806.01261) and relational inductive bias**

>"Infinite use of finite means" (Humboldt,1836; Chomsky, 1965)

*   Human intellect is able to productively compose complex structures (sentences) using a small set of elements (words).
*   *Relational inductive bias* imposes constraints on relationships and interactions among entities in a learning process. It's a prior knowledge one incorporates into a learning algorithm. For example, for a CNN it's locality and **translation invariance**.
*   A graph structure imposes a strong relational inductive bias, since during the learning process we heavily utilize the relations between the nodes, which can be of arbitrary nature.
*   Message passing algorithms must provide permutation invariance of the nodes, since it doesn't matter from what neighbour we receive the signal first. Therefore, graph neural networks have **permutation invariance**.
  - **Example:** compute the centre of mass of the solar system $\to$ the order of the planets doesn't matter.

### **Graph definition:**

*   $\mathcal{G}=(\mathbf{u}, V,E)$
  - $V$: nodes.
  - $E$: edges.
  - $\mathbf{u}$: global attribute.
*   The graph is:
  - Directed.
  - Attributed (nodes, edges, global).
  - Multi-graph: there can be more than one edge between nodes.

**Example:** balls connected by springs in the gravitational field:
*   $\mathbf{u}$ is the total kinetic energy of the system.
*   $V= \{v_k \}$ is the set of balls, with attributes for position and momentum.
*   $E = \{(\textbf{e}_k, r_k, s_k)\}$ is the set of springs connecting the balls, with their corresponding potential energies, $r_k$ and $s_k$ are the indices of the receiver and sender nodes.


<center width="100%" style="padding:10px"> <img src ="https://drive.google.com/uc?id=17jaILJ0oiN9OzwZ6ZowI1hLHowcPBl_3" width=600 height=275></center>

### **GN block update**

\begin{equation}
\begin{aligned}
    1) \, \textbf{e}^{'}_{k} = \phi^e (\textbf{e}_k, \, \textbf{v}_{r_k}, \, \textbf{v}_{r_s}, \, \textbf{u}) &&
    \hspace{2cm} 2) \, \bar{\textbf{e}}'_{i} = \rho^{e \to v} ( E^{'}_i) \hspace{0.95cm}\\
    3) \, \textbf{v}^{'}_{i} = \phi^{v} (\bar{\textbf{e}}'_{i}, \, \textbf{v}_{i}, \, \textbf{u}) \hspace{1.25cm} &&
    4) \, \bar{\textbf{e}}' = \rho^{e \to u} ( E^{'}) \hspace{0.95cm}\\
    5) \, \bar{\textbf{v}}' =  \rho^{v \to u} (V^{'}) \hspace{2cm} &&
    6) \, \textbf{u}' = \phi^{u} (\bar{\textbf{e}}', \, \bar{\textbf{v}}', \, \textbf{u}),
\end{aligned}
\end{equation}

where:
*   $E^{'}_i = \{(\textbf{e}^{'}_k, r_k, s_k)\}_{r_k = i, k=1:N^{|E|}}$ is the set of all per-edge outputs of the node $i$.
*   $E^{'} = \bigcup E^{'}_i $ is the set of all per-edge outputs.
*   $V^{'} = \{v^{'}_i \}_{i=1:N^{|V|}}$ is the set of all per-node outputs.

Ball system example:

1. Update the corresponding forces (edge attributes) between two connected balls.
2. Aggregate the edge updates for edges that project to vertex $i$. Sum all the forces acting on a ball $i$.
3. Update node attributes, e.g. update a position and a momentum of each ball.
4. Aggregate all updated edge attributes. Sum all the forces.
5. Aggregate all nodes attributes, which might correspond to calculating the total momentum of the system.
6. Update the global attribute, e.g. the kinetic energy of the system.

**All different functions $\phi$ can be arbitrary differentiable functions, e.g. neural networks. And all functions $\rho$ are aggregation functions, e.g. *sum* or *mean*.** (figure credit - [Peter Battaglia](https://arxiv.org/abs/1806.01261))

<center width="100%" style="padding:10px"> <img src ="https://drive.google.com/uc?id=12I2j2TQtWHWSBZBJ3wYtkZHN45sux62t" width=700 height=175></center>

# **Graph classification with generalized message passing**

*   Classify an entire graph instead of single nodes or edges with a given dataset of multiple graphs.
*   Molecular property prediction, in which molecules are represented as graphs. Each atom is linked to a node, and edges in the graph are the bonds between atoms.

<center width="100%" style="padding:10px"> <img src ="https://drive.google.com/uc?id=1o_BXW38qr-9uGGPPGgfghCz7P1OOKpsh" width=600></center>

We will use a slightly modified version of the MUTAG dataset. Let's add a global attribute to each graph and set it to zero.




In [None]:
from torch_geometric.datasets import TUDataset

def add_global_attr(data):
  data.u = torch.tensor([[0]]).to(torch.float32)
  return data

dataset = TUDataset(root='data/TUDataset', name='MUTAG', transform=add_global_attr)

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

ModuleNotFoundError: No module named 'torch_geometric'

*   Node features are the one-hot encoded atom types.
*   Edge features are different types of atom bonds.
*   The binary graph labels represent a graph's "mutagenic effect on a specific gram negative bacterium". (Not so important for us)
*   We added a global attribute to each graph.

Let's shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing:

In [None]:
torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:150]
test_dataset = dataset[150:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

### **Mini-batching of graphs in PyG**

*   Each graph in the batch can have a different number of nodes and edges, hence we would require a lot of padding to obtain a single tensor.
*   Represent $N$ graphs in a batch as a single large graph with concatenated node and edge lists.
* There is no edge between different graphs.

<center width="100%" style="padding:10px"> <img src ="https://drive.google.com/uc?id=1C5Ob2YQxrMH-Xf55mZ2RNO79q2Rw5E9P" width=600></center>

Advantages over other batching procedures:

1. GNN operators that rely on a message passing scheme do not need to be modified since messages are not exchanged between two nodes that belong to different graphs.

2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges.

**PyG's Dataloader does batching for us automatically.**

In [None]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Each `Batch` object is equipped with a **`batch` vector**, which maps each node to its respective graph in the batch:

$$
\textrm{batch} = [ 0, \ldots, 0, 1, \ldots, 1, 2, \ldots ]
$$

### **Graph level predictions**

Graph classification in a nutshell:

1. Embed each node by performing multiple rounds of message passing.
2. Aggregate node embeddings into a unified graph embedding (**readout layer**).
3. Train a final classifier on the graph embedding.

The most common **readout layers** is to simply take the average of node embeddings:

$$
\mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{z}^{(L)}_v
$$

We can do it via [`torch_geometric.nn.global_mean_pool`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.global_mean_pool).

### **Train a GN model**

Let's create three neural networks for node, edge and global attributes updates. You can check [torch_scatter.scatter_mean](https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/mean.html) function for better understanding.


In [None]:
from torch_scatter import scatter_mean

# TODO: complete the edge, node and global update networks

class EdgeModel(torch.nn.Module):
  def __init__(self, hidden_channels):
    # TODO: complete together
    pass

  def forward(self, src, dest, edge_attr, u, batch):
    # src, dest: [E, F_x], where E is the number of edges.
    # edge_attr: [E, F_e]
    # u: [B, F_u], where B is the number of graphs.
    # batch: [E] with max entry B - 1.
    # TODO: complete together
    pass


class NodeModel(torch.nn.Module):
  def __init__(self, hidden_channels):
    # TODO: complete together
    pass


  def forward(self, x, edge_index, edge_attr, u, batch):
    # x: [N, F_x], where N is the number of nodes in all graphs of the batch.
    # edge_index: [2, E] with max entry N - 1.
    # edge_attr: [E, F_e]
    # u: [B, F_u]
    # batch: [N] with max entry B - 1.
    # TODO: complete together
    pass

class GlobalModel(torch.nn.Module):
  def __init__(self, hidden_channels):
    # TODO: complete together
    pass

  def forward(self, node_attr_prime, edge_out_bar, u, batch):
    # node_attr_prime: [N, F_x], where N is the number of nodes in the batch.
    # edge_out_bar: [N, F_e]
    # u: [B, F_u]
    # batch: [N] with max entry B - 1.
    # Average all node attributes for each graph, using batch tensor.
    # TODO: complete together
    pass

Then let's write the GN class that takes 3 update models as its' arguments. `num_passes` is the number of times we perform node, edge, global updates.

In [None]:
from torch_geometric.nn import global_mean_pool

class GN(torch.nn.Module):

  def __init__(self, edge_model, node_model, global_model, num_passes):
    super().__init__()
    torch.manual_seed(12345)
    self.edge_model = edge_model
    self.node_model = node_model
    self.global_model = global_model
    num_features = dataset.num_features + dataset.num_edge_features
    num_features += dataset[0].u.size(1)
    self.lin = torch.nn.Linear(num_features, dataset.num_classes)
    self.num_passes = num_passes
    self.reset_parameters()

  def reset_parameters(self):
    for item in [self.node_model, self.edge_model, self.global_model]:
      if hasattr(item, 'reset_parameters'):
        item.reset_parameters()

  # TODO: write the forward pass together

  def __repr__(self) -> str:
      return (f'{self.__class__.__name__}(\n'
              f'  edge_model={self.edge_model},\n'
              f'  node_model={self.node_model},\n'
              f'  global_model={self.global_model}\n'
              f')')

So we can start training:

In [None]:
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

edge_model, node_model, global_model = EdgeModel(64), NodeModel(64), GlobalModel(64)
NUM_PASSES = 3

gn_model = GN(edge_model, node_model, global_model, NUM_PASSES)
optimizer = torch.optim.Adam(gn_model.parameters(), lr=0.01)
loss_function = torch.nn.CrossEntropyLoss()

print(gn_model)
torch.manual_seed(12345)

def train():
    gn_model.train()

    for data in train_loader: # Iterate over the batches of grahs
         out = gn_model(data.x, data.edge_attr, data.u, data.edge_index, data.batch)  # Forward pass(es)
         loss = loss_function(out, data.y)  # Compute the loss
         loss.backward() # Compute the gradients
         optimizer.step()  # Update the weights based on the computed gradients
         optimizer.zero_grad() # Clear the computed gradients

def test(loader):
     gn_model.eval()

     correct = 0
     for data in loader:
         out = gn_model(data.x, data.edge_attr, data.u, data.edge_index, data.batch)  # Iterate over the batches
         pred = out.argmax(dim=1)  # Predict the labels using the label with the highest probability
         correct += int((pred == data.y).sum())  # Check against the ground truth
     return correct / len(loader.dataset) # Compute accuracy

for epoch in range(1, 120):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

The reasons for the fluctations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets.