In [1]:
# Links: https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data
from torch_geometric.data import Data
import torch

In [2]:
class PairData(Data):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index_1':
            return self.x_1.size(0)
        if key == 'edge_index_2':
            return self.x_2.size(0)
        return super().__inc__(key, value, *args, **kwargs)

In [13]:
# Node features of shape (num_nodes, num_node_features) and type torch.float32
x_1 = torch.tensor([[0, 0, 0],
                    [1, 1, 1],
                    [2, 2, 2],
                    [3, 3, 3]], dtype=torch.float32)
x_2 = torch.tensor([[4, 4, 4],
                    [5, 5, 5],
                    [6, 6, 6]], dtype=torch.float32)

# Edge indices of shape (2, num_edges) and type torch.long
edge_index_1 = torch.tensor([[0, 1, 1, 2, 2, 3],
                             [1, 0, 2, 1, 3, 2]], dtype=torch.long)
edge_index_2 = torch.tensor([[0, 1, 1, 2],
                             [1, 0, 2, 1]], dtype=torch.long)

# Edge features of shape (num_edges, num_edge_features) and type torch.float32
edge_attr_1 = torch.tensor([[0],
                            [1],
                            [2],
                            [3],
                            [4],
                            [5]], dtype=torch.float32)
edge_attr_2 = torch.tensor([[6],
                            [7],
                            [8],
                            [9]], dtype=torch.float32)

# Pair label of shape (1,) and type torch.long
y = torch.tensor([1], dtype=torch.float32)

data = PairData(x_1=x_1, edge_index_1=edge_index_1, edge_attr_1=edge_attr_1,  # Graph 1.
                x_2=x_2, edge_index_2=edge_index_2, edge_attr_2=edge_attr_2,  # Graph 2.
                y=y) #Graph pair label. 

print(data.edge_attr_1)

tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.]])


In [14]:
# Dataloader for pairs of graphs
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch

# We will have our list of graphs in the form of Data objects
data_list = [data, data, data]

# Create the dataloader. The follow_batch tells the dataloader which nodes belong to which graph in this 
# giant disconnected graph that the batch creates. We can typically split the data_list into train, val, test and then 
# create individual loaders correspondingly.
train_loader = DataLoader(data_list, batch_size=2, follow_batch=['x_1', 'x_2'])


# We can iterate through batches with the following. Each batch is a data.Batch() object
for batch in train_loader:
    inputs = ((batch.x_1, batch.edge_index_1, batch.edge_attr_1), (batch.x_2, batch.edge_index_2, batch.edge_attr_2))
    graph_1, graph_2 = inputs
    labels = batch.y
    print(batch.x_2)
    print(labels)
    print("Which nodes correspond to which graph:", batch.x_1_batch)
    print("Which nodes correspond to which graph:", batch.x_2_batch)

tensor([[4., 4., 4.],
        [5., 5., 5.],
        [6., 6., 6.],
        [4., 4., 4.],
        [5., 5., 5.],
        [6., 6., 6.]])
tensor([1., 1.])
Which nodes correspond to which graph: tensor([0, 0, 0, 0, 1, 1, 1, 1])
Which nodes correspond to which graph: tensor([0, 0, 0, 1, 1, 1])
tensor([[4., 4., 4.],
        [5., 5., 5.],
        [6., 6., 6.]])
tensor([1.])
Which nodes correspond to which graph: tensor([0, 0, 0, 0])
Which nodes correspond to which graph: tensor([0, 0, 0])


In [15]:
# A simple GCN model which only takes in two graphs

import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GCNConv, global_mean_pool

class GraphClassifier(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features):
        super(GraphClassifier, self).__init__()

        # Node feature transformation layers
        self.conv1 = GCNConv(num_node_features, 64)
        self.conv2 = GCNConv(64, 128)

        # Edge feature transformation layers
        self.edge_mlp = Sequential(Linear(num_edge_features, 32),
                                   ReLU(),
                                   Linear(32, 64))
        
        # Readout layer
        self.readout = global_mean_pool

        # Classifier
        self.classifier = Linear(128, 1)

    def forward(self, x, edge_index, edge_attr, batch):
        # Update node features
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        # Update edge features
        edge_attr = self.edge_mlp(edge_attr)

        # Readout layer to get graph-level representation
        x = self.readout(x, batch)  # <-- Use the batch vector here

        # Classifier to predict the graph label
        x = self.classifier(x)
        x = torch.sigmoid(x)

        return x.squeeze(-1)

In [20]:
from torch_geometric.data import DataLoader

# Initialize model
model = GraphClassifier(num_node_features=3, num_edge_features=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()

# Training loop
for epoch in range(1, 100):
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x_1, batch.edge_index_1, batch.edge_attr_1, batch.x_1_batch)  # Use batch.x_1_batch
        loss = criterion(out, batch.y.float())
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch}, Loss: {loss.item()}')

Epoch: 1, Loss: 0.5858218669891357
Epoch: 2, Loss: 0.32294079661369324
Epoch: 3, Loss: 0.11305849999189377
Epoch: 4, Loss: 0.020637409761548042
Epoch: 5, Loss: 0.0024458253756165504
Epoch: 6, Loss: 0.00025639243540354073
Epoch: 7, Loss: 2.8252999982214533e-05
Epoch: 8, Loss: 3.57628505298635e-06
Epoch: 9, Loss: 4.768372718899627e-07
Epoch: 10, Loss: 1.1920928955078125e-07
Epoch: 11, Loss: 0.0
Epoch: 12, Loss: 0.0
Epoch: 13, Loss: 0.0
Epoch: 14, Loss: 0.0
Epoch: 15, Loss: 0.0
Epoch: 16, Loss: 0.0
Epoch: 17, Loss: 0.0
Epoch: 18, Loss: 0.0
Epoch: 19, Loss: 0.0
Epoch: 20, Loss: 0.0
Epoch: 21, Loss: 0.0
Epoch: 22, Loss: 0.0
Epoch: 23, Loss: 0.0
Epoch: 24, Loss: 0.0
Epoch: 25, Loss: 0.0
Epoch: 26, Loss: 0.0
Epoch: 27, Loss: 0.0
Epoch: 28, Loss: 0.0
Epoch: 29, Loss: 0.0
Epoch: 30, Loss: 0.0
Epoch: 31, Loss: 0.0
Epoch: 32, Loss: 0.0
Epoch: 33, Loss: 0.0
Epoch: 34, Loss: 0.0
Epoch: 35, Loss: 0.0
Epoch: 36, Loss: 0.0
Epoch: 37, Loss: 0.0
Epoch: 38, Loss: 0.0
Epoch: 39, Loss: 0.0
Epoch: 40, Loss:

In [25]:
class PairGraphClassifier(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features):
        super(PairGraphClassifier, self).__init__()

        # Node feature transformation layers
        self.conv1 = GCNConv(num_node_features, 64)
        self.conv2 = GCNConv(64, 128)

        # Edge feature transformation layers
        self.edge_mlp = Sequential(Linear(num_edge_features, 32),
                                   ReLU(),
                                   Linear(32, 64))

        # Classifier
        self.classifier = Linear(256, 1)  # 128 features from each graph

    def forward_one(self, x, edge_index, edge_attr, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        edge_attr = self.edge_mlp(edge_attr)
        x = global_mean_pool(x, batch)  # Use batch vector for separate pooling
        return x

    def forward(self, x_1, edge_index_1, edge_attr_1, batch_1, x_2, edge_index_2, edge_attr_2, batch_2):
        
        x_1 = self.forward_one(x_1, edge_index_1, edge_attr_1, batch_1)
        x_2 = self.forward_one(x_2, edge_index_2, edge_attr_2, batch_2)

        x = torch.cat([x_1, x_2], dim=1)
        x = self.classifier(x)
        x = torch.sigmoid(x)

        return x.squeeze(-1)

In [None]:
# Initialize model
model = PairGraphClassifier(num_node_features=3, num_edge_features=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()

# Training loop
for epoch in range(1, 20):
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x_1, batch.edge_index_1, batch.edge_attr_1, batch.x_1_batch,
                    batch.x_2, batch.edge_index_2, batch.edge_attr_2, batch.x_2_batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch}, Loss: {loss.item()}')
print("We have a graph pair model working!!!")

Epoch: 1, Loss: 0.20939941704273224
Epoch: 2, Loss: 0.008939467370510101
Epoch: 3, Loss: 0.0002089361078105867
Epoch: 4, Loss: 4.768382950715022e-06
Epoch: 5, Loss: 1.1920928955078125e-07
Epoch: 6, Loss: 0.0
Epoch: 7, Loss: 0.0
Epoch: 8, Loss: 0.0
Epoch: 9, Loss: 0.0
Epoch: 10, Loss: 0.0
Epoch: 11, Loss: 0.0
Epoch: 12, Loss: 0.0
Epoch: 13, Loss: 0.0
Epoch: 14, Loss: 0.0
Epoch: 15, Loss: 0.0
Epoch: 16, Loss: 0.0
Epoch: 17, Loss: 0.0
Epoch: 18, Loss: 0.0
Epoch: 19, Loss: 0.0
Epoch: 20, Loss: 0.0
Epoch: 21, Loss: 0.0
Epoch: 22, Loss: 0.0
Epoch: 23, Loss: 0.0
Epoch: 24, Loss: 0.0
Epoch: 25, Loss: 0.0
Epoch: 26, Loss: 0.0
Epoch: 27, Loss: 0.0
Epoch: 28, Loss: 0.0
Epoch: 29, Loss: 0.0
Epoch: 30, Loss: 0.0
Epoch: 31, Loss: 0.0
Epoch: 32, Loss: 0.0
Epoch: 33, Loss: 0.0
Epoch: 34, Loss: 0.0
Epoch: 35, Loss: 0.0
Epoch: 36, Loss: 0.0
Epoch: 37, Loss: 0.0
Epoch: 38, Loss: 0.0
Epoch: 39, Loss: 0.0
Epoch: 40, Loss: 0.0
Epoch: 41, Loss: 0.0
Epoch: 42, Loss: 0.0
Epoch: 43, Loss: 0.0
Epoch: 44, Loss: 0