Code implementation references:

1) The paper "Beltrami Flow and Neural Diffusion on Graphs" by Chamberlain et al.

2) https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial7/GNN_overview.html#PyTorch-Geometric

In [None]:
import os
import json
import math
import numpy as np
import pandas as pd
import time
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

try:
    import torch_geometric
except ModuleNotFoundError:
    # Installing torch geometric packages with specific CUDA+PyTorch version.
    # See https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html for details
    TORCH = torch.__version__.split('+')[0]
    CUDA = 'cu' + torch.version.cuda.replace('.','')
    !pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-geometric
    import torch_geometric
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data
import torch_geometric.datasets
from torch_geometric.datasets import Planetoid

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError:
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
# Download Cora dataset
data_dir = "./data"

# Split training, validation, and test data randomly
dataset = Planetoid(root=data_dir, name='Cora', split='random')
data = dataset[0]
print(data)

In [None]:
# Get nodes, edges, and features
print("# of Nodes: ", data.num_nodes)

edges = data.edge_index
print("\nEdges:\n", edges)
print("Shape of Edges: ", edges.size())

features = data.x
features_updated = torch.unsqueeze(features, 0)
print("\nFeatures:\n", features_updated)
print("Shape of Features: ", features_updated.size())

labels = data.y
print("\nLabels:\n", labels)
print("Shape of Labels: ", labels.size())

In [None]:
# Visualize the initial graph
def construct_graph(edges):
    edge_list = zip(edges[0], edges[1])
    # Directed graph with edges going both directions
    g = nx.DiGraph(edge_list)
    return g

G = construct_graph(edges)
print(G)
nx.draw(G)

In [None]:
# Beltrami Flow layer
class BeltramiLayer(nn.Module):
  def __init__(self, in_dim, out_dim, num_heads=1, concat_heads=True, alpha=0.2):
    """
    Parameters:
    in_dim (int) - input dimension
    out_dim (int) - output dimension
    num_heads (int) - # of heads, attention mechanism applied in parallel
    concat_heads (bool) - if True, outputs of different heads concatenated rather than averaged
    alpha (float) - negative slope of LeakyReLU activation
    """
    super().__init__()
    self.num_heads = num_heads
    self.concat_heads = concat_heads

    # If concatenate outputs of heads, output dimension should be a multiple of # of heads
    if self.concat_heads:
      assert out_dim % num_heads == 0
      out_dim = out_dim // num_heads
    
    # Sub-modules and parameters
    self.projection = nn.Linear(in_dim, out_dim*num_heads)
    self.a = nn.Parameter(torch.Tensor(num_heads, 2*out_dim))
    self.leakyrelu = nn.LeakyReLU(alpha)

    # Xavier initialization
    nn.init.xavier_uniform_(self.projection.weight.data, gain=1.414)
    nn.init.xavier_uniform_(self.a.data, gain=1.414)

  def forward(self, node_feats, edge_list, print_attn_probs=False, alpha=1):
      """
      Parameters:
          node_feats - input features of the node, shape = [1, batch_size, c_in]
          edge_list - list of edges, shape = [2, 10556]
          print_attn_probs - if True, attention weights are printed during forward pass (for debugging purposes)
      """
      batch_size, num_nodes = node_feats.size(0), node_feats.size(1)
      num_feats = node_feats.size(2)

      # Apply linear layer, and sort nodes by head
      node_feats = self.projection(node_feats)
      node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)

      edge_transposed = torch.transpose(edge_list, 0, 1)

      # Calculate attention MLP output (independent for each head)
      attn_matrix = torch.empty((num_nodes, num_nodes, num_feats))
      print(attn_matrix.size())
      
      for edge in edge_transposed:
        src_node = edge[0].item()
        dest_node = edge[1].item()
        diff = torch.subtract(node_feats[0][src_node], node_feats[0][dest_node])
        attn = diff.apply_(lambda x: (1 / math.sqrt(1 + (alpha**2) * (x**2))))
        attn_updated = torch.squeeze(attn)
        attn_matrix[src_node][dest_node] = attn_updated
      print("Attention matrix\n", attn_matrix)

      # Weighted average of attention
      attn_probs = F.softmax(attn_matrix, dim=2)
      if print_attn_probs:
          print("Attention probs\n", attn_probs.permute(0, 3, 1, 2))
      node_feats = torch.einsum('bijh,bjhc->bihc', attn_probs, node_feats)

      # If heads concatenated, we can do this by reshaping. Otherwise, take mean
      if self.concat_heads:
          node_feats = node_feats.reshape(batch_size, num_nodes, -1)
      else:
          node_feats = node_feats.mean(dim=2)

      return node_feats

In [None]:
layer = BeltramiLayer(1433, 1433, num_heads=1)

with torch.no_grad():
  out_feats = layer(features_updated, edges, print_attn_probs=True)

In [None]:
# BLEND network
class BLENDModel(nn.Module):
  def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2, drop_rate=0.1, **kwargs):
    """
    Parameters:
    in_dim (int) - input dimension
    hidden_dim (int) - dimension of hidden features
    out_dim (int) - output dimension
    num_layers (int) - # of hidden graph layers
    drop_rate (float) - dropout rate to apply throughout the network
    """
    super().__init__()

    layers = []
    in_channels, out_channels = in_dim, hidden_dim
    for l_idx in range(num_layers-1):
      layers += [
          BeltramiLayer(in_dim=in_channels, out_dim=out_channels, **kwargs),
          nn.ReLU(inplace=True),
          nn.Dropout(drop_rate)
      ]
      in_channels = hidden_dim
    layers += [
        BeltramiLayer(in_dim=in_channels, out_dim=out_channels, **kwargs)
    ]
    self.layers = nn.ModuleList(layers)

  def forward(self, x, edge_index):
    """
    Parameters:
    x (int) - input features per node
    edge_index (list) - list of vertex index pairs representing the edges (PyTorch geometric notation)
    """
    for l in self.layers:
      # PyTorch geometric graph layers all inherit the MessagePassing class
      if isinstance(l, geom_nn.MessagePassing):
        x = l(x, edge_index)
      else:
        x = l(x)
    return x

In [None]:
# Training, validation, & testing
class NodeLevelBLEND(pl.LightningModule):
  def __init__(self, **model_kwargs):
    super().__init__()
    # Save hyperparameters
    self.save_hyperparameters()

    self.model = BLENDModel(**model_kwargs)
    self.loss_module = nn.CrossEntropyLoss()
  
  def forward(self, data, mode="train"):
    x, edge_index = data.x, data.edge_index
    x = self.model(x, edge_index)

    # Only calculate loss on the nodes corresponding to the mask
    if mode == "train":
      mask = data.train_mask
    elif mode == "val":
      mask = data.val_mask
    elif mode == "test":
      mask = data.test_mask
    else:
      assert False, f"Unknown forward mode: {mode}"
    
    loss = self.loss_module(x[mask], data.y[mask])
    accuracy = (x[mask].argmax(dim=-1) == data.y[mask]).sum().float() / mask.sum()
    return loss, accuracy
  
  def configure_optimizers(self):
    # Use SGD here
    optimizer = optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=2e-3)
    return optimizer
  
  def training_step(self, batch, batch_idx):
    loss, accuracy = self.forward(batch, mode="train")
    self.log('train_loss', loss)
    self.log('train_accuracy', accuracy)
    return loss
  
  def validation_step(self, batch, batch_idx):
    _, accuracy = self.forward(batch, mode="val")
    self.log('val_accuracy', accuracy)

  def test_step(self, batch, batch_idx):
    _, accuracy = self.forward(batch, mode="test")
    self.log('test_accuracy', accuracy)

In [None]:
# Define a training function
def train_node_classifier(dataset, **model_kwargs):
  pl.seed_everything(42)
  node_data_loader = geom_data.DataLoader(dataset, batch_size=1)

  # Create a PyTorch lightning trainer with generation callback
  trainer = pl.Trainer(callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")], 
                       accelerator="gpu" if str(device).startswith("cuda") else "cpu", 
                       devices=1, max_epochs=200, enable_progress_bar=False)
  # Optional logging argument
  trainer.logger._default_hp_metric = None

  # Start model training
  pl.seed_everything()
  model = NodeLevelBLEND(in_dim=dataset.num_node_features, out_dim=dataset.num_classes, **model_kwargs)
  trainer.fit(model, node_data_loader, node_data_loader)
  model = NodeLevelBLEND.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

  # Test best model on the test set
  test_result = trainer.test(model, node_data_loader, verbose=False)
  batch = next(iter(node_data_loader))
  batch = batch.to(model.device)
  _, train_accuracy = model.forward(batch, mode="train")
  _, val_accuracy = model.forward(batch, mode="val")
  result = {"train": train_accuracy,
            "val": val_accuracy,
            "test": test_result[0]['test_accuracy']}
  return model, result

In [None]:
# Print test results
def print_results(result_dict):
    if "train" in result_dict:
        print(f"Train accuracy: {(100.0*result_dict['train']):4.2f}%")
    if "val" in result_dict:
        print(f"Val accuracy:   {(100.0*result_dict['val']):4.2f}%")
    print(f"Test accuracy:  {(100.0*result_dict['test']):4.2f}%")

In [None]:
node_blend_model, node_blend_result = train_node_classifier(dataset=dataset,
                                                            hidden_dim=16,
                                                            num_layers=2,
                                                            drop_rate=0.1)

print_results(node_blend_result)