In [4]:
# Model
from torch_geometric.nn import NNConv, GATConv, global_mean_pool
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the MLP for NNConv
class EdgeMLP(nn.Module):
    def __init__(self, num_edge_features, in_channels, out_channels):
        super(EdgeMLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(num_edge_features, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, in_channels * out_channels)
        )
        
    def forward(self, edge_attr):
        return self.mlp(edge_attr)


# Adapted Graph Neural Network using NNConv and GATConv
class PairsNNConvGATGNN(nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden_channels, out_channels):
        super(PairsNNConvGATGNN, self).__init__()
        
        # Initialize the MLP for NNConv
        self.edge_mlp = EdgeMLP(num_edge_features, num_node_features, hidden_channels)
        
        # NNConv layer
        self.conv1 = NNConv(num_node_features, hidden_channels, self.edge_mlp)
        
        # GATConv layer
        self.conv2 = GATConv(hidden_channels, hidden_channels, heads=1, concat=False)

        # Fully connected layer
        self.fc1 = nn.Linear(hidden_channels, out_channels)
        self.fc2 = nn.Linear(out_channels, 1)
    
        
    def embedder(self, x, edge_index, edge_attr, batch):
        # NNConv layer
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        
        # GATConv layer
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        # Global average pooling
        x = global_mean_pool(x, batch) #<-- batch vector to keep track of graphs

        # Fully connected layers
        x = self.fc1(x)
        x = F.relu(x)
        
        return x
    
    def forward(self, x1, edge_index1, edge_attr1, batch1, x2, edge_index2, edge_attr2, batch2):
        # First graph's embeddings
        z1 = self.embedder(x1, edge_index1, edge_attr1, batch1)
        
        # Second graph's embeddings
        z2 = self.embedder(x2, edge_index2, edge_attr2, batch2)
        
        # Contrast the embeddings
        z = torch.abs(z1 - z2)
        
        # Logistic regression
        z = self.fc2(z)
        z = torch.sigmoid(z)
        
        return z

In [5]:
# Test case
import torch
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader

# Create a dummy graph using PyG's Data class
def create_sample_graph(num_nodes, num_node_features, num_edge_features):
    x1, x2 = torch.randn(num_nodes, num_node_features), torch.randn(num_nodes, num_node_features)  # Node features
    edge_index1 = torch.tensor([[0, 2, 3, 1, 4],
                                [2, 0, 1, 3, 3]], dtype=torch.long)
    edge_index2 = edge_index1
    num_edges = edge_index1.size()[1]
    edge_attr1, edge_attr2 = torch.randn((num_edges, num_edge_features)), torch.randn((num_edges, num_edge_features))   # Edge attributes
    
    y = torch.randint(0, 2, (1,))
    y = y.to(torch.float32)
    
    return PairData(x1=x1, edge_index1=edge_index1, edge_attr1=edge_attr1,
                    x2=x2, edge_index2=edge_index2, edge_attr2=edge_attr2,
                    y=y)

# Hyperparameters
num_node_features = 10
num_edge_features = 6
hidden_channels = 20
out_channels = 5

# Create dummy data
data_list = []
for i in range(8):
    data_list.append(create_sample_graph(5, 10, 6))

# Dataloader
loader = DataLoader(data_list, batch_size=4, follow_batch=['x1', 'x2'])

# Initialize the model
pairs_model = PairsNNConvGATGNN(num_node_features=num_node_features, num_edge_features=num_edge_features, 
                                   hidden_channels=hidden_channels, out_channels=out_channels)

for batch in loader:
    output = pairs_model(batch.x1, batch.edge_index1, batch.edge_attr1, batch.x1_batch,
                         batch.x2, batch.edge_index2, batch.edge_attr2, batch.x2_batch)
    print("Output shape:", output.shape)
    print("Output:", output)

Output shape: torch.Size([4, 1])
Output: tensor([[0.4106],
        [0.4191],
        [0.4301],
        [0.4233]], grad_fn=<SigmoidBackward0>)
Output shape: torch.Size([4, 1])
Output: tensor([[0.4209],
        [0.4339],
        [0.4166],
        [0.4280]], grad_fn=<SigmoidBackward0>)
