In [16]:
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import torch_geometric as pyg
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from tqdm import tqdm

In [17]:
def sample_four_communities(p=0.03, q=0.02):
    sizes = np.ones(4, dtype=int) * 100 
    p = np.ones((4, 4)) * p
    np.fill_diagonal(p, q)
    G = pyg.utils.from_networkx(nx.stochastic_block_model(sizes, p))
    G.x = torch.eye(400)
    return G

def sample_two_communities(p=0.03, q=0.02):
    sizes = np.ones(2, dtype=int) * 200 
    p = np.ones((2, 2)) * p
    np.fill_diagonal(p, q)
    G = pyg.utils.from_networkx(nx.stochastic_block_model(sizes, p))
    G.x = torch.eye(400)
    return G

In [28]:
class MergeDataset(Dataset):

    def __init__(self, p: float = 0.06, q: float = 0.02) -> None:
        super().__init__()
        self.p = p
        self.q = q
        prog_bar = tqdm(desc="Building dataset.", total=2000)

        self.dataset = []
        for _ in range(500):
            self.dataset.append((
                sample_four_communities(p=p, q=q),
                sample_four_communities(p=p, q=q),
                torch.tensor(0.)
            ))
            prog_bar.update(1)

        for _ in range(500):
            self.dataset.append((
                sample_two_communities(p=p, q=q),
                sample_two_communities(p=p, q=q),
                torch.tensor(0.)
            ))
            prog_bar.update(1)

        for _ in range(1000):
            self.dataset.append((
                sample_two_communities(p=p, q=q),
                sample_four_communities(p=p, q=q),
                torch.tensor(1.)
            ))
            prog_bar.update(1)

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, index: int) -> tuple[pyg.data.Data, pyg.data.Data, torch.Tensor]:
        return self.dataset[index]    

In [29]:
def collate_fn(batch):
    return (
        pyg.data.Batch.from_data_list([triple[0] for triple in batch]),
        pyg.data.Batch.from_data_list([triple[1] for triple in batch]),
        torch.stack([triple[2] for triple in batch])
    )

training_data = MergeDataset()
validation_data = MergeDataset()
training_dataloader = DataLoader(training_data, batch_size=128, collate_fn=collate_fn, shuffle=True)
validation_dataloader =  DataLoader(validation_data, batch_size=128, collate_fn=collate_fn)


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Building dataset.:  13%|█▎        | 254/2000 [47:25<5:26:01, 11.20s/it]
Building dataset.:   0%|          | 6/2000 [24:11<134:00:03, 241.93s/it]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A


In [30]:
class GCNEncoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = pyg.nn.GCNConv(in_channels=400, out_channels=128)
        self.conv2 = pyg.nn.GCNConv(in_channels=128, out_channels=32)

    def forward(self, g):
        x = self.conv1(g.x, g.edge_index)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.dropout(x, 0.25)
        x = self.conv2(x, g.edge_index)
        return x 
    
    
class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = GCNEncoder()
        self.pairwise_distance = nn.PairwiseDistance()
        self.pooling = pyg.nn.pool.TopKPooling(in_channels=1, ratio=32)
        self.linear1 = nn.Linear(32, 64)
        self.linear2 = nn.Linear(64, 1)

    def forward(self, g1, g2):
        x1 = self.encoder(g1)
        x2 = self.encoder(g2)
        x = self.pairwise_distance(x1, x2).unsqueeze(1)
        x, *_ = self.pooling(
            x=x, 
            edge_index=g1.edge_index, 
            batch=g1.batch # this ensures top-k is selected per graph
        ) 
        x = x.view(-1, 32) # the top-k will be concatenated, this undoes this into a batch dimension 
        x = self.linear1(x)
        x = nn.functional.relu(x)
        x = torch.nn.functional.dropout(x, 0.25)
        x = self.linear2(x)
        return x

In [31]:
def train_epoch(model, loss_fn, optimiser, training_dataloader):
    total, correct = 0, 0
    losses = []
    for batch in training_dataloader:
        g1, g2, labels = batch
        labels = labels.float().unsqueeze(1)
        predictions = model(g1, g2)
        loss = loss_fn(predictions, labels)
        loss.backward()
        optimiser.step()
        optimiser.zero_grad()
        losses.append(loss.detach())
        correct += torch.sum((predictions>0).long() == labels).item()
        total += len(labels)
    accuracy = correct / total
    return torch.mean(torch.tensor(losses)).item(), accuracy

def validate(model, loss_fn, validation_dataloader):
    total, correct = 0, 0
    losses = []
    with torch.no_grad():
        for batch in validation_dataloader:
            g1, g2, labels = batch
            labels = labels.float().unsqueeze(1)
            predictions = model(g1, g2)
            loss = loss_fn(predictions, labels)
            losses.append(loss.detach())
            correct += torch.sum((predictions>0).long() == labels).item()
            total += len(labels)
    accuracy = correct / total
    return torch.mean(torch.tensor(losses)).item(), accuracy
    

In [32]:
model = Model()
loss_fn = nn.BCEWithLogitsLoss()
optimiser = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(5):
    train_loss, train_accuracy = train_epoch(model, loss_fn, optimiser, training_dataloader)
    valid_loss, valid_accuracy = validate(model, loss_fn, validation_dataloader)
    print(
        epoch, 
        round(train_loss, 4), 
        round(train_accuracy, 2),  
        round(valid_loss, 4), 
        round(valid_accuracy, 2),
        sep='\t'
    )

0	0.6715	0.5	0.542	0.51
1	0.4284	0.8	0.2753	0.85
2	0.1634	0.97	0.07	1.0


KeyboardInterrupt: 