# Graph Neural Networks

**Author**: Oleg Platonov

Graph Neural Networks (GNNs) are currently the most popular approach to machine learning on graphs. Many GNN architectures can be unified by the Message-Passing Neural Networks (MPNNs) framework. Below we will describe (a variant of) this framework and implement and train several examples of MPNNs.

First, let's introduce the notation we will be using in this notebook. Let $G = (V, E)$ be a graph with nodeset $V$ and edgeset $E$, $|V| = n$, $|E| = m$. Let $N(v)$ be the one-hop neighborhood of the node $v$ and $deg(v)$ be the degree of node $v$, $deg(v) = |N(v)|$. Let $A$ be the adjacency matrix of graph $G$ and $D$ be the diagonal degree matrix of graph $G$, i.e., $D = diag \Big( deg(v_1), \; deg(v_2), \; ..., \; deg(v_n) \Big)$.

In each layer $l$ an MPNN creates a representation $h_i^l$ of each node $v_i$ from it's previous-layer representation and previous-layer representations of its neighbors using the following formula:

$$ h_i^{l+1} = \mathrm{Update} \Bigg( h_i^l, \; \mathrm{Aggregate} \Big( \Big\{ (h_i^l, \; h_j^l): \; v_j \in N(v_i) \Big\} \Big) \Bigg) $$

Here, $\mathrm{Aggregate}$ is a function that aggregates information from the set of neighbors (since it operates on a set, it should be invariant to the order of neighbors) and $\mathrm{Update}$ is a function that combines the node's previous-layer representation with the aggregated information from its neighbors. For example, $\mathrm{Aggregate}$ can be the elementwise mean operation over the set of neighbors and $\mathrm{Update}$ can be an MLP that takes two concatenated vectors as input:

$$ h_i^{l+1} = \mathrm{MLP} \Bigg( \bigg[ h_i^l \; \mathbin\Vert \; \mathrm{mean} \Big( \Big\{ h_j^l: \; v_j \in N(v_i) \Big\} \Big) \bigg] \Bigg) $$

(this is actually the first GNN that we will implement in this seminar).

The $\mathrm{Aggregate}$ operation is often called graph convolution

Note that variations of the above MPNN formula are possible. For example, edge representations can be added, but we won't do it in this seminar.

In [None]:
from tqdm.notebook import tqdm
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.cuda.amp import autocast, GradScaler

In [None]:
device = 'cuda:0'

Now, let's get us a graph. PyTorch Geometric library provides a lot of popular graph datasets. We will use the Amazon-Computers dataset. It is a co-purchasing network where nodes represent products, edges indicate that two products are frequently bought together, node features are bag-of-words-encoded product reviews, and node labels are product categories. The graph is a simple undirected graph without self-loops.

In [None]:
# !pip install torch_geometric

In [None]:
from torch_geometric import datasets

In [None]:
data = datasets.Amazon(name='computers', root='data')[0]
features = data.x
labels = data.y
edges = data.edge_index.T

print(f'Number of nodes: {len(labels)}')
print(f'Number of edges: {len(edges)}')
print(f'Average node degree: {len(edges) * 2 / len(labels):.2f}')

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
full_idx = np.arange(len(labels))
train_idx, val_and_test_idx = train_test_split(full_idx, test_size=0.5, random_state=0,
                                               stratify=labels)

val_idx, test_idx = train_test_split(val_and_test_idx, test_size=0.5, random_state=0,
                                     stratify=labels[val_and_test_idx])

train_idx = torch.from_numpy(train_idx)
val_idx = torch.from_numpy(val_idx)
test_idx = torch.from_numpy(test_idx)

Let's prepare a training loop.

In [None]:
def train_step(model, optimizer, scaler, amp, graph, features, labels, train_idx):
    model.train()

    with autocast(enabled=amp):
        logits = model(graph=graph, x=features)
        loss = F.cross_entropy(input=logits[train_idx], target=labels[train_idx])

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()


@torch.no_grad()
def evaluate(model, amp, graph, features, labels, train_idx, test_idx, val_idx):
    model.eval()

    with autocast(enabled=amp):
        logits = model(graph=graph, x=features)

    preds = logits.argmax(axis=1)
    
    train_accuracy = (preds[train_idx] == labels[train_idx]).float().mean().item()
    val_accuracy = (preds[val_idx] == labels[val_idx]).float().mean().item()
    test_accuracy = (preds[test_idx] == labels[test_idx]).float().mean().item()
    
    metrics = {
        'train accuracy': train_accuracy,
        'val accuracy': val_accuracy,
        'test accuracy': test_accuracy
    }

    return metrics


def run_experiment(graph, features, labels, train_idx, val_idx, test_idx, graph_conv_module, num_layers=2,
                   hidden_dim=256, num_heads=4, dropout=0.2, lr=3e-5, num_steps=500, device='cuda:0', amp=False):
    model = Model(graph_conv_module=graph_conv_module,
                  num_layers=num_layers,
                  input_dim=features.shape[1],
                  hidden_dim=hidden_dim,
                  output_dim=len(labels.unique()),
                  num_heads=num_heads,
                  dropout=dropout)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scaler = GradScaler(enabled=amp)
    
    graph = graph.to(device)
    features = features.to(device)
    labels = labels.to(device)
    train_idx = train_idx.to(device)
    val_idx = val_idx.to(device)
    test_idx = test_idx.to(device)
    
    best_val_metric = 0
    corresponding_test_metric = 0
    best_step = None
    with tqdm(total=num_steps) as progress_bar:
        for step in range(1, num_steps + 1):
            train_step(model=model, optimizer=optimizer, scaler=scaler, amp=amp, graph=graph, features=features,
                       labels=labels, train_idx=train_idx)
            metrics = evaluate(model=model, amp=amp, graph=graph, features=features, labels=labels,
                               train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

            progress_bar.update()
            progress_bar.set_postfix({metric: f'{value:.2f}' for metric, value in metrics.items()})
            
            if metrics['val accuracy'] > best_val_metric:
                best_val_metric = metrics['val accuracy']
                corresponding_test_metric = metrics['test accuracy']
                best_step = step
    
    print(f'Best val accuracy: {best_val_metric:.4f}')
    print(f'Corresponding test accuracy: {corresponding_test_metric:.4f}')
    print(f'(step {best_step})')


This should look quite similar to your standard training loop, but with one notable difference - there are no mini-batches, we are always training on the whole graph. Since the data samples (graph nodes) are not independent, we cannot trivially sample a mini-batch.

Now, let's implement a model. Don't forget about skip connections and layer normalization - they can signififcantly boost the performance of a deep learning model.

In [None]:
class FeedForwardModule(nn.Module):
    def __init__(self, dim, num_inputs, dropout):
        super().__init__()
        self.linear_1 = nn.Linear(in_features=num_inputs * dim, out_features=dim)
        self.dropout_1 = nn.Dropout(p=dropout)
        self.act = nn.GELU()
        self.linear_2 = nn.Linear(in_features=dim, out_features=dim)
        self.dropout_2 = nn.Dropout(p=dropout)
    
    def forward(self, x):
        x = self.linear_1(x)
        x = self.dropout_1(x)
        x = self.act(x)
        x = self.linear_2(x)
        x = self.dropout_2(x)

        return x


class ResidualModule(nn.Module):
    def __init__(self, graph_conv_module, dim, num_heads, dropout):
        super().__init__()
        self.normalization = nn.LayerNorm(normalized_shape=dim)
        self.graph_conv = graph_conv_module(dim=dim, num_heads=num_heads)
        self.feed_forward = FeedForwardModule(dim=dim, num_inputs=2, dropout=dropout)
    
    def forward(self, graph, x):
        x_res = self.normalization(x)
        
        x_aggregated = self.graph_conv(graph, x_res)
        x_res = torch.cat([x_res, x_aggregated], axis=1)
        
        x_res = self.feed_forward(x_res)
        
        x = x + x_res
        
        return x


class Model(nn.Module):
    def __init__(self, graph_conv_module, num_layers, input_dim, hidden_dim, output_dim, num_heads, dropout):
        super().__init__()
        self.input_linear = nn.Linear(in_features=input_dim, out_features=hidden_dim)
        self.input_dropout = nn.Dropout(p=dropout)
        self.input_act = nn.GELU()
        
        self.residual_modules = nn.ModuleList(
            ResidualModule(graph_conv_module=graph_conv_module, dim=hidden_dim, num_heads=num_heads,
                           dropout=dropout)
            for _ in range(num_layers)
        )
        
        self.output_normalization = nn.LayerNorm(hidden_dim)
        self.output_linear = nn.Linear(in_features=hidden_dim, out_features=output_dim)
    
    def forward(self, graph, x):
        x = self.input_linear(x)
        x = self.input_dropout(x)
        x = self.input_act(x)
        
        for residual_module in self.residual_modules:
            x = residual_module(graph, x)
        
        x = self.output_normalization(x)
        logits = self.output_linear(x)
        
        return logits


Now everything is ready - except for the graph convolution module. We will implement several variants of this module, which will constitute the only difference between our GNNs. But first - as a simple baseline - let's implement a graph convolution module that does nothing. It will allow us to see how a graph-agnostic model performs, so we can then compare our GNNs to this baseline.

In [None]:
class DummyGraphConv(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
    
    def forward(self, graph, x):
        return torch.zeros_like(x)


In [None]:
graph = torch.empty(0)   # We don't care about graph representation for this experiment.

run_experiment(graph=graph, features=features, labels=labels,
               train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
               graph_conv_module=DummyGraphConv,
               device=device, amp=True)

Now let's implement some real graph convolutions. Simple graph convolutions can be represented as operations with (sparse) matrices. Thus, they can be implemented in pure PyTorch. We will need the graph adjacency matrix $A$, the graph degree matrix $D$, and the matrix of node representations at layer $l$ $H^l$. Further, let $\tilde{h_i}^{l}$ be the output of $\mathrm{Aggregate}$ function at layer $l$ for node $v_i$ and let $\widetilde{H}^l$ be the matrix of stacked vectors $\tilde{h_i}^{l}$ for all nodes.

For the next couple experiments, assume that the graph argument of the graph convolution forward method is a sparse adjacency matrix of the graph.

In [None]:
graph = torch.sparse_coo_tensor(indices=edges.T, values=torch.ones(len(edges)), size=(len(labels), len(labels)))
graph

Let's implement a graph convolution that simply takes the mean of neighbors' representations. We can write:

$$ \tilde{h}_i^{l+1} = \frac{1}{|N(v_i)|} \sum_{v_j \in N(v_i)} h_j^l $$

This operation can be written in matrix form:

$$ \widetilde{H}^{l+1} = D^{-1} A H^l $$

Let's implement it!

In [None]:
class MeanGraphConv(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
    
    def forward(self, graph, x):
        ### YOUR CODE HERE ###
        
        
        
        ######################
        
        return x_aggregated


(The computations can be sped up by precomputing $D^{-1} A$, but we won't do it.)

In [None]:
run_experiment(graph=graph, features=features, labels=labels,
               train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
               graph_conv_module=MeanGraphConv,
               device=device)

As we can see, the accuracy is a lot better than in the previous experiment - our GNN works better than a graph-agnostoc model on this dataset.

Now, let's try another simple GNN variant - this time we will implement a graph convolution proposed in [the GCN paper](https://arxiv.org/abs/1609.02907). The formula is:

$$ \tilde{h}_i^{l+1} = \sum_{v_j \in N(v_i)} \frac{1}{\sqrt{deg(v_i) deg(v_j)}} h_j^l $$

It's very similar to the mean convolution, except we normalize each neighbor's representation not by the degree of the ego node, but by the geometric mean of the degree of the ego node and the neighbor. This operation can be written in matrix form:

$$ \widetilde{H}^{l+1} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}} H^l $$

Let's implement it!

In [None]:
class GCNGraphConv(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
    
    def forward(self, graph, x):
        ### YOUR CODE HERE ###
        
        
        
        ######################
        
        return x_aggregated


(The computations can be sped up by precomputing $D^{-\frac{1}{2}} A D^{-\frac{1}{2}}$, but we won't do it.)

In [None]:
run_experiment(graph=graph, features=features, labels=labels,
               train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
               graph_conv_module=GCNGraphConv,
               device=device)

The results are similar to those in the previous experiment.

Simple graph convolutions can be expressed as matrix operations, and thus, can be implemented in pure PyTorch. However, efficient implementation of more complex graph convolutions requires using specialized libraries. There are two most popular GNN libraries for PyTorch - [PyTorch Geometric (PyG)](https://github.com/pyg-team/pytorch_geometric) and [Deep Graph Library (DGL)](https://www.dgl.ai/). In this seminar, we will be using DGL, because ~it is objectively better~ the instructor likes it more.

In [None]:
# pip install  dgl -f https://data.dgl.ai/wheels/cu118/repo.html

In [None]:
import dgl
from dgl import ops

There are many features for deep learning on graphs in DGL, but we will only be using two of them - the Graph class, which is obviously used for representing a graph, and the [ops module](https://docs.dgl.ai/api/python/dgl.ops.html), which contains operators for message passing on graphs.

First, let's create a graph representation which we will be using in the next few experiments.

In [None]:
graph = dgl.graph((edges[:, 0], edges[:, 1]), num_nodes=len(labels))
graph

Now let's reimplement the mean graph convolution, this time using DGL. For this we will need a certain operation from the ops module - can you guess which one by their names?

In [None]:
class DGLMeanGraphConv(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
    
    def forward(self, graph, x):
        ### YOUR CODE HERE ###
        
        
        
        ######################
        
        return x_aggregated


In [None]:
run_experiment(graph=graph, features=features, labels=labels,
               train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
               graph_conv_module=DGLMeanGraphConv,
               device=device, amp=True)

The results are roughly the same as for the pure PyTorch implementation, but the training is faster (graph message passing operations with DGL a generally faster than PyTorch sparse matrix multiplications, and, further, DGL supports using AMP with most of its operations, while PyTorch does not (yet) allow using AMP with sparse matrix operations).

By simply swapping the ops.copy_u_mean function for the ops.copy_u_max function, we can get another graph convolution that computes the elementwise maximum of neighbors' representations. This one cannot be efficiently implemented in pure PyTorch. Let's see how it performs.

In [None]:
class DGLMaxGraphConv(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
    
    def forward(self, graph, x):
        ### YOUR CODE HERE ###
        
        
        
        ######################
        
        return x_aggregated


In [None]:
run_experiment(graph=graph, features=features, labels=labels,
               train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
               graph_conv_module=DGLMaxGraphConv,
               device=device)   # This one currently does not work with AMP.

Now, let's reimplement the GCN graph convolution using DGL.

In [None]:
class DGLGCNGraphConv(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
    
    def forward(self, graph, x):
        ### YOUR CODE HERE ###
        
        
        
        ######################
        
        return x_aggregated


(The computations can be sped up by precomputing weights, but we won't do it.)

In [None]:
run_experiment(graph=graph, features=features, labels=labels,
               train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
               graph_conv_module=DGLGCNGraphConv,
               device=device, amp=True)

Now let's implement something more complex - the graph convolution proposed in [the GAT paper](https://arxiv.org/abs/1710.10903). This one uses attention (although a very simple version of it). The formulas are:

$$ a_{ij} = \mathrm{LeakyReLU} \Big( w_1^T h_i^l + w_2^T h_j^l + b \Big) $$

$$ \alpha_{ij} = \mathrm{softmax}_{ij}(a_{i,1}, \; a_{i,2}, \; ... \; a_{i, deg(v_i)}) = \frac{\mathrm{exp}(a_{ij})}{\sum_{v_k \in N(v_i)} \mathrm{exp}(a_{ik})} $$

$$ \tilde{h}_i^{l+1} = \sum_{v_j \in N(v_i)} \alpha_{ij} h_j^l $$

where $\mathrm{softmax}_{ij}(a_{i,1}, \; a_{i,2}, \; ... \; a_{i, deg(v_i)})$ is the $j$-th output of softmax of the values $a_{ik}$ corresponding to the neighbors of the node $v_i$, i.e., softmax is taken only over the ego node's neighborhood. This function is available in DGL.

Note that additionally the attention mechanism is multi-headed.

In [None]:
from dgl.nn.functional import edge_softmax

In [None]:
class DGLGATGraphConv(nn.Module):
    def __init__(self, dim, num_heads=4, **kwargs):
        super().__init__()
        ### YOUR CODE HERE ###
        
        
        
        ######################
    
    def forward(self, graph, x):
        ### YOUR CODE HERE ###
        
        
        
        ######################
        
        return x_aggregated


In [None]:
run_experiment(graph=graph, features=features, labels=labels,
               train_idx=train_idx, val_idx=val_idx, test_idx=test_idx,
               graph_conv_module=DGLGATGraphConv,
               device=device, amp=True)