In [3]:
# Model
from torch_geometric.nn import NNConv, GATConv, global_mean_pool
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the MLP for NNConv
class EdgeMLP(nn.Module):
    def __init__(self, num_edge_features, in_channels, out_channels):
        super(EdgeMLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(num_edge_features, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, in_channels * out_channels)
        )
        
    def forward(self, edge_attr):
        return self.mlp(edge_attr)


# Adapted Graph Neural Network using NNConv and GATConv
class PairsNNConvGATGNN(nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden_channels, out_channels):
        super(PairsNNConvGATGNN, self).__init__()
        
        # Initialize the MLP for NNConv
        self.edge_mlp = EdgeMLP(num_edge_features, num_node_features, hidden_channels)
        
        # NNConv layer
        self.conv1 = NNConv(num_node_features, hidden_channels, self.edge_mlp)
        
        # GATConv layer
        self.conv2 = GATConv(hidden_channels, hidden_channels, heads=1, concat=False)

        # Fully connected layer
        self.fc1 = nn.Linear(hidden_channels, out_channels)
        self.fc2 = nn.Linear(out_channels, 1)
    
        
    def embedder(self, x, edge_index, edge_attr, batch):
        # NNConv layer
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        
        # GATConv layer
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        # Global average pooling
        x = global_mean_pool(x, batch) #<-- batch vector to keep track of graphs

        # Fully connected layers
        x = self.fc1(x)
        x = F.relu(x)
        
        return x
    
    def forward(self, x1, edge_index1, edge_attr1, batch1, x2, edge_index2, edge_attr2, batch2):
        # First graph's embeddings
        z1 = self.embedder(x1, edge_index1, edge_attr1, batch1)
        
        # Second graph's embeddings
        z2 = self.embedder(x2, edge_index2, edge_attr2, batch2)
        
        # Contrast the embeddings
        z = torch.abs(z1 - z2)
        
        # Logistic regression
        z = self.fc2(z)
        z = torch.sigmoid(z)
        
        return z.squeeze(1)

In [4]:
# Test case
import torch
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader

# PairData class
class PairData(Data):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index1':
            return self.x1.size(0)
        if key == 'edge_index2':
            return self.x2.size(0)
        return super().__inc__(key, value, *args, **kwargs)


# Create a dummy graph using PyG's Data class
def create_sample_graph(num_nodes, num_node_features, num_edge_features):
    x1, x2 = torch.randn(num_nodes, num_node_features), torch.randn(num_nodes, num_node_features)  # Node features
    edge_index1 = torch.tensor([[0, 2, 3, 1, 4],
                                [2, 0, 1, 3, 3]], dtype=torch.long)
    edge_index2 = edge_index1
    num_edges = edge_index1.size()[1]
    edge_attr1, edge_attr2 = torch.randn((num_edges, num_edge_features)), torch.randn((num_edges, num_edge_features))   # Edge attributes
    
    y = torch.randint(0, 2, (1,))
    y = y.to(torch.float32)
    
    return PairData(x1=x1, edge_index1=edge_index1, edge_attr1=edge_attr1,
                    x2=x2, edge_index2=edge_index2, edge_attr2=edge_attr2,
                    y=y)

# Hyperparameters
num_node_features = 10
num_edge_features = 6
hidden_channels = 20
out_channels = 5

# Create dummy data
data_list = []
for i in range(8):
    data_list.append(create_sample_graph(5, 10, 6))

# Dataloader
loader = DataLoader(data_list, batch_size=4, follow_batch=['x1', 'x2'])

# Initialize the model
model = PairsNNConvGATGNN(num_node_features=num_node_features, num_edge_features=num_edge_features, 
                                   hidden_channels=hidden_channels, out_channels=out_channels)

for batch in loader:
    output = model(batch.x1, batch.edge_index1, batch.edge_attr1, batch.x1_batch,
                         batch.x2, batch.edge_index2, batch.edge_attr2, batch.x2_batch)
    print("Output shape:", output.shape)
    print("Output:", output)

Output shape: torch.Size([4])
Output: tensor([0.4297, 0.4350, 0.4422, 0.4282], grad_fn=<SqueezeBackward1>)
Output shape: torch.Size([4])
Output: tensor([0.4288, 0.4428, 0.4350, 0.4345], grad_fn=<SqueezeBackward1>)


In [7]:
# Training test
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()

for epoch in range(1, 100):
    for batch in loader:
        optimizer.zero_grad()
        out = model(batch.x1, batch.edge_index1, batch.edge_attr1, batch.x1_batch,
                    batch.x2, batch.edge_index2, batch.edge_attr2, batch.x2_batch)
        loss = criterion(out, batch.y.float())
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch}, Loss: {loss.item()}')
    if loss.item() < 1e-5:
        print("Success! Looks like we're overfitting.")
        break

Epoch: 1, Loss: 0.0
Success! Looks like we're overfitting.


In [8]:
# Load PairData class dataset of graph pairs
path = r"C:\Users\xmoot\Desktop\Data\ssl-seizure-detection\patient_pseudolabeled\relative_positioning\PyG\jh101_12s_7min_PairData.pt"
data = torch.load(path)

In [15]:
# Place into dataloader
loader = DataLoader(data, batch_size=32, follow_batch=['x1', 'x2'])
print("Number of batches:", len(loader))

Number of batches: 28969


In [12]:
# Real training
num_node_features = 1
num_edge_features = 1
hidden_channels = 64
out_channels = 32

model = PairsNNConvGATGNN(num_node_features=num_node_features, num_edge_features=num_edge_features, 
                                   hidden_channels=hidden_channels, out_channels=out_channels)             
                          
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCELoss()

for epoch in range(1, 10):
    i = 0
    for batch in loader:
        optimizer.zero_grad()
        out = model(batch.x1, batch.edge_index1, batch.edge_attr1, batch.x1_batch,
                    batch.x2, batch.edge_index2, batch.edge_attr2, batch.x2_batch)
        loss = criterion(out, batch.y.float())
        loss.backward()
        optimizer.step()
        i += 1
        print(i)
    print(f'Epoch: {epoch}, Loss: {loss.item()}')

1
2
3
4
5
6
7
8
9


KeyboardInterrupt: 