In [1]:
import torch
import os
from torch_geometric.utils import to_networkx
import networkx as nx


def evaluate_prototype(predicted_prototype, ground_truth_prototype, max_extra_nodes=3):
    pred_nx = to_networkx(predicted_prototype, to_undirected=True)
    gt_nx = to_networkx(ground_truth_prototype, to_undirected=True)
    pred_nodes = set(pred_nx.nodes)
    gt_nodes = set(gt_nx.nodes)
    extra_nodes = pred_nodes - gt_nodes
    num_extra_nodes = len(extra_nodes)
    if num_extra_nodes > max_extra_nodes:
        return False
    subgraph_isomorphism = nx.is_isomorphic(gt_nx, pred_nx.subgraph(gt_nodes))
    if not subgraph_isomorphism:
        return False
    return True


prototype = torch.load("../../mutag0_ground_truth.pt", weights_only=False)

In [2]:
def plot_mutag(
    graph1,
    graph2=None,
    original_x1=None,
    perm1=None,
    original_x2=None,
    perm2=None,
    with_labels=False,
):
    import matplotlib.pyplot as plt
    import networkx as nx
    import matplotlib
    import matplotlib.patches as mpatches
    from torch_geometric.utils import to_networkx

    colormap = matplotlib.colormaps.get_cmap("Pastel1")

    color_map = {
        0: colormap(0),  # C
        1: colormap(1),  # O
        2: colormap(2),  # Cl
        3: colormap(3),  # H
        4: colormap(4),  # N
        5: colormap(5),  # F
        6: colormap(6),  # Br
        7: colormap(7),  # S
        8: colormap(8),  # P
        9: colormap(9),  # I
        "other": "gray",
    }

    atom_types = {
        0: "C",
        1: "O",
        2: "Cl",
        3: "H",
        4: "N",
        5: "F",
        6: "Br",
        7: "S",
        8: "P",
        9: "I",
    }

    def plot_single_graph(graph, ax, original_x=None, perm=None):
        G = to_networkx(graph, to_undirected=True)

        node_colors = []
        node_labels = {}

        if original_x is not None and perm is not None:
            mapped_x = original_x[perm[: graph.num_nodes]]
            for node in range(graph.num_nodes):
                one_hot = mapped_x[node].tolist()
                try:
                    node_type = one_hot.index(1)
                except ValueError:
                    node_type = "other"
                node_colors.append(color_map.get(node_type, "gray"))
                node_labels[node] = perm[node].item()
        else:
            for node in range(graph.num_nodes):
                one_hot = graph.x[node].tolist()
                try:
                    node_type = one_hot.index(1)
                except ValueError:
                    node_type = "other"
                node_colors.append(color_map.get(node_type, "gray"))
                node_labels[node] = perm[node].item() if perm is not None else node

        pos = nx.spring_layout(G)
        nx.draw(
            G,
            pos,
            node_color=node_colors,
            with_labels=with_labels,
            labels=node_labels,
            node_size=500,
            font_weight="bold",
            ax=ax,
        )

    legend_handles = [
        mpatches.Patch(color=colormap(0), label="C"),
        mpatches.Patch(color=colormap(1), label="O"),
        mpatches.Patch(color=colormap(2), label="Cl"),
        mpatches.Patch(color=colormap(3), label="H"),
        mpatches.Patch(color=colormap(4), label="N"),
        mpatches.Patch(color=colormap(5), label="F"),
        mpatches.Patch(color=colormap(6), label="Br"),
        mpatches.Patch(color=colormap(7), label="S"),
        mpatches.Patch(color=colormap(8), label="P"),
        mpatches.Patch(color=colormap(9), label="I"),
        mpatches.Patch(color="gray", label="Other"),
    ]

    if graph2 is None:
        fig, ax = plt.subplots(figsize=(8, 8))
        plot_single_graph(graph1, ax, original_x1, perm1)
    else:
        fig, axes = plt.subplots(1, 2, figsize=(24, 16))
        plot_single_graph(graph1, axes[0], original_x1, perm1)
        plot_single_graph(graph2, axes[1], original_x2, perm2)

    fig.legend(handles=legend_handles, loc="lower left", title="Node Types")
    plt.show()

In [3]:
# for _, graph in gnnexplainer_graphs:
#     plot_mutag(graph)

In [None]:
from explainer_main import *

results = []


def run(threshold, lr, num_epochs):
    main(
        thershold=threshold,
        lr=lr,
        num_epochs=num_epochs,
    )

    gnnexplainer_graphs = []

    for file_name in os.listdir("gnnexplainer_mutag0_graphs"):
        if file_name.endswith(".pt"):
            file_path = os.path.join("gnnexplainer_mutag0_graphs", file_name)
            pt_file = torch.load(file_path, weights_only=False)
            gnnexplainer_graphs.append(pt_file)

    for _, graph in gnnexplainer_graphs:
        graph.x = graph.feat

    for _, graph in gnnexplainer_graphs:
        plot_mutag(graph)


run(10, 0., 500)

In [9]:
gnnexplainer_graphs = []

for file_name in os.listdir("gnnexplainer_mutag0_graphs"):
    if file_name.endswith(".pt"):
        file_path = os.path.join("gnnexplainer_mutag0_graphs", file_name)
        pt_file = torch.load(file_path, weights_only=False)
        gnnexplainer_graphs.append(pt_file)

(
    sum(
        1
        for _, graph in gnnexplainer_graphs
        if graph and evaluate_prototype(graph, prototype)
    )
)

2