# Graph Attention Networks (GAT) in PyTorch

This code is a PyTorch adaptation of the Graph Attention Networks (GAT) paper. This uses the labml package, and is primarily based off of code from labml.

In essence, GATs are designed to operate on graph-structured data, where a graph comprises nodes connected by edges. For instance, in the Cora dataset, nodes represent research papers, and edges represent citations between the papers.

GAT employs masked self-attention, which is akin to the mechanism used in transformers. It consists of stacked graph attention layers, where each layer receives node embeddings as input and produces transformed embeddings as output. These layers allow nodes to attend to the embeddings of other connected nodes, facilitating the learning of relationships within the graph. The implementation includes detailed descriptions of the graph attention layer operations.

In [7]:
import torch
from torch import nn
from labml_helpers.module import Module

## Graph attention layer

This is a single graph attention layer.
A GAT is made up of multiple such layers.

It takes
$$\mathbf{h} = \{ \overrightarrow{h_1}, \overrightarrow{h_2}, \dots, \overrightarrow{h_N} \}$$,
where $\overrightarrow{h_i} \in \mathbb{R}^F$ as input
and outputs
$$\mathbf{h'} = \{ \overrightarrow{h'_1}, \overrightarrow{h'_2}, \dots, \overrightarrow{h'_N} \}$$,
where $\overrightarrow{h'_i} \in \mathbb{R}^{F'}$.


* `in_features`, $F$, is the number of input features per node
* `out_features`, $F'$, is the number of output features per node
* `n_heads`, $K$, is the number of attention heads
* `is_concat` whether the multi-head results should be concatenated or averaged
* `dropout` is the dropout probability
* `leaky_relu_negative_slope` is the negative slope for leaky relu activation


In [9]:
class GraphAttentionLayer(Module):
    
    def __init__(self, in_features: int, out_features: int, n_heads: int,
                 is_concat: bool = True,
                 dropout: float = 0.6,
                 leaky_relu_negative_slope: float = 0.2):
        super().__init__()

        self.is_concat = is_concat
        self.n_heads = n_heads

        # Calculate the number of dimensions per head
        if is_concat:
            assert out_features % n_heads == 0
            # If we are concatenating the multiple heads
            self.n_hidden = out_features // n_heads
        else:
            # If we are averaging the multiple heads
            self.n_hidden = out_features

        # Linear layer for initial transformation;
        # i.e. to transform the node embeddings before self-attention
        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
        # Linear layer to compute attention score $e_{ij}$
        self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)
        # The activation for attention score $e_{ij}$
        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
        # Softmax to compute attention $\alpha_{ij}$
        self.softmax = nn.Softmax(dim=1)
        # Dropout layer to be applied for attention
        self.dropout = nn.Dropout(dropout)


* `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`.
* `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`.
We use shape `[n_nodes, n_nodes, 1]` since the adjacency is the same for each head.

Adjacency matrix represent the edges (or connections) among nodes.
`adj_mat[i][j]` is `True` if there is an edge from node `i` to node `j`.


In [11]:
 def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):

        # Number of nodes
        n_nodes = h.shape[0]
        # The initial transformation,
        # $$\overrightarrow{g^k_i} = \mathbf{W}^k \overrightarrow{h_i}$$
        # for each head.
        # We do single linear transformation and then split it up for each head.
        g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)
        g_repeat = g.repeat(n_nodes, 1, 1)
        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
        # Reshape so that `g_concat[i, j]` is $\overrightarrow{g_i} \Vert \overrightarrow{g_j}$
        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
        
        e = self.activation(self.attn(g_concat))
        # Remove the last dimension of size `1`
        e = e.squeeze(-1)
        
        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
        # Mask $e_{ij}$ based on adjacency matrix.
        # $e_{ij}$ is set to $- \infty$ if there is no edge from $i$ to $j$.
        e = e.masked_fill(adj_mat == 0, float('-inf'))
        
        