In [None]:
import itertools
import random

import torch
from torch.nn import Linear
from torch.nn import functional as F
from torch.optim import Adam
from torch_geometric.data import Data
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import BatchNorm, MessagePassing, TopKPooling
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import (
    dense_to_sparse,
    is_undirected,
    to_networkx,
    to_undirected,
)
from torch_scatter import scatter_mean

from custom.args import grey, purple
from custom.dataset import GraphDataset, create_dataset
from custom.utils import *
from custom.model import GraphMatchingNetwork

In [None]:
dataset = TUDataset(
    root="data", name="MUTAG", use_node_attr=True, transform=NormalizeFeatures()
)

small_graphs, medium_graphs, large_graphs, classes = analyze_dataset(dataset)

In [None]:
def train(model, optimizer, pairs, labels, batch_size, title=""):
    model.train()
    train_losses = []
    train_accuracies = []
    losses = []
    accs = []

    for i in range(len(pairs)):
        optimizer.zero_grad()

        graph1, graph2 = pairs[i]
        label = labels[i]

        feats_1, edge_index_1 = graph1.x, graph1.edge_index
        feats_2, edge_index_2 = graph2.x, graph2.edge_index
        sizes_1 = torch.tensor([graph1.num_nodes])
        sizes_2 = torch.tensor([graph2.num_nodes])

        emb_1, emb_2, _, _ = model(
            feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
        )

        metrics = model.compute_metrics(emb_1, emb_2, torch.tensor([label]))
        loss = metrics["loss"]
        acc = metrics["acc"]

        losses.append(loss)
        accs.append(acc)

        if i % batch_size == 0 and i > 0:
            batch_loss = torch.mean(torch.stack(losses))
            batch_acc = torch.mean(torch.stack(accs))
            losses = []
            accs = []
            train_losses.append(batch_loss.detach().numpy())
            train_accuracies.append(batch_acc.detach().numpy())
            batch_loss.backward()
            optimizer.step()

    plt.figure(figsize=(12, 5))

    plt.suptitle(title)

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Training Loss")
    plt.title("Loss over Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label="Training Accuracy")
    plt.title("Accuracy over Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()

    plt.show()

In [None]:
def test(model, title="", layers=3):
    class_clusters = []
    class_accs = []
    for i in range(dataset.num_classes):
        c = f"class_{str(i)}"
        idx1 = random.sample(range(len(classes[c])), 1)[0]
        idx2 = random.sample(range(len(classes[c])), 1)[0]
        graph1, graph2 = classes[c][idx1], classes[c][idx2]

        model.eval()

        feats_1, edge_index_1 = graph1.x, graph1.edge_index
        feats_2, edge_index_2 = graph2.x, graph2.edge_index
        sizes_1 = torch.tensor([len(graph1.x)])
        sizes_2 = torch.tensor([len(graph2.x)])
        _, _, cluster1, cluster2 = model(
            feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
        )

        clusters = []
        accs = []

        for i in range(len(model.topk_outputs)):
            (
                (x_pooled_1, edge_index_pooled_1, perm1),
                (x_pooled_2, edge_index_pooled_2, perm2),
            ) = model.topk_outputs[i]
            clusters.append(
                (
                    Data(x=x_pooled_1, edge_index=edge_index_pooled_1),
                    Data(x=x_pooled_2, edge_index=edge_index_pooled_2),
                )
            )
            accs.append(
                (
                    len(set(range(8)) & set(perm1.tolist())),
                    len(set(range(8)) & set(perm2.tolist())),
                )
            )

        acc = list(itertools.chain.from_iterable(zip(*accs)))
        class_accs.extend(acc)

        cs = list(itertools.chain.from_iterable(zip(*clusters)))
        class_clusters.extend(cs)

    plot_all_classes(class_clusters, class_accs, title, layers, 2)

In [None]:
hyperparameters = [
    (32, 3, 0.2, 0.001, 64, 500),
    (32, 3, 0.4, 0.01, 16, 1000),
    (32, 3, 0.4, 0.01, 32, 1000),
    (32, 3, 0.4, 0.01, 64, 3000),
    (32, 3, 0.5, 0.01, 64, 1000),
    (32, 3, 0.5, 0.01, 128, 3000),
    (32, 3, 0.5, 0.001, 64, 3000),
    (32, 4, 0.1, 0.01, 16, 500),
    (32, 4, 0.1, 0.001, 32, 1000),
    (32, 4, 0.3, 0.0001, 16, 500),
    (32, 4, 0.4, 0.001, 32, 1000),
    (32, 4, 0.5, 0.01, 32, 500),
    (32, 4, 0.5, 0.01, 64, 1000),
    (32, 4, 0.5, 0.01, 128, 2000),
    (32, 4, 0.5, 0.01, 128, 3000),
    (32, 4, 0.5, 0.001, 64, 2000),
    (32, 4, 1, 0.01, 32, 500),
    (32, 4, 1, 0.01, 128, 2000),
    (32, 5, 0.2, 0.0001, 64, 1000),
    (32, 5, 0.1, 0.01, 64, 1000),
    (32, 4, 1, 0.01, 128, 2000),
    (32, 5, 1, 0.001, 64, 2000),
    (32, 6, 0.2, 0.001, 16, 2000),
    (32, 7, 0.2, 0.001, 16, 500),
    (32, 7, 0.4, 0.001, 64, 500),
    (32, 7, 0.4, 0.001, 128, 1000),
    (32, 8, 0.2, 0.001, 64, 2000),
    (32, 8, 0.5, 0.01, 64, 3000),
]

In [None]:
class Args:
    def __init__(self, dim, num_layers, margin, lr, batch_size, num_pairs):
        self.dim = dim
        self.feat_dim = dataset.num_features
        self.num_layers = num_layers
        self.margin = margin
        self.lr = lr
        self.n_classes = dataset.num_classes
        self.batch_size = batch_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.n_clusters = 0.7
        self.num_pairs = num_pairs


for hyperparams in hyperparameters:
    newargs = Args(*hyperparams)
    m = GraphMatchingNetwork(newargs)
    o = Adam(m.parameters(), lr=newargs.lr, weight_decay=1e-5)
    p, l = create_graph_pairs(dataset, newargs.num_pairs)
    train(m, o, p, l, newargs.batch_size, str(hyperparams))
    for _ in range(10):
        test(m, layers=newargs.num_layers)