In [1]:
from torch_geometric.nn import NNConv, GATConv, global_mean_pool
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
# Define the MLP for NNConv
class EdgeMLP(nn.Module):
    def __init__(self, num_edge_features, in_channels, out_channels):
        super(EdgeMLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(num_edge_features, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, in_channels * out_channels)
        )
        
    def forward(self, edge_attr):
        return self.mlp(edge_attr)


# Adapted Graph Neural Network using NNConv and GATConv
class NNConvGATGNN(nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden_channels, out_channels):
        super(NNConvGATGNN, self).__init__()
        
        # Initialize the MLP for NNConv
        self.edge_mlp = EdgeMLP(num_edge_features, num_node_features, hidden_channels)
        
        # NNConv layer
        self.conv1 = NNConv(num_node_features, hidden_channels, self.edge_mlp)
        
        # GATConv layer
        self.conv2 = GATConv(hidden_channels, hidden_channels, heads=1, concat=False)

        # Fully connected layer
        self.fc1 = nn.Linear(hidden_channels, out_channels)
        self.fc2 = nn.Linear(out_channels, 1)
        
    def forward(self, x, edge_index, edge_attr, batch):
        # NNConv layer
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        
        # GATConv layer
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        # Global average pooling
        x = global_mean_pool(x, batch) #<-- batch vector to keep track of graphs

        # Fully connected layers
        x = self.fc1(x)
        x = F.relu(x)

        x = self.fc2(x)
        x = torch.sigmoid(x)
        
        return x

# Initialize the model
model_with_gat = NNConvGATGNN(num_node_features=10, num_edge_features=6, hidden_channels=20, out_channels=5)

# Print the model architecture
print(model_with_gat)


NNConvGATGNN(
  (edge_mlp): EdgeMLP(
    (mlp): Sequential(
      (0): Linear(in_features=6, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=64, bias=True)
      (3): ReLU()
      (4): Linear(in_features=64, out_features=200, bias=True)
    )
  )
  (conv1): NNConv(10, 20, aggr=add, nn=EdgeMLP(
    (mlp): Sequential(
      (0): Linear(in_features=6, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=64, bias=True)
      (3): ReLU()
      (4): Linear(in_features=64, out_features=200, bias=True)
    )
  ))
  (conv2): GATConv(20, 20, heads=1)
  (fc1): Linear(in_features=20, out_features=5, bias=True)
  (fc2): Linear(in_features=5, out_features=1, bias=True)
)


In [7]:
# Test for a batch
from torch_geometric.data import Batch, Data

# Initialize the model
model = NNConvGATGNN(num_node_features=10, num_edge_features=6, hidden_channels=20, out_channels=5)

# Create synthetic data for two graphs
# Graph 1
x1 = torch.randn((8, 10))  # 8 nodes with 10 features each
edge_index1 = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7, 0]], dtype=torch.long)  # 8 edges
edge_attr1 = torch.randn((8, 6))  # 8 edges with 6 features each
data1 = Data(x=x1, edge_index=edge_index1, edge_attr=edge_attr1)

# Graph 2
x2 = torch.randn((6, 10))  # 6 nodes with 10 features each
edge_index2 = torch.tensor([[0, 1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 0]], dtype=torch.long)  # 6 edges
edge_attr2 = torch.randn((6, 6))  # 6 edges with 6 features each
data2 = Data(x=x2, edge_index=edge_index2, edge_attr=edge_attr2)

# Create a batch from multiple graphs
batch = Batch.from_data_list([data1, data2])

# Forward pass
output = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

# Output should be of shape [num_nodes_in_all_graphs, out_channels]
print("Output shape:", output.shape)
print("Output:", output)

Output shape: torch.Size([2, 1])
Output: tensor([[0.5656],
        [0.5635]], grad_fn=<SigmoidBackward0>)
