References:

1. https://github.com/xbresson/CS6208_2023/blob/main/codes/labs_lecture07/01_vanilla_graph_transformers.ipynb
2. https://github.com/pgniewko/pytorch_geometric/blob/master/torch_geometric/nn/conv/transformer_conv.py
3. https://arxiv.org/abs/2012.09699
4. https://arxiv.org/abs/1703.04977

TDC:
1. https://tdcommons.ai/benchmark/overview/

In [14]:
import torch
import math
import torch.nn.functional as F
import numpy as np
from torch import nn
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_scatter.composite import scatter_softmax
from torch_scatter.scatter import scatter_add
from torch_geometric.utils import softmax
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.nn.aggr import MultiAggregation
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset

torch.manual_seed(192837465)

# Generate example data
x = torch.randn(6, 3)  # Node features (6 nodes, 3-dimensional features)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5], [1, 0, 2, 1, 3, 2, 5, 4]], dtype=torch.long)  # Edge indices
edge_attr = torch.randn(8, 2)  # Edge attributes (8 edges, 2-dimensional attributes)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [92]:
class MLP(nn.Module):
    def __init__(self, dims, output_dim, num_layers=2, dropout=0.0, act='relu', act_kwargs=None):
        super(MLP, self).__init__()
    
        if isinstance(dims, int):
            dims = [dims] * num_layers
            
        assert len(dims) == num_layers
        
        layers = []
        
        for (i_dim, o_dim) in zip(dims[:-1], dims[1:]):
            layers.append(nn.Linear(i_dim, o_dim, bias=True))
            layers.append(activation_resolver(act, **(act_kwargs or {})))
            if dropout > 0:
                layers.append(nn.Dropout(p=dropout))
                
        layers.append(nn.Linear(dims[-1], output_dim, bias=True))
        
        self.mlp = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.mlp(x)

    
class GTConv(MessagePassing):
    def __init__(self, in_dim, hidden_dim, num_heads=1, dropout=0.0, norm='bn'):
        super(GTConv, self).__init__(node_dim=0, aggr='add')
        
        assert hidden_dim % num_heads == 0
        
        self.WQ = nn.Linear(in_dim, hidden_dim, bias=True)
        self.WK = nn.Linear(in_dim, hidden_dim, bias=True)
        self.WV = nn.Linear(in_dim, hidden_dim, bias=True)
        self.WO = nn.Linear(hidden_dim, in_dim, bias=True)
        
        if norm.lower() in ['bn', 'batchnorm', 'batch_norm']:
            self.norm1 = nn.BatchNorm1d(in_dim)
            self.norm2 = nn.BatchNorm1d(in_dim)
        elif norm.lower() in ['ln', 'layernorm', 'layer_norm']:
            self.norm1 = nn.LayerNorm(in_dim)
            self.norm2 = nn.LayerNorm(in_dim)
            
        self.dropout_layer = nn.Dropout(p=dropout)
            
        self.ffn = MLP(dims=in_dim, output_dim=in_dim, num_layers=2, dropout=dropout)
        
        self.num_heads = num_heads
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.norm = norm.lower()
        
        self.reset_parameters()
        
    def reset_parameters(self):
        # TODO: init linear layers with xavier weights    
        pass
    
    def forward(self, x, edge_index, edge_attr=None):
        #print(edge_index)
        x_ = x
        Q = self.WQ(x).view(-1, self.num_heads, self.hidden_dim // self.num_heads)
        K = self.WK(x).view(-1, self.num_heads, self.hidden_dim // self.num_heads)
        V = self.WV(x).view(-1, self.num_heads, self.hidden_dim // self.num_heads)
        
        out = self.propagate(edge_index, Q=Q, K=K, V=V, edge_attr=edge_attr, size=None)
        
        out = out.view(-1, self.hidden_dim)
        
        out = self.dropout_layer(out)
        out = self.WO(out) + x_ # Residual connection
        out = self.norm1(out)
    
        # FFN
        mlp_in = out
        out = self.ffn(out)
        out = self.norm2(mlp_in + out)
        
        return out
        
        
    def message(self, Q_i, K_j, V_j, index, edge_attr=None):
        d_k = Q_i.size(-1)
        qijk = (Q_i * K_j).sum(dim=-1) / math.sqrt(d_k)
        
        alpha = softmax(qijk, index) # Log-Sum-Exp trick used. No need for clipping.
        
        return alpha.view(-1, self.num_heads, 1) * V_j
    
    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_dim}, '
                f'{self.hidden_dim}, heads={self.num_heads})')
    
    
class GraphTransformerNet(nn.Module):
    def __init__(self, node_dim_in, edge_dim_in,
                 hidden_dim=128, norm='bn',
                 num_gt_layers=4, num_heads=8,
                 aggregators=['sum']):
        super(GraphTransformerNet, self).__init__()
        
        self.node_emb = nn.Linear(node_dim_in, hidden_dim)
        self.edge_emb = nn.Linear(edge_dim_in, hidden_dim)
        
        self.gt_layers = [GTConv(in_dim=hidden_dim, hidden_dim=hidden_dim, 
                                 num_heads=num_heads) for _ in range(num_gt_layers)]
        
        self.global_pool = MultiAggregation(aggregators, mode='cat')
        
        num_aggrs = len(aggregators)
        self.mu_mlp = MLP(dims=[num_aggrs * hidden_dim, hidden_dim, hidden_dim], 
                          output_dim=1, num_layers=3, dropout=0.0)
        self.std_mlp = MLP(dims=[num_aggrs * hidden_dim, hidden_dim, hidden_dim], 
                           output_dim=1, num_layers=3, dropout=0.0)
        
        
    def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_emb(x.squeeze())
        
        edge_attr = self.edge_emb(edge_attr)

        for gt_layer in self.gt_layers:
            x = gt_layer(x, edge_index, edge_attr=edge_attr)

        x = self.global_pool(x, batch)
        mu = self.mu_mlp(x)
        log_var = self.std_mlp(x)
        
        
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps * std + mu, mu, std

In [93]:
gt = GTConv(in_dim=3, hidden_dim=6, num_heads=2)
gt

GTConv(3, 6, heads=2)

In [94]:
gt(x, edge_index, edge_attr=edge_attr)

tensor([[ 0.8714, -1.5164,  1.4111],
        [ 1.7645,  1.4195,  0.9956],
        [-0.5366,  0.4443,  0.4392],
        [-0.4399, -0.3694, -0.8469],
        [-1.1963, -0.8110, -1.2334],
        [-0.4631,  0.8331, -0.7656]], grad_fn=<NativeBatchNormBackward0>)

In [95]:
gt_net = GraphTransformerNet(node_dim_in=3, edge_dim_in=2)

In [99]:
loader = DataLoader([data], batch_size=128, shuffle=False)
for data in loader:
    x = data.x
    edge_index = data.edge_index
    edge_attr = data.edge_attr
    pred, mu, std = gt_net(x, edge_index, edge_attr, data.batch)
    print(pred, mu, std)

tensor([[-0.4718]], grad_fn=<AddBackward0>) tensor([[0.0691]], grad_fn=<AddmmBackward0>) tensor([[0.9552]], grad_fn=<ExpBackward0>)
