References:

1. https://github.com/xbresson/CS6208_2023/blob/main/codes/labs_lecture07/01_vanilla_graph_transformers.ipynb
2. https://github.com/pgniewko/pytorch_geometric/blob/master/torch_geometric/nn/conv/transformer_conv.py
3. https://arxiv.org/abs/2012.09699
4. https://arxiv.org/abs/1703.04977

In [5]:
import torch
import math
import torch.nn.functional as F
import numpy as np
from torch import nn
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_scatter.composite import scatter_softmax
from torch_scatter.scatter import scatter_add
from torch_geometric.utils import softmax
from torch_geometric.nn.resolver import activation_resolver


torch.manual_seed(192837465)

# Generate example data
x = torch.randn(6, 3)  # Node features (6 nodes, 3-dimensional features)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5], [1, 0, 2, 1, 3, 2, 5, 4]], dtype=torch.long)  # Edge indices
edge_attr = torch.randn(8, 2)  # Edge attributes (8 edges, 2-dimensional attributes)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
class MLP(nn.Module):
    def __init__(self, hidden_dim, num_layers=2, dropout_p=0.0, 
                 act='relu', act_kwargs=None, final_act=True):
        super(MLP, self).__init__()
    
        if isinstance(hidden_dim, int):
            hidden_dim = [hidden_dim] * num_layers
            
        assert len(hidden_dim) == num_layers
        
        layers = []
        
        for (in_dim, out_dim) in zip(hidden_dim[:-2], hidden_dim[1:]):
            layers.append(nn.Linear(in_dim, out_dim, bias=True))
            layers.append(activation_resolver(act, **(act_kwargs or {})))
            if dropout_p > 0:
                layers.append(nn.Dropout(p=dropout_p))
        
        if dropout_p > 0.0:
            layers = layers[:-1]
        if not final_act:
            layers = layers[:-1]
            
        self.mlp = nn.Sequential(*layers)
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout_p = dropout_p
        self.act = act
    
    def forward(self, x):
        return self.mlp(x)

    
class GTConv(MessagePassing):
    def __init__(self, in_dim, out_dim, num_heads=1, dropout_p=0.0, norm='bn'):
        super(GTConv, self).__init__(node_dim=0, aggr='add')
        
        assert out_dim % num_heads == 0
        
        self.WQ = nn.Linear(in_dim, out_dim, bias=True)
        self.WK = nn.Linear(in_dim, out_dim, bias=True)
        self.WV = nn.Linear(in_dim, out_dim, bias=True)
        
        self.WO = nn.Linear(out_dim, in_dim, bias=True)
        
        
        if norm.lower() in ['bn', 'batchnorm', 'batch_norm']:
            self.norm1 = nn.BatchNorm1d(in_dim)
            self.norm2 = nn.BatchNorm1d(in_dim)
        elif norm.lower() in ['ln', 'layernorm', 'layer_norm']:
            self.norm1 = nn.LayerNorm(in_dim)
            self.norm2 = nn.LayerNorm(in_dim)
            
        self.ffn = MLP(hidden_dim=in_dim, num_layers=2, 
                       dropout_p=dropout_p, final_act=False)
        
        self.num_heads = num_heads
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dropout_p = dropout_p
        
        self.reset_parameters()
        
    def reset_parameters(self):
        # TODO: init linear layers with xavier weights    
        pass
    
    def forward(self, x, edge_index, edge_attr=None):
        #print(edge_index)
        Q = self.WQ(x).view(-1, self.num_heads, self.out_dim // self.num_heads)
        K = self.WK(x).view(-1, self.num_heads, self.out_dim // self.num_heads)
        V = self.WV(x).view(-1, self.num_heads, self.out_dim // self.num_heads)
        
        out = self.propagate(edge_index, Q=Q, K=K, V=V, edge_attr=edge_attr, size=None)
        
        out = out.view(-1, self.out_dim)
        
        out = F.dropout(out, self.dropout_p)
        out = self.WO(out) + x
        out = self.norm1(out)
    
        # FFN
        mlp_in = out
        out = self.ffn(out)
        out = self.norm2(mlp_in + out)
        
        return out
        
        
    def message(self, Q_i, K_j, V_j, index, edge_attr=None):
        d_k = Q_i.size(-1)
        qijk = (Q_i * K_j).sum(dim=-1) / math.sqrt(d_k)
        
        alpha = softmax(qijk, index) # Log-Sum-Exp trick used. No need for clipping.
        
        return alpha.view(-1, self.num_heads, 1) * V_j
    
    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_dim}, '
                f'{self.out_dim}, heads={self.num_heads})')

In [15]:
gt = GTConv(in_dim=3, out_dim=6, num_heads=2)
gt

GTConv(3, 6, heads=2)

In [16]:
gt(x, edge_index, edge_attr=edge_attr)

tensor([[ 0.7639, -1.3601,  1.5337],
        [ 1.7370,  0.9856,  0.9378],
        [-0.4407,  0.6072,  0.3323],
        [-1.1320, -0.8076, -0.7924],
        [-0.9604, -0.7107, -0.9446],
        [ 0.0323,  1.2855, -1.0667]], grad_fn=<NativeBatchNormBackward0>)