In [None]:
indices = ((data_graph.y == 0).nonzero(as_tuple=True))[0].to(torch.float)
num_samples = 5
random_indices = torch.multinomial(indices, num_samples, replacement=True)
random_elements = indices[random_indices]
random_elements

indices = ((data_graph.y == 1).nonzero(as_tuple=True))[0].to(torch.float)
random_indices = torch.multinomial(indices, num_samples, replacement=True)
random_elements_2 = indices[random_indices]
ind = torch.cat((random_elements, random_elements_2), dim=0).to(torch.int)
ind.tolist()
feat = data_graph.x[ind]


In [None]:
from torch_geometric.explain import Explainer, GNNExplainer


explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)

exp = []
# explanations for each node
for i in ind:
    explanation = explainer(data_graph.x, data_graph.edge_index, index=i)
    exp.append(explanation)
############################

In [None]:
for e in exp:
    path = f'features/feature_importance{e.index}.png'
    e.visualize_feature_importance(path, top_k=50)

In [None]:
from torch_geometric.nn import MessagePassing, APPNP
from torch_geometric.utils import k_hop_subgraph

def subgraph(model, node_idx, x, edge_index, **kwargs):
    num_nodes, num_edges = x.size(0), edge_index.size(1)

    flow = 'source_to_target'
    for module in model.modules():
        if isinstance(module, MessagePassing):
            flow = module.flow
            break

    num_hops = 0
    for module in model.modules():
        if isinstance(module, MessagePassing):
            if isinstance(module, APPNP):
                num_hops += module.K
            else:
                num_hops += 1

    subset, edge_index, mapping, edge_mask = k_hop_subgraph(
        node_idx, num_hops, edge_index, relabel_nodes=True,
        num_nodes=num_nodes, flow=flow)

    x = x[subset]
    for key, item in kwargs:
        if torch.is_tensor(item) and item.size(0) == num_nodes:
            item = item[subset]
        elif torch.is_tensor(item) and item.size(0) == num_edges:
            item = item[edge_mask]
        kwargs[key] = item

    return x, edge_index, mapping, edge_mask, kwargs

def edge_mask_to_node_mask(data, edge_mask, aggregation="mean"):

    node_weights = torch.zeros(data.x.shape[0])

    if aggregation == "sum":

        for weight, nodes in zip(edge_mask, data.edge_index.T):
            node_weights[nodes[0].item()] += weight.item() / 2
            node_weights[nodes[1].item()] += weight.item() / 2

    elif aggregation == "mean":

        node_degrees = torch.zeros(data.x.shape[0])

        for weight, nodes in zip(edge_mask, data.edge_index.T):

            node_weights[nodes[0].item()] += weight.item()

            node_weights[nodes[1].item()] += weight.item()

            node_degrees[nodes[0].item()] += 1

            node_degrees[nodes[1].item()] += 1

        node_weights = node_weights / node_degrees.clamp(min=1.)

    elif aggregation == "max":

        for weight, nodes in zip(edge_mask, data.edge_index.T):

            node_weights[nodes[0].item()] = max(weight.item(), node_weights[nodes[0].item()])

            node_weights[nodes[1].item()] = max(weight.item(), node_weights[nodes[1].item()])

    else:

        raise NotImplementedError(f"No such aggregation method: {aggregation}")

    return node_weights

# convert the edge mask to node mask here
def generate_node_masks(explainer, nodes, data, aggregation="mean"):
    node_masks = []

    for node in nodes:
        _, edge_mask = explainer.explain_node(node, data.x, data.edge_index)
        node_mask = edge_mask_to_node_mask(data, edge_mask, aggregation)
        node_masks.append(node_mask)
        
    return node_masks

# The function to return the k-hop subgraph of the selected nodes
def fidelity(model,  # is a must
             node_idx,  # is a must
             full_feature_matrix,  # must
             edge_index=None,  # the whole, so data.edge_index
             node_mask=None,  # at least one of these three node, feature, edge
             feature_mask=None,
             edge_mask=None,
             samples=100,
             random_seed=12345,
             device="cpu"
             ):
    """
    Distortion/Fidelity (for Node Classification)
    :param model: GNN model which is explained
    :param node_idx: The node which is explained
    :param full_feature_matrix: The feature matrix from the Graph (X)
    :param edge_index: All edges
    :param node_mask: Is a (binary) tensor with 1/0 for each node in the computational graph
    => 1 means the features of this node will be fixed
    => 0 means the features of this node will be pertubed/randomized
    if not available torch.ones((1, num_computation_graph_nodes))
    :param feature_mask: Is a (binary) tensor with 1/0 for each feature
    => 1 means this features is fixed for all nodes with 1
    => 0 means this feature is randomized for all nodes
    if not available torch.ones((1, number_of_features))
    :param edge_mask:
    :param samples:
    :param random_seed:
    :param device:
    :param validity:
    :return:
    """
    if edge_mask is None and feature_mask is None and node_mask is None:
        raise ValueError("At least supply one mask")

    computation_graph_feature_matrix, computation_graph_edge_index, mapping, hard_edge_mask, kwargs = subgraph(model, node_idx, 
                                                                                                               full_feature_matrix, 
                                                                                                               edge_index)

    # get predicted label
    log_logits = model.forward(x=computation_graph_feature_matrix,
                               edge_index=computation_graph_edge_index)
    predicted_labels = log_logits.argmax(dim=-1)

    predicted_label = predicted_labels[mapping]

    # fill missing masks
    if feature_mask is None:
        (num_nodes, num_features) = full_feature_matrix.size()
        feature_mask= torch.ones((1, num_features), device=device)

    num_computation_graph_nodes = computation_graph_feature_matrix.size(0)
    if node_mask is None:
        # all nodes selected
        node_mask = torch.ones((1, num_computation_graph_nodes), device=device)


    # set edge mask
    if edge_mask is not None:
        for module in model.modules():
            if isinstance(module, MessagePassing):
                module.__explain__ = False
                module.__edge_mask__ = edge_mask
    (num_nodes, num_features) = full_feature_matrix.size()

    num_nodes_computation_graph = computation_graph_feature_matrix.size(0)

    # retrieve complete mask as matrix
    mask = node_mask.T.matmul(feature_mask)

    correct = 0.0

    rng = torch.Generator(device=device)
    rng.manual_seed(random_seed)
    random_indices = torch.randint(num_nodes, (samples, num_nodes_computation_graph, num_features),
                                   generator=rng,
                                   device=device,
                                   )
    random_indices = random_indices.type(torch.int64)
    
    
    ###################################################################################################
    # for each samples, add your code here to:
    for i in range(samples):

        #1. generate the perturbed input
        random_features = torch.gather(full_feature_matrix,
                                       dim=0,
                                       index=random_indices[i, :, :])

        randomized_features = mask * computation_graph_feature_matrix + (1 - mask) * random_features

        #2. get the prediction from the trained model using the perturbed features as input
        log_logits = model(x=randomized_features, edge_index=computation_graph_edge_index)

        #3. calculate the number of corrected predicted labels:
        distorted_labels = log_logits.argmax(dim=-1)
        if distorted_labels[mapping] == predicted_label:
            correct += 1       
    ###################################################################################################
    
    # reset mask
    if edge_mask is not None:
        for module in model.modules():
            if isinstance(module, MessagePassing):
                module.__explain__ = False
                module.__edge_mask__ = None

    return correct / samples

In [None]:
# for each node calculate the rdt fidelity for the feature and node mask

# Initialize empty lists to store the results
fidelity_scores = []
device = torch.device('cpu')

# Loop over each node
for idx, explanation in zip(ind, exp):
    # Generate masks for GNNExplainer
    node_feat_mask = explanation.node_mask
    # print(node_feat_mask)
    edge_mask = explanation.edge_mask

    node_mask = edge_mask_to_node_mask(data_graph, edge_mask, aggregation="mean")

    # Calculate fidelity
    fidelity_score = fidelity(model, 
                              node_idx=idx, 
                              full_feature_matrix=data_graph.x, 
                              edge_index=data_graph.edge_index, 
                              node_mask=node_mask, 
                              edge_mask=edge_mask,
                              feature_mask=node_feat_mask,
                              device=device)
    fidelity_scores.append(fidelity_score)

# Print out the results
for i, fid in zip(ind, fidelity_scores):
    print(f'Node {i} - Fidelity: {fid}')