In [1]:
import itertools
import random

import torch
from torch.nn import Linear
from torch.nn import functional as F
from torch.nn.functional import cosine_similarity
from torch.optim import Adam
from torch_geometric.data import Data
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import BatchNorm, MessagePassing, TopKPooling
from torch_geometric.transforms import NormalizeFeatures
from torch_scatter import scatter_mean
from collections import defaultdict

from custom.args import grey, purple
from custom.dataset import GraphDataset, create_dataset
from custom.utils import *
from networkx.algorithms.centrality import degree_centrality
import pickle
from itertools import zip_longest
from torch_geometric.utils import degree
from torch_geometric.utils import degree, from_networkx, to_networkx


from custom.utils import *

In [2]:
class GraphMatchingConvolution(MessagePassing):
    def __init__(self, in_channels, out_channels, args, aggr="add"):
        super(GraphMatchingConvolution, self).__init__(aggr=aggr)
        self.args = args
        self.lin_node = torch.nn.Linear(in_channels, out_channels)
        self.lin_message = torch.nn.Linear(out_channels * 2, out_channels)
        self.lin_passing = torch.nn.Linear(out_channels + in_channels, out_channels)
        self.batch_norm = BatchNorm(out_channels)

    def forward(self, x, edge_index, batch):
        x_transformed = self.lin_node(x)
        return self.propagate(edge_index, x=x_transformed, original_x=x, batch=batch)

    def message(self, edge_index_i, x_i, x_j):
        x = torch.cat([x_i, x_j], dim=1)
        m = self.lin_message(x)
        return m

    def update(self, aggr_out, edge_index, x, original_x, batch):
        n_graphs = torch.unique(batch).shape[0]
        cross_graph_attention, a_x, a_y = batch_block_pair_attention(
            original_x, batch, n_graphs
        )
        attention_input = original_x - cross_graph_attention
        aggr_out = self.lin_passing(torch.cat([aggr_out, attention_input], dim=1))
        aggr_out = self.batch_norm(aggr_out)

        norms = torch.norm(aggr_out, p=2, dim=1)
        cross_attention_sums = cross_graph_attention.sum(dim=1)

        return (
            aggr_out,
            edge_index,
            batch,
            (
                attention_input,
                cross_graph_attention,
                a_x,
                a_y,
                norms,
                cross_attention_sums,
            ),
        )


class GraphAggregator(torch.nn.Module):
    def __init__(self, in_channels, out_channels, args):
        super(GraphAggregator, self).__init__()
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.lin_gate = torch.nn.Linear(in_channels, out_channels)
        self.lin_final = torch.nn.Linear(out_channels, out_channels)
        self.args = args

    def forward(self, x, edge_index, batch):
        x_states = self.lin(x)
        x_gates = torch.nn.functional.softmax(self.lin_gate(x), dim=1)
        x_states = x_states * x_gates
        x_states = scatter_mean(x_states, batch, dim=0)
        x_states = self.lin_final(x_states)
        return x_states


class GraphMatchingNetwork(torch.nn.Module):
    def __init__(self, args):
        super(GraphMatchingNetwork, self).__init__()
        self.args = args
        self.margin = self.args.margin
        if args.n_classes > 2:
            self.f1_average = "micro"
        else:
            self.f1_average = "binary"
        self.layers = torch.nn.ModuleList()
        self.layers.append(
            GraphMatchingConvolution(self.args.feat_dim, self.args.dim, args)
        )
        for _ in range(self.args.num_layers - 1):
            self.layers.append(
                GraphMatchingConvolution(self.args.dim, self.args.dim, args)
            )
        self.aggregator = GraphAggregator(self.args.dim, self.args.dim, self.args)
        self.layer_outputs = []
        self.layer_cross_attentions = []
        self.mincut = []
        self.mlp = torch.nn.Sequential()
        self.args.n_clusters = args.n_clusters
        self.mlp.append(Linear(self.args.dim, self.args.n_clusters))
        self.topk_outputs = []
        self.norms_per_layer = []
        self.attention_sums_per_layer = []

    def compute_emb(
        self, feats, edge_index, batch, sizes_1, sizes_2, edge_index_1, edge_index_2
    ):

        topk_pooling = TopKPooling(
            self.args.dim, ratio=min(sizes_1.item(), sizes_2.item())
        )
        for i in range(self.args.num_layers):
            (
                feats,
                edge_index,
                batch,
                (
                    attention_input,
                    cross_graph_attention,
                    a_x,
                    a_y,
                    norms,
                    attention_sums,
                ),
            ) = self.layers[i](feats, edge_index, batch)

            x_1 = feats[: sizes_1.item(), :]
            x_2 = feats[sizes_1.item() : sizes_1.item() + sizes_2.item(), :]

            norms_1 = norms[: sizes_1.item()]
            norms_2 = norms[sizes_1.item() : sizes_1.item() + sizes_2.item()]

            attention_sums_1 = attention_sums[: sizes_1.item()]
            attention_sums_2 = attention_sums[
                sizes_1.item() : sizes_1.item() + sizes_2.item()
            ]

            x_pooled_1, edge_index_pooled_1, _, _, perm1, score1 = topk_pooling(
                x_1,
                edge_index_1,
            )
            x_pooled_2, edge_index_pooled_2, _, _, perm2, score2 = topk_pooling(
                x_2,
                edge_index_2,
            )

            self.topk_outputs.append(
                (
                    (x_pooled_1, edge_index_pooled_1, perm1, score1),
                    (x_pooled_2, edge_index_pooled_2, perm2, score2),
                )
            )
            self.layer_cross_attentions.append((cross_graph_attention, a_x, a_y))
            self.layer_outputs.append((x_1, edge_index_1, x_2, edge_index_2))
            self.norms_per_layer.append((norms_1, norms_2))
            self.attention_sums_per_layer.append((attention_sums_1, attention_sums_2))

        feats = self.aggregator(feats, edge_index, batch)
        return feats, edge_index, batch

    def combine_pair_embedding(
        self, feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
    ):
        feats = torch.cat([feats_1, feats_2], dim=0)
        max_node_idx_1 = sizes_1.sum()
        edge_index_2_offset = edge_index_2 + max_node_idx_1
        edge_index = torch.cat([edge_index_1, edge_index_2_offset], dim=1)
        batch = create_batch(torch.cat([sizes_1, sizes_2], dim=0))
        feats, edge_index, batch = (
            feats.to(self.args.device),
            edge_index.to(self.args.device),
            batch.to(self.args.device),
        )
        return feats, edge_index, batch

    def forward(self, feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2):
        self.layer_outputs = []
        self.layer_cross_attentions = []
        self.topk_outputs = []
        self.mincut = []
        self.norms_per_layer = []
        self.attention_sums_per_layer = []

        feats, edge_index, batch = self.combine_pair_embedding(
            feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
        )
        emb, _, _ = self.compute_emb(
            feats, edge_index, batch, sizes_1, sizes_2, edge_index_1, edge_index_2
        )
        emb_1 = emb[: emb.shape[0] // 2, :]
        emb_2 = emb[emb.shape[0] // 2 :, :]

        best_acc1, best_acc2 = 0.0, 0.0
        cluster1, cluster2 = None, None
        layer1, layer2 = None, None
        clusters = []

        return emb_1, emb_2, cluster1, cluster2, layer1, layer2

    def compute_metrics(self, emb_1, emb_2, labels):
        distances = torch.norm(emb_1 - emb_2, p=2, dim=1)
        loss = F.relu(self.margin - labels * (1 - distances)).mean()
        predicted_similar = torch.where(
            distances < self.args.margin,
            torch.ones_like(labels),
            -torch.ones_like(labels),
        )
        acc = (predicted_similar == labels).float().mean()
        metrics = {"loss": loss, "acc": acc}
        return metrics

    def init_metric_dict(self):
        return {"acc": -1, "f1": -1}

    def has_improved(self, m1, m2):
        return m1["acc"] < m2["acc"]

In [3]:
def train(model, optimizer, pairs, labels, batch_size, title):
    model.train()
    train_losses = []
    train_accuracies = []
    losses = []
    accs = []

    for i in range(len(pairs)):
        optimizer.zero_grad()

        graph1, graph2 = pairs[i]
        label = labels[i]

        feats_1, edge_index_1 = graph1.x, graph1.edge_index
        feats_2, edge_index_2 = graph2.x, graph2.edge_index
        sizes_1 = torch.tensor([graph1.num_nodes])
        sizes_2 = torch.tensor([graph2.num_nodes])

        emb_1, emb_2, _, _, _, _ = model(
            feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
        )

        metrics = model.compute_metrics(emb_1, emb_2, torch.tensor([label]))
        loss = metrics["loss"]
        acc = metrics["acc"]

        losses.append(loss)
        accs.append(acc)

        if i % batch_size == 0 and i > 0:
            batch_loss = torch.mean(torch.stack(losses))
            batch_acc = torch.mean(torch.stack(accs))
            losses = []
            accs = []
            train_losses.append(batch_loss.detach().numpy())
            train_accuracies.append(batch_acc.detach().numpy())
            batch_loss.backward()
            optimizer.step()

    # if train_accuracies[-1] > 0.0:
    #     plt.figure(figsize=(12, 5))
    #     plt.title(title)

    #     plt.subplot(1, 2, 1)
    #     plt.plot(train_losses, label="Training Loss")
    #     plt.title("Loss over Epochs")
    #     plt.xlabel("Epochs")
    #     plt.ylabel("Loss")
    #     plt.legend()

    #     plt.subplot(1, 2, 2)
    #     plt.plot(train_accuracies, label="Training Accuracy")
    #     plt.title("Accuracy over Epochs")
    #     plt.xlabel("Epochs")
    #     plt.ylabel("Accuracy")
    #     plt.legend()

    #     plt.show()

    # return train_accuracies[-1]

In [4]:
import networkx as nx
import torch
from torch_geometric.utils import degree
import time


class MCS:
    def __init__(self, mp):
        self.mp = mp
        self.max_size = 0
        self.all_mappings = []
        self.unique_mappings = set()
        self.visited = set()
        self.time_limit = 5
        self.start_time = None

    def find_mcs(self, G1, G2):
        self.max_size = 0
        self.all_mappings = []
        self.unique_mappings = set()
        self.visited = set()
        self.start_time = time.time()

        G1_degrees = degree(G1.edge_index[0], G1.num_nodes)
        G2_degrees = degree(G2.edge_index[0], G2.num_nodes)

        nodes1 = list(range(G1.num_nodes))
        nodes2 = list(range(G2.num_nodes))

        for n1 in nodes1:
            for n2 in nodes2:
                if (n1, n2) in self.mp:
                    M = {n1: n2}
                    neighbors1 = set(G1.edge_index[1][G1.edge_index[0] == n1].tolist())
                    neighbors2 = set(G2.edge_index[1][G2.edge_index[0] == n2].tolist())
                    self.match(
                        G1, G2, M, G1_degrees, G2_degrees, neighbors1, neighbors2
                    )
        return self.all_mappings

    def match(self, G1, G2, M, G1_degrees, G2_degrees, neighbors1, neighbors2):
        if time.time() - self.start_time > self.time_limit:
            return

        state = (frozenset(M.items()), frozenset(neighbors1), frozenset(neighbors2))
        if state in self.visited:
            return
        self.visited.add(state)

        edge_count = self.count_edges(M, G1, G2)

        if len(M) > self.max_size or (
            len(M) == self.max_size and edge_count > self.edge_count
        ):
            self.max_size = len(M)
            self.edge_count = edge_count
            self.all_mappings = [M.copy()]
            self.unique_mappings = {self.canonical_form(M)}
        elif len(M) == self.max_size and edge_count == self.edge_count:
            canonical = self.canonical_form(M)
            if canonical not in self.unique_mappings:
                self.all_mappings.append(M.copy())
                self.unique_mappings.add(canonical)

        candidates1 = sorted(neighbors1, key=lambda n: -G1_degrees[n].item())
        candidates2 = sorted(neighbors2, key=lambda n: -G2_degrees[n].item())

        for n1 in candidates1:
            if n1 not in M:
                for n2 in candidates2:
                    if n2 not in M.values() and self.feasible(n1, n2, M, G1, G2):
                        M[n1] = n2
                        new_neighbors1 = set(
                            G1.edge_index[1][G1.edge_index[0] == n1].tolist()
                        )
                        new_neighbors2 = set(
                            G2.edge_index[1][G2.edge_index[0] == n2].tolist()
                        )
                        neighbors1.update(new_neighbors1 - set(M.keys()))
                        neighbors2.update(new_neighbors2 - set(M.values()))
                        self.match(
                            G1,
                            G2,
                            M,
                            G1_degrees,
                            G2_degrees,
                            neighbors1,
                            neighbors2,
                        )
                        del M[n1]
                        neighbors1.difference_update(new_neighbors1)
                        neighbors2.difference_update(new_neighbors2)

    def feasible(self, n1, n2, M, G1, G2):
        if not torch.equal(G1.x[n1], G2.x[n2]):
            return False
        if (n1, n2) not in self.mp:
            return False

        count1 = 0
        count2 = 0

        for neighbor in G1.edge_index[1][G1.edge_index[0] == n1]:
            if neighbor.item() in M:
                count1 += 1

        for neighbor in G2.edge_index[1][G2.edge_index[0] == n2]:
            if neighbor.item() in M.values():
                count2 += 1

        if count1 != count2:
            return False

        for neighbor in G1.edge_index[1][G1.edge_index[0] == n1]:
            if (
                neighbor.item() in M
                and M[neighbor.item()]
                not in G2.edge_index[1][G2.edge_index[0] == n2].tolist()
            ):
                return False

        return True

    def count_edges(self, M, G1, G2):
        edge_count = 0
        for u1, v1 in M.items():
            for u2, v2 in M.items():
                if u1 != u2:
                    u1_v1_exists = (
                        (G1.edge_index[0] == u1) & (G1.edge_index[1] == u2)
                    ).any() or (
                        (G1.edge_index[0] == u2) & (G1.edge_index[1] == u1)
                    ).any()
                    u2_v2_exists = (
                        (G2.edge_index[0] == v1) & (G2.edge_index[1] == v2)
                    ).any() or (
                        (G2.edge_index[0] == v2) & (G2.edge_index[1] == v1)
                    ).any()
                    if u1_v1_exists and u2_v2_exists:
                        edge_count += 1
        return edge_count

    def canonical_form(self, M):
        G1_set = set(M.keys())
        G2_set = set(M.values())
        return (frozenset(G1_set), frozenset(G2_set))

In [5]:
def plot_attentions(graph1, graph2, attention_pairs, title=""):
    def get_node_labels(graph):
        node_labels = {}
        for node in range(graph.num_nodes):
            node_labels[node] = node
        return node_labels

    G1 = to_networkx(graph1, to_undirected=True)
    G2 = to_networkx(graph2, to_undirected=True)

    G_combined = nx.Graph()

    for n, d in G1.nodes(data=True):
        G_combined.add_node(n, **d)
    for n, d in G2.nodes(data=True):
        G_combined.add_node(n + len(G1.nodes), **d)

    G_combined.add_edges_from([(u, v) for u, v in G1.edges()])
    G_combined.add_edges_from(
        [(u + len(G1.nodes), v + len(G1.nodes)) for u, v in G2.edges()]
    )

    G_plain = G_combined.copy()

    for node1, node2 in attention_pairs:
        G_combined.add_edge(node1, node2 + len(G1.nodes))

    pos_G1 = nx.spring_layout(G1)
    pos_G2 = nx.spring_layout(G2)

    for key in pos_G1.keys():
        pos_G1[key] = [pos_G1[key][0] - 1.5, pos_G1[key][1]]

    for key in pos_G2.keys():
        pos_G2[key] = [pos_G2[key][0] + 1.5, pos_G2[key][1]]

    pos_combined = {**pos_G1, **{k + len(G1.nodes): v for k, v in pos_G2.items()}}

    node_labels_G1 = get_node_labels(graph1)
    node_labels_G2 = get_node_labels(graph2)
    node_labels_combined = {
        **node_labels_G1,
        **{k + len(G1.nodes): v for k, v in node_labels_G2.items()},
    }

    plt.figure(figsize=(12, 6))
    plt.gcf().patch.set_alpha(0)

    nx.draw(
        G_plain,
        pos=pos_combined,
        with_labels=False,
        labels=node_labels_combined,
        node_color="#3b8bc2",
        edge_color="black",
        node_size=500,
    )

    # attention_edges = [
    #     (node1, node2 + len(G1.nodes)) for node1, node2 in attention_pairs
    # ]
    # nx.draw_networkx_edges(
    #     G_combined,
    #     pos=pos_combined,
    #     edgelist=attention_edges,
    #     edge_color="lightgrey",
    # )

    plt.title(title)
    plt.show()

    plt.figure(figsize=(12, 6))
    plt.gcf().patch.set_alpha(0)

    nx.draw(
        G_combined,
        pos=pos_combined,
        with_labels=False,
        labels=node_labels_combined,
        node_color="#3b8bc2",
        edge_color="black",
        node_size=500,
    )

    attention_edges = [
        (node1, node2 + len(G1.nodes)) for node1, node2 in attention_pairs
    ]
    nx.draw_networkx_edges(
        G_combined,
        pos=pos_combined,
        edgelist=attention_edges,
        edge_color="lightgrey",
    )

    plt.title(title)
    plt.show()

In [6]:
def loop(graph1, graph2, attentions, t):
    most_nodes = 0
    largest_summary = None
    for i in range(3):
        attention_nodes = extract_dynamic_attention_nodes(attentions, threshold=t)
        mp = mutual_pairs(attention_nodes, i)
        vf2 = MCS(mp)
        patterns = vf2.find_mcs(graph1, graph2)

        if patterns != []:
            pattern = patterns[0]

            g1_subgraph, g2_subgraph = create_subgraphs(pattern, graph1, graph2)

            if (
                nx.is_isomorphic(
                    to_networkx(g1_subgraph, to_undirected=True),
                    to_networkx(g2_subgraph, to_undirected=True),
                )
                and len(pattern) > 2
            ):
                summary = g1_subgraph
                # plot_mutag(summary)

                if len(pattern) > most_nodes:
                    most_nodes = len(pattern)
                    largest_summary = summary
                    final_pattern = pattern

    return largest_summary, final_pattern

In [7]:
import os


def load_graphs(name, c="cycle"):
    with open(f"gnnexplainer_graphs/graph{i}_{c}.pkl", "rb") as f:
        original_G, _ = pickle.load(f)

    for node in original_G.nodes:
        if "self" in original_G.nodes[node]:
            del original_G.nodes[node]["self"]

    data = from_networkx(original_G)
    data.x = data.feat.float()
    # plot_graph(data)
    return data


cycles = []
completes = []
lines = []
stars = []
for i in range(10):
    cycles.append(load_graphs(i, "cycle"))
    completes.append(load_graphs(i, "complete"))
    lines.append(load_graphs(i, "line"))
    stars.append(load_graphs(i, "star"))

# (graph1, graph2, attentions) = torch.load(
#     "examples/method2_oof_v2_exact_cycle_8.pt", weights_only=False
# )

# summary, final_pattern = loop(graph1, graph1, attentions, 0.08)

# print(final_pattern)

# print(summary.x)
# plot_graph_pair(graph1, graph2)
# plot_graph(summary)

In [8]:
import random

numbers = list(range(10)) * 2
random.shuffle(numbers)
random_pairs = [(numbers[i], numbers[i + 1]) for i in range(0, len(numbers), 2)]

print(random_pairs)

[(6, 1), (9, 2), (8, 4), (5, 4), (0, 8), (7, 9), (6, 0), (7, 2), (3, 1), (5, 3)]


In [9]:
dataset = GraphDataset(torch.load("data/cycle_line_star_complete_1.pt"))


class NewArgs:
    def __init__(self, dim, num_layers, margin, lr, batch_size, num_pairs):
        self.dim = dim
        self.feat_dim = dataset.num_features
        self.num_layers = num_layers
        self.margin = margin
        self.lr = lr
        self.n_classes = dataset.num_classes
        self.batch_size = batch_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.n_clusters = 8
        self.num_pairs = num_pairs

  dataset = GraphDataset(torch.load("data/cycle_line_star_complete_1.pt"))


In [10]:
hyperparams = (32, 7, 0.2, 0.01, 64, 500)
newargs = NewArgs(*hyperparams)
model = GraphMatchingNetwork(newargs)
optimizer = Adam(model.parameters(), lr=newargs.lr, weight_decay=1e-5)
pairs, labels = create_graph_pairs(dataset, newargs.num_pairs)
train(model, optimizer, pairs, labels, newargs.batch_size, str(hyperparams))

In [11]:
def run(n, l, random_pairs, t=0.08):
    idx1, idx2 = random_pairs[n]
    graph1, graph2 = l[idx1], l[idx2]

    feats_1, edge_index_1 = graph1.x, graph1.edge_index
    feats_2, edge_index_2 = graph2.x, graph2.edge_index
    sizes_1 = torch.tensor([len(graph1.x)])
    sizes_2 = torch.tensor([len(graph2.x)])
    emb1, emb2, cluster1, cluster2, layer1, layer2 = model(
        feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
    )

    embeddings, attentions = extract_embeddings_and_attention(
        model, feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
    )

    br = False

    most_nodes = 0
    largest_summary = None
    t_of_largest = 0
    layer_of_largest = 0
    ts = []
    mps = None

    mps = None
    most_nodes = 0
    largest_summary = None
    t_of_largest = 0
    layer_of_largest = 0
    ts = []
    mps = None

    mps = None
    for i in range(7):
        attention_nodes = extract_dynamic_attention_nodes(attentions, threshold=t)
        mp = mutual_pairs(attention_nodes, i)
        vf2 = MCS(mp)
        patterns = vf2.find_mcs(graph1, graph2)
        if patterns != [] and len(patterns[0]) > 2:
            pattern = patterns[0]

            g1_subgraph, g2_subgraph = create_subgraphs(pattern, graph1, graph2)

            if (
                nx.is_isomorphic(
                    to_networkx(g1_subgraph, to_undirected=True),
                    to_networkx(g2_subgraph, to_undirected=True),
                )
                and len(pattern) > 2
            ):

                summary = g1_subgraph

                if len(pattern) > most_nodes:
                    most_nodes = len(pattern)
                    largest_summary = summary
                    t_of_largest = t
                    layer_of_largest = i + 1
                    mps = mp

    # plot_graph_pair(graph1, graph2)
    # plot_graph(largest_summary)
    return largest_summary, idx1, idx2


while True:
    hyperparams = (32, 7, 0.2, 0.01, 64, 500)
    newargs = NewArgs(*hyperparams)
    model = GraphMatchingNetwork(newargs)
    optimizer = Adam(model.parameters(), lr=newargs.lr, weight_decay=1e-5)
    pairs, labels = create_graph_pairs(dataset, newargs.num_pairs)
    train(model, optimizer, pairs, labels, newargs.batch_size, str(hyperparams))

    numbers = list(range(10)) * 2
    random.shuffle(numbers)
    random_pairs = [(numbers[i], numbers[i + 1]) for i in range(0, len(numbers), 2)]

    cycles_prototypes = []
    for i in range(10):
        cycles_prototypes.append(run(i, cycles, random_pairs))

    counter = 0
    for i in cycles_prototypes:
        if is_cycle(i[0]):
            counter += 1

    if counter >= 5:
        torch.save(cycles_prototypes, "finally.pt")
        print("OOOOO")
        break

# completes_prototypes = []
# for i in range(10):
#     completes_prototypes.append(run(i, completes, 0.1))

# lines_prototypes = []
# for i in range(10):
#     lines_prototypes.append(run(i, lines))

# stars_prototypes = []
# for i in range(10):
#     stars_prototypes.append(run(i, stars))

KeyboardInterrupt: 