In [None]:
import os
from typing import Any, Dict, Optional 

In [None]:
import torch 
from torch.nn import (BatchNorm1d, Embedding, Linear, ModuleList, ReLU, Sequential)

from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T 
from torch_geometric.datasets import ZINC 
from torch_geometric.loader import DataLoader 
from torch_geometric.nn import GINEConv, GPSConv, global_add_pool 
from torch_geometric.nn.attention import PerformerAttention 

# Data

In [None]:
path = os.path.join(os.path.dirname(os.path.realpath('./')), 'data', 'ZINC-PE')
path

In [None]:
transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')

In [None]:
train_dataset = ZINC(path, subset=True, split='train', pre_transform=transform)
val_dataset = ZINC(path, subset=True, split='val', pre_transform=transform)
test_dataset = ZINC(path, subset=True, split='test', pre_transform=transform)

# GPS graph transformer

An architecture that processes the graph in parallel through a global transformer and a local MPNN. 

In [None]:
# this is specific to the Performer architecture
class RedrawProjection: 
    def __init__(self, model:torch.nn.Module, redraw_interval: Optional[int] = None): 
        self.model = model 
        self.redraw_interval = redraw_interval 
        self.num_last_redraw = 0

    def redraw_projections(self): 
        if not self.model.training or self.redraw_interval is None: 
            return
        
        if self.num_last_redraw >= self.redraw_interval: 
            fast_attentions = [
                module for module in self.model.modules()  if isinstance(module, PerformerAttention)
            ]

            for fast_attention in fast_attentions: 
                fast_attention.redraw_projection_matrix() 
            self.num_last_redraw = 0
        else: 
            self.num_last_redraw += 1

### The actual GPS architecture

In [None]:
class GPS(torch.nn.Module): 

    def __init__(self, channels:int, pe_dim: int, num_layers: int, attn_type: str, attn_kwargs: Dict[str, Any] ): 
        super().__init__() 


        self.node_emb = Embedding(28, channels - pe_dim) # this can be sparse 
        self.pe_lin = Linear(20, pe_dim)
        self.pe_norm = BatchNorm1d(20)
        self.edge_emb = Embedding(4, channels)

        self.convs = ModuleList()
        for _ in range(num_layers): 
            nn = Sequential(
                Linear(channels, channels), 
                ReLU(), 
                Linear(channels, channels)
            )
            conv = GPSConv(channels, GINEConv(nn), heads=4, attn_type=attn_type, attn_kwargs=attn_kwargs)
            self.convs.append(conv)

        self.mlp = Sequential(
            Linear(channels, channels // 2), 
            ReLU(), 
            Linear(channels // 2, channels // 4), 
            ReLU(), 
            Linear(channels // 4, 1)
        )

        self.redraw_projection = RedrawProjection(
            self.convs, 
            redraw_interval = 1000 if attn_type=='performer' else None
        )

    
    def forward(self, x, pe, edge_index, edge_attr, batch): 
        # TODO: this should be rewritten in a more elegant, functional way
        x_pe = self.pe_norm(pe)
        x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1)
        edge_attr = self.edge_emb(edge_attr)

        for conv in self.convs: 
            x = conv(x, edge_index, batch, edge_attr=edge_attr) 
        
        x = global_add_pool(x, batch)
        return self.mlp(x)
        



## training and testing and shit

In [None]:
def train(model: torch.nn.Module, 
          loader: DataLoader, 
          optimizer: torch.optim.Optimizer, 
          device: torch.device,): 
    model.train() # training mode

    total_loss = 0.0 
    for data in train_loader: 
        data = data.to(device)
        optimizer.zero_grad()
        model.redraw_projection.redraw_projections()  # redraw projections if needed 
        out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch) 

        loss = (out.squeeze() - data.y).abs().mean() 
        loss.backward() 
        total_loss += loss.item() * data.num_graphs 
        optimizer.step()
    return total_loss / len(loader.dataset)


@torch.no_grad()
def test(model: torch.nn.Module,
         loader: DataLoader, 
         device: torch.device): 
    model.eval()  # evaluation mode

    total_error = 0.0 
    for data in loader: 
        data = data.to(device) 
        out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch) 
        total_error += (out.squeeze() - data.y).abs().sum().item() 
    return total_error /len(loader.dataset)


def run_training(
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: ReduceLROnPlateau,
        train_loader: DataLoader,
        val_loader: DataLoader,
        test_loader: DataLoader,
        device: torch.device,
        epochs: int = 101,
): 
    train_losses = []
    val_maes = []
    test_maes = []
    for epoch in range(1, epochs):
        loss = train(model, train_loader, optimizer, device)
        val_mae = test(model, val_loader, device)
        test_mae = test(model, test_loader, device)
        scheduler.step(val_mae)
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val MAE: {val_mae:.4f}, Test MAE: {test_mae:.4f}')
        train_losses.append(loss)
        val_maes.append(val_mae)
        test_maes.append(test_mae)
    return train_losses, val_maes, test_maes

# Run training 

In [None]:
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
attention_type= 'multihead'  # 'performer', 'multihead'
attn_kwargs = {'dropout': 0.5, }
model =GPS(
    channels = 64, 
    pe_dim=8, 
    num_layers=10, 
    attn_type = attention_type, 
    attn_kwargs=attn_kwargs, 
).to(device) 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, min_lr=1e-5)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 
val_loader = DataLoader(val_dataset, batch_size=64) 
test_loader = DataLoader(test_dataset, batch_size=64) 
train_loss_64, val_mae_32, test_mae_32 = run_training(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device,
    epochs=101,
)

In [None]:
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
attention_type= 'multihead'  # 'performer', 'multihead'
attn_kwargs = {'dropout': 0.5, }
model =GPS(
    channels = 64, 
    pe_dim=8, 
    num_layers=10, 
    attn_type = attention_type, 
    attn_kwargs=attn_kwargs, 
).to(device) 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, min_lr=1e-5)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) 
val_loader = DataLoader(val_dataset, batch_size=256) 
test_loader = DataLoader(test_dataset, batch_size=256) 

train_loss_128, val_mae_128, test_mae_128 = run_training(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device,
    epochs=101,
)

This uses only one GPS layer and has limited heads and capacity. We don't expect it to be particularly good. Also, there is no early stopping and hardly any test loss seems unstable. hence the batches are perhaps too small? There also appears to be some overfitting and shit. however, the principle is sound, and the system works, but the code needs some cleanup and a better structure. Pytorch is not very strict. I would like to have something in JAX for this, but Jraph has been archived and I don't see how I could build it. Switching from batches of size 32 to 128 seems to solve the problem. 