In [None]:
import itertools
import random

import torch
from torch.nn import Linear
from torch.nn import functional as F
from torch.nn.functional import cosine_similarity
from torch.optim import Adam
from torch_geometric.data import Data
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import BatchNorm, MessagePassing, TopKPooling
from torch_geometric.transforms import NormalizeFeatures
from torch_scatter import scatter_mean

from custom.args import grey, purple
from custom.dataset import GraphDataset, create_dataset
from custom.utils import *

In [None]:
class GraphMatchingConvolution(MessagePassing):
    def __init__(self, in_channels, out_channels, args, aggr="add"):
        super(GraphMatchingConvolution, self).__init__(aggr=aggr)
        self.args = args
        self.lin_node = torch.nn.Linear(in_channels, out_channels)
        self.lin_message = torch.nn.Linear(out_channels * 2, out_channels)
        self.lin_passing = torch.nn.Linear(out_channels + in_channels, out_channels)
        self.batch_norm = BatchNorm(out_channels)

    def forward(self, x, edge_index, batch):
        x_transformed = self.lin_node(x)
        return self.propagate(edge_index, x=x_transformed, original_x=x, batch=batch)

    def message(self, edge_index_i, x_i, x_j):
        x = torch.cat([x_i, x_j], dim=1)
        m = self.lin_message(x)
        return m

    def update(self, aggr_out, edge_index, x, original_x, batch):
        n_graphs = torch.unique(batch).shape[0]
        cross_graph_attention, a_x, a_y = batch_block_pair_attention(
            original_x, batch, n_graphs
        )
        attention_input = original_x - cross_graph_attention
        aggr_out = self.lin_passing(torch.cat([aggr_out, attention_input], dim=1))
        aggr_out = self.batch_norm(aggr_out)

        norms = torch.norm(aggr_out, p=2, dim=1)
        cross_attention_sums = cross_graph_attention.sum(dim=1)

        return (
            aggr_out,
            edge_index,
            batch,
            (
                attention_input,
                cross_graph_attention,
                a_x,
                a_y,
                norms,
                cross_attention_sums,
            ),
        )


class GraphAggregator(torch.nn.Module):
    def __init__(self, in_channels, out_channels, args):
        super(GraphAggregator, self).__init__()
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.lin_gate = torch.nn.Linear(in_channels, out_channels)
        self.lin_final = torch.nn.Linear(out_channels, out_channels)
        self.args = args

    def forward(self, x, edge_index, batch):
        x_states = self.lin(x)
        x_gates = torch.nn.functional.softmax(self.lin_gate(x), dim=1)
        x_states = x_states * x_gates
        x_states = scatter_mean(x_states, batch, dim=0)
        x_states = self.lin_final(x_states)
        return x_states

In [None]:
class GraphMatchingNetwork(torch.nn.Module):
    def __init__(self, args):
        super(GraphMatchingNetwork, self).__init__()
        self.args = args
        self.margin = self.args.margin
        if args.n_classes > 2:
            self.f1_average = "micro"
        else:
            self.f1_average = "binary"
        self.layers = torch.nn.ModuleList()
        self.layers.append(
            GraphMatchingConvolution(self.args.feat_dim, self.args.dim, args)
        )
        for _ in range(self.args.num_layers - 1):
            self.layers.append(
                GraphMatchingConvolution(self.args.dim, self.args.dim, args)
            )
        self.aggregator = GraphAggregator(self.args.dim, self.args.dim, self.args)
        self.layer_outputs = []
        self.layer_cross_attentions = []
        self.mincut = []
        self.mlp = torch.nn.Sequential()
        self.args.n_clusters = args.n_clusters
        self.mlp.append(Linear(self.args.dim, self.args.n_clusters))
        self.topk_outputs = []
        self.norms_per_layer = []
        self.attention_sums_per_layer = []

    def compute_emb(
        self, feats, edge_index, batch, sizes_1, sizes_2, edge_index_1, edge_index_2
    ):

        topk_pooling = TopKPooling(
            self.args.dim, ratio=min(sizes_1.item(), sizes_2.item())
        )
        for i in range(self.args.num_layers):
            (
                feats,
                edge_index,
                batch,
                (
                    attention_input,
                    cross_graph_attention,
                    a_x,
                    a_y,
                    norms,
                    attention_sums,
                ),
            ) = self.layers[i](feats, edge_index, batch)

            x_1 = feats[: sizes_1.item(), :]
            x_2 = feats[sizes_1.item() : sizes_1.item() + sizes_2.item(), :]

            norms_1 = norms[: sizes_1.item()]
            norms_2 = norms[sizes_1.item() : sizes_1.item() + sizes_2.item()]

            attention_sums_1 = attention_sums[: sizes_1.item()]
            attention_sums_2 = attention_sums[
                sizes_1.item() : sizes_1.item() + sizes_2.item()
            ]

            x_pooled_1, edge_index_pooled_1, _, _, perm1, score1 = topk_pooling(
                x_1,
                edge_index_1,
            )
            x_pooled_2, edge_index_pooled_2, _, _, perm2, score2 = topk_pooling(
                x_2,
                edge_index_2,
            )

            self.topk_outputs.append(
                (
                    (x_pooled_1, edge_index_pooled_1, perm1, score1),
                    (x_pooled_2, edge_index_pooled_2, perm2, score2),
                )
            )
            self.layer_cross_attentions.append((cross_graph_attention, a_x, a_y))
            self.layer_outputs.append((x_1, edge_index_1, x_2, edge_index_2))
            self.norms_per_layer.append((norms_1, norms_2))
            self.attention_sums_per_layer.append((attention_sums_1, attention_sums_2))

        feats = self.aggregator(feats, edge_index, batch)
        return feats, edge_index, batch

    def combine_pair_embedding(
        self, feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
    ):
        feats = torch.cat([feats_1, feats_2], dim=0)
        max_node_idx_1 = sizes_1.sum()
        edge_index_2_offset = edge_index_2 + max_node_idx_1
        edge_index = torch.cat([edge_index_1, edge_index_2_offset], dim=1)
        batch = create_batch(torch.cat([sizes_1, sizes_2], dim=0))
        feats, edge_index, batch = (
            feats.to(self.args.device),
            edge_index.to(self.args.device),
            batch.to(self.args.device),
        )
        return feats, edge_index, batch

    def forward(self, feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2):
        self.layer_outputs = []
        self.layer_cross_attentions = []
        self.topk_outputs = []
        self.mincut = []
        self.norms_per_layer = []
        self.attention_sums_per_layer = []

        feats, edge_index, batch = self.combine_pair_embedding(
            feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
        )
        emb, _, _ = self.compute_emb(
            feats, edge_index, batch, sizes_1, sizes_2, edge_index_1, edge_index_2
        )
        emb_1 = emb[: emb.shape[0] // 2, :]
        emb_2 = emb[emb.shape[0] // 2 :, :]

        best_acc1, best_acc2 = 0.0, 0.0
        cluster1, cluster2 = None, None
        layer1, layer2 = None, None
        for i in range(len(self.topk_outputs)):
            (
                (x_pooled_1, edge_index_pooled_1, perm1, score1),
                (x_pooled_2, edge_index_pooled_2, perm2, score2),
            ) = self.topk_outputs[i]
            acc1 = len(set(range(8)) & set(perm1.tolist()))
            acc2 = len(set(range(8)) & set(perm2.tolist()))
            if acc1 > best_acc1:
                cluster1 = Data(x=x_pooled_1, edge_index=edge_index_pooled_1)
                best_acc1 = acc1
                layer1 = i + 1
            if acc2 > best_acc2:
                cluster2 = Data(x=x_pooled_2, edge_index=edge_index_pooled_2)
                best_acc2 = acc2
                layer2 = i + 1

        return emb_1, emb_2, cluster1, cluster2, layer1, layer2

    def compute_metrics(self, emb_1, emb_2, labels):
        distances = torch.norm(emb_1 - emb_2, p=2, dim=1)
        loss = F.relu(self.margin - labels * (1 - distances)).mean()
        predicted_similar = torch.where(
            distances < self.args.margin,
            torch.ones_like(labels),
            -torch.ones_like(labels),
        )
        acc = (predicted_similar == labels).float().mean()
        metrics = {"loss": loss, "acc": acc}
        return metrics

    def init_metric_dict(self):
        return {"acc": -1, "f1": -1}

    def has_improved(self, m1, m2):
        return m1["acc"] < m2["acc"]

In [None]:
dataset = GraphDataset(torch.load("data/cycle_line_star_complete_1.pt"))
small_graphs, medium_graphs, large_graphs, classes = analyze_dataset(dataset)

In [None]:
def train(model, optimizer, pairs, labels, batch_size):
    model.train()
    train_losses = []
    train_accuracies = []
    losses = []
    accs = []

    def get_params(model):
        return {name: param.clone() for name, param in model.named_parameters()}

    initial_params = get_params(model)

    for i in range(len(pairs)):
        optimizer.zero_grad()

        graph1, graph2 = pairs[i]
        label = labels[i]

        feats_1, edge_index_1 = graph1.x, graph1.edge_index
        feats_2, edge_index_2 = graph2.x, graph2.edge_index
        sizes_1 = torch.tensor([graph1.num_nodes])
        sizes_2 = torch.tensor([graph2.num_nodes])

        emb_1, emb_2, _, _, _, _ = model(
            feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
        )

        metrics = model.compute_metrics(emb_1, emb_2, torch.tensor([label]))
        loss = metrics["loss"]
        acc = metrics["acc"]

        losses.append(loss)
        accs.append(acc)

        if i % batch_size == 0 and i > 0:
            batch_loss = torch.mean(torch.stack(losses))
            batch_acc = torch.mean(torch.stack(accs))
            losses = []
            accs = []
            train_losses.append(batch_loss.detach().numpy())
            train_accuracies.append(batch_acc.detach().numpy())
            batch_loss.backward()
            optimizer.step()
            # if i % 100 * batch_size == 0:
            #     print(
            #         f"Epoch: {i} - Loss: {batch_loss.item():.4f}, Acc: {batch_acc:.4f}"
            #     )

    trained_params = get_params(model)

    # for name in initial_params:
    #     initial_param = initial_params[name]
    #     trained_param = trained_params[name]
    #     if not torch.equal(initial_param, trained_param):
    #         print(f"Parameter {name} has changed.")
    #     else:
    #         print(f"Parameter {name} has NOT changed.")

    # plt.figure(figsize=(12, 5))

    # plt.subplot(1, 2, 1)
    # plt.plot(train_losses, label="Training Loss")
    # plt.title("Loss over Epochs")
    # plt.xlabel("Epochs")
    # plt.ylabel("Loss")
    # plt.legend()

    # plt.subplot(1, 2, 2)
    # plt.plot(train_accuracies, label="Training Accuracy")
    # plt.title("Accuracy over Epochs")
    # plt.xlabel("Epochs")
    # plt.ylabel("Accuracy")
    # plt.legend()

    # plt.show()

In [None]:
def test(model, title=""):
    class_clusters = []
    class_accs = []
    for i in range(4):
        c = f"class_{str(i)}"
        idx1 = random.sample(range(len(classes[c])), 1)[0]
        idx2 = random.sample(range(len(classes[c])), 1)[0]
        graph1, graph2 = classes[c][idx1], classes[c][idx2]

        model.eval()

        feats_1, edge_index_1 = graph1.x, graph1.edge_index
        feats_2, edge_index_2 = graph2.x, graph2.edge_index
        sizes_1 = torch.tensor([len(graph1.x)])
        sizes_2 = torch.tensor([len(graph2.x)])
        _, _, cluster1, cluster2, _, _ = model(
            feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
        )

        clusters = []
        accs = []

        for i in range(len(model.topk_outputs)):
            (
                (x_pooled_1, edge_index_pooled_1, perm1, score1),
                (x_pooled_2, edge_index_pooled_2, perm2, score2),
            ) = model.topk_outputs[i]
            clusters.append(
                (
                    Data(x=x_pooled_1, edge_index=edge_index_pooled_1),
                    Data(x=x_pooled_2, edge_index=edge_index_pooled_2),
                )
            )
            accs.append(
                (
                    len(set(range(8)) & set(perm1.tolist())),
                    len(set(range(8)) & set(perm2.tolist())),
                )
            )

        acc = list(itertools.chain.from_iterable(zip(*accs)))
        class_accs.extend(acc)

        cs = list(itertools.chain.from_iterable(zip(*clusters)))
        class_clusters.extend(cs)
        # plot_graph_pair(cluster1, cluster2)

    norm_barplot(model, 0)
    norm_barplot(model, 1)
    cross_barplot(model, 0)
    cross_barplot(model, 1)
    plot_all_classes(class_clusters, class_accs, title, model.args.num_layers)

In [None]:
from scipy.stats import mode


def best_k(model, threshold=0.8):
    def calculate_cumulative_scores(scores):
        normalized_scores = F.softmax(scores, dim=0)
        cumulative_scores = torch.cumsum(normalized_scores, dim=0)
        return cumulative_scores

    def calculate_confidence(cumulative_scores, k):
        gradients = np.gradient(cumulative_scores.detach().numpy())
        confidence = gradients[k - 1]
        return confidence

    def find_best_k_and_confidence(cumulative_scores, total_percentage):
        k = ((cumulative_scores / cumulative_scores[-1]) >= total_percentage).nonzero()[
            0
        ].item() + 1
        confidence = calculate_confidence(cumulative_scores, k)
        return k, confidence

    layer_scores_1 = []
    layer_scores_2 = []
    for i in range(len(model.topk_outputs)):
        (
            (_, _, _, score1),
            (_, _, _, score2),
        ) = model.topk_outputs[i]
        layer_scores_1.append(score1)
        layer_scores_2.append(score2)

    k_values_graph1 = []
    confidences_graph1 = []
    k_stds1 = []
    k_gradients1 = []
    for i, scores in enumerate(layer_scores_1):
        cumulative_scores = calculate_cumulative_scores(scores)
        best_k, confidence = find_best_k_and_confidence(cumulative_scores, threshold)
        k_values_graph1.append(best_k)
        mean_score = scores.mean()
        confidences_graph1.append(confidence)

    k_values_graph2 = []
    confidences_graph2 = []
    k_stds2 = []
    k_gradients2 = []
    for i, scores in enumerate(layer_scores_2):
        cumulative_scores = calculate_cumulative_scores(scores)
        best_k, confidence = find_best_k_and_confidence(cumulative_scores, threshold)
        k_values_graph2.append(best_k)
        confidences_graph2.append(confidence)
        std_dev = scores.std()
        k = (scores > mean_score - std_dev).sum().item()
        k_stds2.append(k)
        # gradients = scores.grad
        # k_grad = (gradients.abs() > 0.8).sum().item()
        # k_gradients2.append(k_grad)

    def calculate_weighted_average(ks, confidences):
        normalized_confidences = [float(i) / sum(confidences) for i in confidences]
        weighted_ks = sum(k * w for k, w in zip(ks, normalized_confidences))
        return int(round(weighted_ks))

    combined_ks = k_values_graph1 + k_values_graph2
    combined_confidences = confidences_graph1 + confidences_graph2

    overall_k_confidence = calculate_weighted_average(combined_ks, combined_confidences)
    overall_k_median = int(np.median(combined_ks))

    overall_k_mode = mode(combined_ks).mode

    return (
        overall_k_confidence,
        overall_k_median,
        k_values_graph1,
        k_values_graph2,
        overall_k_mode,
    )

In [None]:
def find_most_common_graph(graphs):
    nx_graphs = [to_networkx(graph, to_undirected=True) for graph in graphs]

    def graph_hash(G):
        sorted_edges = tuple(sorted((min(edge), max(edge)) for edge in G.edges))
        return sorted_edges

    hashes = [graph_hash(g) for g in nx_graphs]
    frequency = Counter(hashes)
    most_common_hash, _ = frequency.most_common(1)[0]
    for graph, h in zip(graphs, hashes):
        if h == most_common_hash:
            return graph

In [None]:
def calculate_weighted_confidence(scores, external_weights=None, w_avg=0.5, w_var=0.5):
    # Calculate average score
    average_score = scores.mean()

    # Calculate variance of the scores
    variance_score = scores.var()

    # If variance is zero (all scores are the same), handle division by zero in inverse variance calculation
    if variance_score == 0:
        inverse_variance = torch.tensor(0)
    else:
        inverse_variance = 1 / variance_score

    # Combine average score and inverse variance
    weighted_confidence = w_avg * average_score + w_var * inverse_variance

    # If there are external weights to include
    if external_weights is not None:
        weighted_confidence += (
            external_weights.sum() * 0.1
        )  # Assuming external_weights is also a tensor

    return weighted_confidence.item()

In [None]:
def find_min_total_ged_graph(graphs):
    """
    Finds the graph that has the smallest total graph edit distance (GED) to all other graphs in the list.

    Parameters:
        graphs (list of Data): The list of PyTorch Geometric graph data instances.

    Returns:
        Data: The PyTorch Geometric graph with the smallest total GED.
    """

    # Convert PyTorch Geometric graphs to NetworkX graphs
    nx_graphs = [to_networkx(graph, to_undirected=True) for graph in graphs]

    # Initialize a list to store the total GED for each graph
    total_geds = [0] * len(nx_graphs)

    # Compute the GED between each pair of graphs
    for i, graph_i in enumerate(nx_graphs):
        for j, graph_j in enumerate(nx_graphs):
            if i != j:
                # Initialize a variable to store the minimum GED for the current pair
                min_ged = float("inf")
                # Iterate over the generator to find the minimal GED
                for ged_estimate in nx.optimize_graph_edit_distance(graph_i, graph_j):
                    min_ged = min(min_ged, ged_estimate)
                # Add the minimal GED to the total GED of the current graph
                total_geds[i] += min_ged

    # Find the index of the graph with the smallest total GED
    min_ged_index = total_geds.index(min(total_geds))

    # Return the corresponding PyTorch Geometric graph
    return graphs[min_ged_index]

In [None]:
def select_cluster(model, k_threshold=0.8):
    accs = []
    cs = []
    clustered_graphs = []

    k_confidence, k_median, k_values_graph1, k_values_graph2, k_mode = best_k(
        model, k_threshold
    )

    best_acc1, best_acc2 = 0.0, 0.0
    cluster1, cluster2 = None, None
    layer1, layer2 = None, None

    layers = []
    sims = []

    for i in range(len(model.layer_outputs)):
        (x_1, edge_index_1, x_2, edge_index_2) = model.layer_outputs[i]

        topk_pooling = TopKPooling(model.args.dim, ratio=8)

        x_pooled_1, edge_index_pooled_1, _, _, perm1, score1 = topk_pooling(
            x_1,
            edge_index_1,
        )
        x_pooled_2, edge_index_pooled_2, _, _, perm2, score2 = topk_pooling(
            x_2,
            edge_index_2,
        )

        acc1 = len(set(range(8)) & set(perm1.tolist()))
        acc2 = len(set(range(8)) & set(perm2.tolist()))

        layers.append(
            (
                (x_pooled_1, edge_index_pooled_1, score1),
                (x_pooled_2, edge_index_pooled_2, score2),
                i,
            )
        )

        clustered_graphs.append(
            (
                (
                    Data(x=x_pooled_1, edge_index=edge_index_pooled_1),
                    # abs(score1).sum().item(),
                    calculate_weighted_confidence(score1),
                ),
                (
                    Data(x=x_pooled_2, edge_index=edge_index_pooled_2),
                    # abs(score2).sum().item(),
                    calculate_weighted_confidence(score2),
                ),
            )
        )

        if acc1 > best_acc1:
            x = x_pooled_1
            cluster1 = Data(x=x_pooled_1, edge_index=edge_index_pooled_1)
            best_acc1 = acc1
            layer1 = i + 1
        if acc2 > best_acc2:
            cluster2 = Data(x=x_pooled_2, edge_index=edge_index_pooled_2)
            best_acc2 = acc2
            layer2 = i + 1

    cs.append((layers, layer1, layer2))

    accs.append(best_acc1)
    accs.append(best_acc2)

    connected_1 = []
    connected_2 = []
    all_clusters = []

    for i in range(len(clustered_graphs)):
        (g1, s1), (g2, s2) = clustered_graphs[i]
        if nx.is_connected(to_networkx(g1, to_undirected=True)):
            connected_1.append((g1, s1))
        if nx.is_connected(to_networkx(g2, to_undirected=True)):
            connected_2.append((g2, s2))
        all_clusters.append((g1, s1))
        all_clusters.append((g2, s2))

    graphs = []

    for i, (g1, s1) in enumerate(connected_1):
        for j, (g2, s2) in enumerate(connected_2):
            if nx.is_isomorphic(
                to_networkx(g1, to_undirected=True), to_networkx(g2, to_undirected=True)
            ):
                graphs.append(
                    (Data(x=(g1.x + g2.x) / 2, edge_index=g1.edge_index), i + 1, j + 1)
                )

    # for graph in graphs:
    # plot_graph(graph[0])

    connected = connected_1 + connected_2
    connected = sorted(connected, key=lambda x: x[1])

    if graphs != []:
        return (
            find_most_common_graph([g[0] for g in graphs]),
            k_confidence,
            cluster1,
            cluster2,
            k_mode,
            "isomorphic",
        )
    elif connected_1 != [] or connected_2 != []:
        return (
            # random.choice(connected_1 + connected_2)[0],
            connected[0][0],
            k_confidence,
            cluster1,
            cluster2,
            k_mode,
            "connected",
        )
    else:
        return (
            random.choice(all_clusters)[0],
            k_confidence,
            cluster1,
            cluster2,
            k_mode,
            "random",
        )

In [None]:
def acc_test(model, print_results=True, k_threshold=0.8):
    correct_class0 = 0
    correct_class1 = 0
    correct_class2 = 0
    correct_class3 = 0
    correct_class0_new = 0
    correct_class1_new = 0
    correct_class2_new = 0
    correct_class3_new = 0
    best_class0 = 0
    best_class1 = 0
    best_class2 = 0
    best_class3 = 0
    layers_class0 = []
    layers_class1 = []
    layers_class2 = []
    layers_class3 = []
    ks = []
    ks_mode = []
    ts = {"isomorphic": 0, "connected": 0, "random": 0}
    correct_isomorphic = 0
    correct_connected = 0

    for _ in range(500):
        for i in range(4):
            c = f"class_{str(i)}"
            idx1 = random.sample(range(len(classes[c])), 1)[0]
            idx2 = random.sample(range(len(classes[c])), 1)[0]
            graph1, graph2 = classes[c][idx1], classes[c][idx2]

            model.eval()

            feats_1, edge_index_1 = graph1.x, graph1.edge_index
            feats_2, edge_index_2 = graph2.x, graph2.edge_index
            sizes_1 = torch.tensor([len(graph1.x)])
            sizes_2 = torch.tensor([len(graph2.x)])
            emb1, emb2, cluster1, cluster2, layer1, layer2 = model(
                feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
            )

            cluster, k, cluster1, cluster2, k_mode, t = select_cluster(
                model, k_threshold
            )
            ks.append(k)
            ks_mode.append(k_mode)
            ts[t] += 1

            if c == "class_0":
                layers_class0.append(layer1)
                layers_class0.append(layer2)
                correct_class0 += is_cycle(cluster1) + is_cycle(cluster2)
                best_class0 += any([is_cycle(cluster1), is_cycle(cluster2)])
                correct_class0_new += is_cycle(cluster)
                if t == "isomorphic":
                    correct_isomorphic += is_cycle(cluster)
                elif t == "connected":
                    correct_connected += is_cycle(cluster)
            elif c == "class_1":
                layers_class1.append(layer1)
                layers_class1.append(layer2)
                correct_class1 += is_complete(cluster1) + is_complete(cluster2)
                best_class1 += any([is_complete(cluster1), is_complete(cluster2)])
                correct_class1_new += is_complete(cluster)
                if t == "isomorphic":
                    correct_isomorphic += is_complete(cluster)
                elif t == "connected":
                    correct_connected += is_complete(cluster)
            elif c == "class_2":
                layers_class2.append(layer1)
                layers_class2.append(layer2)
                correct_class2 += is_line(cluster1) + is_line(cluster2)
                best_class2 += any([is_line(cluster1), is_line(cluster2)])
                correct_class2_new += is_line(cluster)
                if t == "isomorphic":
                    correct_isomorphic += is_line(cluster)
                elif t == "connected":
                    correct_connected += is_line(cluster)
            elif c == "class_3":
                layers_class3.append(layer1)
                layers_class3.append(layer2)
                correct_class3 += is_star(cluster1) + is_star(cluster2)
                best_class3 += any([is_star(cluster1), is_star(cluster2)])
                correct_class3_new += is_star(cluster)
                if t == "isomorphic":
                    correct_isomorphic += is_star(cluster)
                elif t == "connected":
                    correct_connected += is_star(cluster)

    class0_acc = correct_class0 / 10
    class1_acc = correct_class1 / 10
    class2_acc = correct_class2 / 10
    class3_acc = correct_class3 / 10
    overall_acc = (class0_acc + class1_acc + class2_acc + class3_acc) / 4
    class0_acc_new = correct_class0_new / 5
    class1_acc_new = correct_class1_new / 5
    class2_acc_new = correct_class2_new / 5
    class3_acc_new = correct_class3_new / 5
    overall_acc_new = (
        class0_acc_new + class1_acc_new + class2_acc_new + class3_acc_new
    ) / 4
    best_class0_acc = best_class0 / 5
    best_class1_acc = best_class1 / 5
    best_class2_acc = best_class2 / 5
    best_class3_acc = best_class3 / 5
    best_overall_acc = (
        best_class0_acc + best_class1_acc + best_class2_acc + best_class3_acc
    ) / 4
    # plot_layer_barplot(
    #     layers_class0, layers_class1, layers_class2, layers_class3, num_layers
    # )
    unique, counts = np.unique(ks, return_counts=True)
    counts = counts / len(ks)
    results = dict(zip(unique, counts))

    unique_mode, counts_mode = np.unique(ks_mode, return_counts=True)
    counts_mode = counts_mode / len(ks_mode)
    results_mode = dict(zip(unique_mode, counts_mode))
    if print_results:
        print(f"Correct cycle predictions: {class0_acc:.1f}% ({best_class0_acc:.1f}%)")
        print(
            f"Correct complete predictions: {class1_acc:.1f}% ({best_class1_acc:.1f}%)"
        )
        print(f"Correct line predictions: {class2_acc:.1f}% ({best_class2_acc:.1f}%)")
        print(f"Correct star predictions: {class3_acc:.1f}% ({best_class3_acc:.1f}%)")
        print(f"Overall accuracy: {overall_acc:.1f}% ({best_overall_acc:.1f}%)")
        print("-")
        print(f"New correct cycle predictions: {class0_acc_new:.1f}%")
        print(f"New correct complete predictions: {class1_acc_new:.1f}%")
        print(f"New correct line predictions: {class2_acc_new:.1f}%")
        print(f"New correct star predictions: {class3_acc_new:.1f}%")
        print(f"New overall accuracy: {overall_acc_new:.1f}%")
        print("-")
        print(f"Selected k: {results}")
        print(f"Selected k mode: {results_mode}")
        print(f"Types of selected graphs: {ts}")
        print(f"Correct isomorphic: {correct_isomorphic}")
        print(f"Correct connected: {correct_connected}")
    return (class0_acc, class1_acc, class2_acc, class3_acc, overall_acc)

In [None]:
class NewArgs:
    def __init__(self, dim, num_layers, margin, lr, batch_size, num_pairs):
        self.dim = dim
        self.feat_dim = dataset.num_features
        self.num_layers = num_layers
        self.margin = margin
        self.lr = lr
        self.n_classes = dataset.num_classes
        self.batch_size = batch_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.n_clusters = 8
        self.num_pairs = num_pairs

In [None]:
top_10 = [
    (32, 7, 0.5, 0.0001, 128, 3000),
    (32, 7, 0.2, 0.01, 64, 500),
    (32, 5, 0.2, 0.01, 32, 3000),
    (32, 8, 0.3, 0.0001, 64, 1000),
    (32, 7, 0.5, 0.01, 32, 3000),
    (32, 7, 0.4, 0.0001, 32, 500),
    (32, 5, 0.1, 0.01, 64, 3000),
    (32, 7, 0.4, 0.0001, 128, 3000),
    (32, 7, 0.5, 0.01, 128, 1000),
    (32, 8, 0.3, 0.0001, 64, 500),
]

In [None]:
for hyperparams in top_10:
    newargs = NewArgs(*hyperparams)
    model = GraphMatchingNetwork(newargs)
    optimizer = Adam(model.parameters(), lr=newargs.lr, weight_decay=1e-5)
    pairs, labels = create_graph_pairs(dataset, newargs.num_pairs)
    train(model, optimizer, pairs, labels, newargs.batch_size)
    acc_test(model, True, 0.8)
    print("-------------------------------------------------------------")

In [None]:
def find_joint_clusters(
    clusters_graph1, clusters_graph2, node_features1, node_features2
):
    # Example: Find clusters with the highest overlap in average features
    max_overlap = 0
    best_cluster_pair = None
    for cluster1 in clusters_graph1:
        for cluster2 in clusters_graph2:
            # Calculate feature overlap or intersection
            overlap = cosine_similarity(
                node_features1[cluster1].mean(0, keepdim=True),
                node_features2[cluster2].mean(0, keepdim=True),
            )
            if overlap > max_overlap:
                max_overlap = overlap
                best_cluster_pair = (cluster1, cluster2)
    return best_cluster_pair

In [None]:
model.eval()

c = f"class_0"
idx1 = random.sample(range(len(classes[c])), 1)[0]
idx2 = random.sample(range(len(classes[c])), 1)[0]
graph1, graph2 = classes[c][idx1], classes[c][idx2]

feats_1, edge_index_1 = graph1.x, graph1.edge_index
feats_2, edge_index_2 = graph2.x, graph2.edge_index
sizes_1 = torch.tensor([len(graph1.x)])
sizes_2 = torch.tensor([len(graph2.x)])
emb1, emb2, cluster1, cluster2, layer1, layer2 = model(
    feats_1, edge_index_1, feats_2, edge_index_2, sizes_1, sizes_2
)

ks_confidence = []
ks_median = []
accs = []
cs = []
clustered_graphs = []
clusters1 = []
clusters2 = []

(k_confidence, k_median, k_values_graph1, k_values_graph, k_mode) = best_k(model)
ks_confidence.append(k_confidence)
ks_median.append(k_median)

best_acc1, best_acc2 = 0.0, 0.0
cluster1, cluster2 = None, None
layer1, layer2 = None, None

layers = []
sims = []

for i in range(len(model.layer_outputs)):
    (x_1, edge_index_1, x_2, edge_index_2) = model.layer_outputs[i]

    topk_pooling = TopKPooling(model.args.dim, ratio=8)

    x_pooled_1, edge_index_pooled_1, _, _, perm1, score1 = topk_pooling(
        x_1,
        edge_index_1,
    )
    x_pooled_2, edge_index_pooled_2, _, _, perm2, score2 = topk_pooling(
        x_2,
        edge_index_2,
    )

    (x_pooled_1_model, edge_index_pooled_1_model, perm1_model, score1_model), (
        x_pooled_2_model,
        edge_index_pooled_2_model,
        perm2_model,
        score2_model,
    ) = model.topk_outputs[i]

    acc1 = len(set(range(8)) & set(perm1.tolist()))
    acc2 = len(set(range(8)) & set(perm2.tolist()))

    # print(x_pooled_1_model == x_pooled_1)
    # print(edge_index_pooled_1_model == edge_index_pooled_1)
    # print(x_pooled_2_model == x_pooled_2)
    # print(edge_index_pooled_2_model == edge_index_pooled_2)

    layers.append(
        (
            (x_pooled_1, edge_index_pooled_1, score1),
            (x_pooled_2, edge_index_pooled_2, score2),
            i,
        )
    )

    clustered_graphs.append(
        (
            (
                Data(x=x_pooled_1, edge_index=edge_index_pooled_1),
                # abs(score1).sum().item(),
                calculate_weighted_confidence(score1),
            ),
            (
                Data(x=x_pooled_2, edge_index=edge_index_pooled_2),
                # abs(score2).sum().item(),
                calculate_weighted_confidence(score2),
            ),
        )
    )

    clusters1.append(perm1)
    clusters2.append(perm2)

    if acc1 > best_acc1:
        x = x_pooled_1
        cluster1 = Data(x=x_pooled_1, edge_index=edge_index_pooled_1)
        best_acc1 = acc1
        layer1 = i + 1
    if acc2 > best_acc2:
        cluster2 = Data(x=x_pooled_2, edge_index=edge_index_pooled_2)
        best_acc2 = acc2
        layer2 = i + 1

cs.append((layers, layer1, layer2))

accs.append(best_acc1)
accs.append(best_acc2)

connected_1 = []
connected_2 = []

for i in range(len(clustered_graphs)):
    (g1, s1), (g2, s2) = clustered_graphs[i]
    if nx.is_connected(to_networkx(g1, to_undirected=True)):
        connected_1.append((g1, s1))
    if nx.is_connected(to_networkx(g2, to_undirected=True)):
        connected_2.append((g2, s2))

graphs = []

for i, (g1, s1) in enumerate(connected_1):
    for j, (g2, s2) in enumerate(connected_2):
        if nx.is_isomorphic(
            to_networkx(g1, to_undirected=True), to_networkx(g2, to_undirected=True)
        ):
            graphs.append(
                (Data(x=(g1.x + g2.x) / 2, edge_index=g1.edge_index), i + 1, j + 1)
            )

# for graph in graphs:
#     plot_graph(graph[0], title=f"({str(graph[1])}, {str(graph[2])})")

In [None]:
find_joint_clusters(clusters1, clusters2, feats_1, feats_2)

In [None]:
# ls, l1, l2 = cs[0]
# print(f"Graph 1 layer: {l1}")
# print(f"Graph 2 layer: {l2}")
# clusters = []
# for i, layer in enumerate(ls):
#     (x1, e1, s1), (x2, e2, s2), l = layer
#     print(x1.shape)
#     g1 = to_nx(x1, e1)
#     g2 = to_nx(x2, e2)
#     clusters.append(g1)
#     print(nx.is_connected(g1), nx.is_connected(g2))
#     print(nx.is_isomorphic(g1, g2))
#     if i > 0:
#         print(sims[i])
#     # plot_graph_pair(Data(x=x1, edge_index=e1), Data(x=x2, edge_index=e2))

In [None]:
# unique_confidence, counts_confidence = np.unique(ks_confidence, return_counts=True)
# results_confidence = dict(zip(unique_confidence, counts_confidence))

# unique_median, counts_median = np.unique(ks_median, return_counts=True)
# results_median = dict(zip(unique_median, counts_median))

# unique_accs, count_accs = np.unique(accs, return_counts=True)
# count_accs = count_accs / len(accs)
# results_accs = dict(zip(unique_accs, count_accs))

# print(f"Accs: {results_accs}")
# print(f"Confidence: {results_confidence}")
# print(f"Median: {results_median}")

In [None]:
# best_acc1, best_acc2 = 0.0, 0.0
# cluster1, cluster2 = None, None
# layer1, layer2 = None, None

# topk_pooling = TopKPooling(model.args.dim, ratio=8)

# for i in range(len(model.topk_outputs)):
#     (
#         (x_pooled_1, edge_index_pooled_1, perm1, score1),
#         (x_pooled_2, edge_index_pooled_2, perm2, score2),
#     ) = model.topk_outputs[i]
#     acc1 = len(set(range(8)) & set(perm1.tolist()))
#     acc2 = len(set(range(8)) & set(perm2.tolist()))
#     if acc1 > best_acc1:
#         cluster1 = Data(x=x_pooled_1, edge_index=edge_index_pooled_1)
#         best_acc1 = acc1
#         layer1 = i + 1
#     if acc2 > best_acc2:
#         cluster2 = Data(x=x_pooled_2, edge_index=edge_index_pooled_2)
#         best_acc2 = acc2
#         layer2 = i + 1

In [None]:
norms = torch.norm(model.layer_cross_attentions[4][0], p=2, dim=1)
print(norms)

In [None]:
def combine_graphs(graph1, graph2, mapping):
    x1, edge_index1 = graph1.x, graph1.edge_index
    x2, edge_index2 = graph2.x, graph2.edge_index

    new_x_size = x1.size(0) + x2.size(0) - len(mapping)
    new_x = torch.zeros((new_x_size, x1.size(1)))

    new_index_map = {}
    current_index = 0

    for idx1, idx2 in mapping.items():
        new_x[current_index] = (x1[idx1] + x2[idx2]) / 2
        new_index_map[idx1] = current_index
        new_index_map[x1.size(0) + idx2] = current_index
        current_index += 1

    for idx1 in range(x1.size(0)):
        if idx1 not in mapping:
            new_x[current_index] = x1[idx1]
            new_index_map[idx1] = current_index
            current_index += 1

    for idx2 in range(x2.size(0)):
        if idx2 not in mapping.values():
            new_x[current_index] = x2[idx2]
            new_index_map[x1.size(0) + idx2] = current_index
            current_index += 1

    new_edge_list = []
    for edge in edge_index1.t():
        new_edge_list.append(
            [new_index_map[edge[0].item()], new_index_map[edge[1].item()]]
        )
    offset = x1.size(0)
    for edge in edge_index2.t():
        new_edge_list.append(
            [
                new_index_map[offset + edge[0].item()],
                new_index_map[offset + edge[1].item()],
            ]
        )

    new_edge_index = torch.tensor(new_edge_list).t().contiguous()

    combined_graph = Data(x=new_x, edge_index=new_edge_index)

    return combined_graph

In [None]:
for i in range(len(clustered_graphs)):
    (g1, s1), (g2, s2) = clustered_graphs[i]
    # (x_1, edge_index_1, x_2, edge_index_2) = model.layer_outputs[i]
    # g1 = Data(x=x_1, edge_index=edge_index_1)
    # g2 = Data(x=x_2, edge_index=edge_index_2)
    # print(g1.num_nodes, g2.num_nodes)
    G1 = to_networkx(g1, to_undirected=True)
    G2 = to_networkx(g2, to_undirected=True)
    ismags = nx.isomorphism.ISMAGS(G1, G2)
    largest_common_subgraph = list(ismags.largest_common_subgraph(symmetry=False))
    # print(largest_common_subgraph)
    plot_graph_pair(g1, g2, "Graph 1", "Graph 2")
    if nx.is_connected(G1) and nx.is_connected(G2):
        for i in range(len(largest_common_subgraph)):
            combined_data = combine_graphs(g1, g2, largest_common_subgraph[i])

            combined_clustered_x, combined_clustered_edge_index, _, _, _, _ = (
                topk_pooling(combined_data.x, combined_data.edge_index)
            )

            combined_clustered = Data(
                x=combined_clustered_x, edge_index=combined_clustered_edge_index
            )

            if nx.is_connected(to_networkx(combined_clustered, to_undirected=True)):
                plot_graph_pair(
                    combined_data, combined_clustered, "Combined", "Combined-Clustered"
                )
    print("---------------------------------------------------------------")