# Graph Attention Network (GAT)
Code implementation of GAT from the paper (https://arxiv.org/abs/1710.10903) by Kipf and Welling

Additional resources:
- author's blogpost: https://petar-v.com/GAT/
- torch-geometric implementation: https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GAT.html

In [2]:
import torch
import torch.nn.functional as F

In [82]:
class GATConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.6, alpha=0.2, concat=True):
        super(GATConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dropout = dropout
        self.alpha = alpha
        self.concat = concat

        self.W = torch.nn.Parameter(torch.Tensor(in_channels, out_channels))
        self.a = torch.nn.Parameter(torch.Tensor(2*out_channels, 1))
        self.b = torch.nn.Parameter(torch.Tensor(out_channels))
        
        # leakyReLU was used in the original GAT paper
        self.leakyrelu = torch.nn.LeakyReLU(self.alpha)
        
        self.reset_parameters()
    
    def forward(self, A, H):        
        Wh = H @ self.W
        Whu = Wh @ self.a[:self.out_channels]
        Whv = Wh @ self.a[self.out_channels:]
        # broadcast add
        e = Whu + Whv.T
        e = self.leakyrelu(e)
        
        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(A > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        
        H_prime = attention @ Wh
        
        if self.concat:
            return F.elu(H_prime)
        else:
            return H_prime
    
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.W)
        torch.nn.init.xavier_uniform_(self.a)
        torch.nn.init.zeros_(self.b)

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.6, alpha=0.2, num_heads=4):
        super(GAT, self).__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.dropout = dropout
        self.alpha = alpha
        self.num_heads = num_heads
        
        self.attention_heads = [GATConv(in_channels, hidden_channels, dropout, alpha) for _ in range(num_heads)]
        self.attention_out = GATConv(hidden_channels*num_heads, out_channels, dropout, alpha, concat=False)
    
    def forward(self, A, H):
        H = F.dropout(H, self.dropout, training=self.training)
        out = torch.cat([head(A, H) for head in self.attention_heads], dim=1)
        out = F.elu(self.attention_out(A, out))
        return F.log_softmax(out, dim=1)
        

In [49]:
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

from torch_geometric.utils import to_dense_adj

In [66]:
dataset = Planetoid(root='datasets', name='Cora')
# dataset.transform = T.NormalizeFeatures()

print('number of classes: ', dataset.num_classes)
print('number of nodes (|V|): ', dataset.data.num_nodes)
print('number of node features (|d|): ', dataset.num_node_features)

data = dataset[0]

number of classes:  7
number of nodes (|V|):  2708
number of node features (|d|):  1433


In [87]:
# torch dataset comes with a sparse edge index so we are converting it to an adjacency matrix (for notational convenience)
A = to_dense_adj(data.edge_index)[0]
X = data.x

model = GAT(dataset.num_node_features, 64, dataset.num_classes, dropout=0.4, alpha=0.2, num_heads=8)

# training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
num_epochs = 200

for epoch in range(num_epochs):
    optimizer.zero_grad()
    
    out = model(A, X)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    
    loss.backward()
    optimizer.step()
    if epoch % (num_epochs / 10) == 0:
        print(loss.item())

1.954589605331421
1.8099957704544067
1.6648118495941162
1.5188413858413696
1.4498655796051025
1.3341245651245117
1.1727324724197388
1.2138220071792603
1.0909379720687866
1.077467918395996


In [88]:
model.eval()
_, pred = model(A, X).max(dim=1)
correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
acc = round(acc * 100, 3)

print('accuracy: ', acc, "%")

accuracy:  76.6 %
