# Message Passing Neural Networks

Supplementary resource from the Presentation Introduction to Graph Representation Learning (Timothy Lee)

In [29]:
import torch
import torch.nn.functional as F

### General MPNN Implementation

In [133]:
# This class defines one layer of the basic GNN model
class MPNNconv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, normalize=False, self_loops=False):
        super(MPNNconv, self).__init__()
        self.W_self = torch.nn.Parameter(torch.Tensor(in_channels, out_channels))
        self.W_neigh = torch.nn.Parameter(torch.Tensor(in_channels, out_channels))
        self.b = torch.nn.Parameter(torch.Tensor(out_channels))
        self.reset_parameters()
        
        # extra parameters
        self.normalize = normalize
        self.self_loops = self_loops
    
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.W_self)
        torch.nn.init.xavier_uniform_(self.W_neigh)
        torch.nn.init.zeros_(self.b)
    
    def forward(self, A, h):
        # A: adjacency matrix (v x v)
        # h: node representation (v x d)
        
        if self.normalize:
            # Compute D^(-1/2)
            D = torch.diag(torch.sum(A, dim=1))
            D_inv = torch.inverse(D)
            
            A = torch.matmul(D_inv, A)
        
        H = None
        if self.self_loops:
            H = torch.matmul(torch.matmul(A, h), self.W_neigh) + self.b
        else:
            # matrix-level feed-forward operation
            H = torch.matmul(h, self.W_self) + torch.matmul(torch.matmul(A, h), self.W_neigh) + self.b
        
        # element-wise non-linearity
        H = F.relu(H)
        
        return H

class GCNConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__()
        self.W = torch.nn.Parameter(torch.Tensor(in_channels, out_channels))
        self.b = torch.nn.Parameter(torch.Tensor(out_channels))
        self.reset_parameters()
    
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.W)
        torch.nn.init.zeros_(self.b)
    
    def forward(self, A, h):
        # A: adjacency matrix (v x v)
        # h: node representation (v x d)
        
        D_inv_sqrt = torch.diag(torch.pow(torch.sum(A, dim=1), -0.5))
        A = torch.matmul(torch.matmul(D_inv_sqrt, A), D_inv_sqrt)
        
        H = torch.matmul(torch.matmul(A, h), self.W) + self.b
        
        # element-wise non-linearity
        H = F.relu(H)
        
        return H

In [134]:
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, normalize=False, self_loops=False, conv_type='MPNN'):
        super(GNN, self).__init__()
        if conv_type == 'MPNN':
            self.conv1 = MPNNconv(in_channels, hidden_channels, normalize, self_loops)
            self.conv2 = MPNNconv(hidden_channels, out_channels, normalize, self_loops)
        elif conv_type == 'GCN':
            self.conv1 = GCNConv(in_channels, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, out_channels)
        else:
            raise ValueError('Unknown conv type')
        
        self.self_loops = self_loops
    
    def forward(self, A, X):
        if self.self_loops:
            I = torch.eye(A.size(0), device=A.device)
            A = A + I
        h1 = self.conv1(A, X)
        h2 = self.conv2(A, h1)
        
        return h2

### Training with Citation Networks

In [31]:
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

from torch_geometric.utils import to_dense_adj

In [33]:
dataset = Planetoid(root='datasets', name='Cora')
# dataset.transform = T.NormalizeFeatures()

print('number of classes: ', dataset.num_classes)
print('number of nodes (|V|): ', dataset.data.num_nodes)
print('number of node features (|d|): ', dataset.num_node_features)

data = dataset[0]
data

number of classes:  7
number of nodes (|V|):  2708
number of node features (|d|):  1433




Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [137]:
# torch dataset comes with a sparse edge index so we are converting it to an adjacency matrix (for notational convenience)
A = to_dense_adj(data.edge_index)[0]
X = data.x

model = GNN(in_channels=dataset.num_node_features, hidden_channels=64, out_channels=dataset.num_classes, normalize=True, self_loops=True, conv_type='GCN')

# training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
num_epochs = 200

for epoch in range(num_epochs):
    optimizer.zero_grad()
    
    out = model(A, X)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    
    loss.backward()
    optimizer.step()
    if epoch % (num_epochs / 10) == 0:
        print(loss.item())

1.9460290670394897
1.3367527723312378
0.6376906633377075
0.2751449644565582
0.13458026945590973
0.07771994173526764
0.05084258317947388
0.03611980006098747
0.02713708020746708
0.021219292655587196


In [138]:
model.eval()
_, pred = model(A, X).max(dim=1)
correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
acc = round(acc * 100, 3)

print('accuracy: ', acc, "%")

accuracy:  78.1 %
