# Install Pytorch Geometric

In [None]:
import torch
torch.manual_seed(42)
from IPython.display import clear_output
torch_version = torch.__version__
print("Torch version: ", torch_version)
pytorch_version = f"torch-{torch.__version__}.html"
!pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install torch-geometric
clear_output()
print("Done.")

Done.


In [None]:
import torch

if torch.cuda.is_available():
    print("CUDA is available. Running on GPU.")
else:
    print("CUDA is not available. Running on CPU.")


CUDA is available. Running on GPU.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GATConv

In [None]:
print("Torch version: ", torch_version)

Torch version:  2.4.1+cu121


In [None]:
import os
import pandas as pd
import torch
from torch_geometric.data import Dataset, Data

In [None]:
import h5py
import numpy as np
import pickle

In [None]:
import matplotlib.pyplot as plt

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

Mounted at /content/drive


# Load dataset

In [None]:
load_path = '/content/drive/MyDrive/CV+GNN/Dataset_Classification/Classification_5/1_13safe/normalized_graph_data.pkl'
with open(load_path, 'rb') as f:
    normalized_graph_data = pickle.load(f)

In [None]:
class MyGraphDataset(Dataset):
    def __init__(self, normalized_graph_data):
        self.normalized_graph_data = normalized_graph_data

    def __len__(self):
        return len(self.normalized_graph_data)

    def __getitem__(self, idx):
        return self.normalized_graph_data[idx]

dataset = MyGraphDataset(normalized_graph_data)

In [None]:
len(normalized_graph_data)

125000

In [None]:
from torch.utils.data import Subset
train_idx = np.load('/content/drive/MyDrive/CV+GNN/Dataset_Classification/Classification_5/1_13safe/1_13GIN/best_fold_train_idx.npy')
val_idx = np.load('/content/drive/MyDrive/CV+GNN/Dataset_Classification/Classification_5/1_13safe/1_13GIN/best_fold_val_idx.npy')

train_subset = Subset(dataset, train_idx)
val_subset = Subset(dataset, val_idx)

train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)



In [None]:
len(train_subset)

100000

In [None]:
len(val_subset)

25000

In [None]:
train_subset[0]

Data(x=[12, 15], edge_index=[2, 21], edge_attr=[21, 1], y=[1], graph_id='513d6906fdc9f035870045a2')

# **Define Model**

In [None]:
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_mean_pool
from torch.nn import Linear, Sequential, ReLU

class GNN(torch.nn.Module):
    def __init__(self, num_features, embedding_size=64, dropout_rate=0.3, num_classes=5):
        super(GNN, self).__init__()
        nn1 = Sequential(Linear(num_features, embedding_size), ReLU(), Linear(embedding_size, embedding_size))
        self.initial_conv = GINConv(nn1)
        nn2 = Sequential(Linear(embedding_size, embedding_size), ReLU(), Linear(embedding_size, embedding_size))
        self.conv1 = GINConv(nn2)

        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.out = Linear(embedding_size, num_classes)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.initial_conv(x, edge_index))
        x = self.dropout(x)
        x = F.relu(self.conv1(x, edge_index))
        x = self.dropout(x)
        x = global_mean_pool(x, batch)
        x = self.out(x)
        return x

# **Load Model**

In [None]:
complete_model_save_path = '/content/drive/MyDrive/CV+GNN/Dataset_Classification/Classification_5/1_13safe/1_13GIN/complete_model_1.13.pth'
model = torch.load(complete_model_save_path)

model.eval()

  model = torch.load(complete_model_save_path)


GNN(
  (initial_conv): GINConv(nn=Sequential(
    (0): Linear(in_features=15, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
  ))
  (conv1): GINConv(nn=Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
  ))
  (dropout): Dropout(p=0.3, inplace=False)
  (out): Linear(in_features=64, out_features=5, bias=True)
)

# **PGExplainer**

# Define Explainer

In [None]:
from torch_geometric.explain import Explainer, PGExplainer

In [None]:
explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.00005).to(device), #lr=0.0001
    explanation_type='phenomenon',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='raw',
    )
)

# Training

In [None]:
for epoch in range(30):
    running_loss = 0.0
    for batch_idx, batch in enumerate(train_loader):
        batch = batch.to(device)

        batch.y = batch.y.view(-1).long().to(device)

        loss = explainer.algorithm.train(epoch, model, batch.x, batch.edge_index, batch=batch.batch, target=batch.y)

        running_loss += loss.item() if isinstance(loss, torch.Tensor) else loss

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Average Loss: {avg_loss}")

Epoch 1, Average Loss: 14.473759732131958
Epoch 2, Average Loss: 3.3510645901489258
Epoch 3, Average Loss: 2.0838884338378905
Epoch 4, Average Loss: 1.8982557767868042
Epoch 5, Average Loss: 1.8633137602615357
Epoch 6, Average Loss: 1.8556209352111817
Epoch 7, Average Loss: 1.8535144194412232
Epoch 8, Average Loss: 1.8528955278778076
Epoch 9, Average Loss: 1.8526734577560424
Epoch 10, Average Loss: 1.8526060467910768
Epoch 11, Average Loss: 1.8525496947479247
Epoch 12, Average Loss: 1.852520556564331
Epoch 13, Average Loss: 1.8525099988937377
Epoch 14, Average Loss: 1.8524930281829834
Epoch 15, Average Loss: 1.8524916013717652
Epoch 16, Average Loss: 1.8524893449401856
Epoch 17, Average Loss: 1.852475835494995
Epoch 18, Average Loss: 1.8524781458663941
Epoch 19, Average Loss: 1.852472984313965
Epoch 20, Average Loss: 1.8524733600616454
Epoch 21, Average Loss: 1.8524688789749146
Epoch 22, Average Loss: 1.8524667210388184
Epoch 23, Average Loss: 1.8524657190322875
Epoch 24, Average Loss:

# Explanation and save csv

In [None]:
import pandas as pd
import torch
import json
from collections import defaultdict

results = []

for batch in val_loader:
    batch = batch.to(device)
    target = batch.y
    explanation = explainer(batch.x, batch.edge_index, batch=batch.batch, target=target)

    print("Explanation keys:", explanation.keys())
    if 'edge_mask' not in explanation:
        raise ValueError("Edge mask is not available in the explanation.")

    important_edges_mask = explanation['edge_mask']
    important_edges = important_edges_mask.nonzero(as_tuple=True)[0]
    important_edges_weights = important_edges_mask[important_edges].cpu().numpy()

    print(f"Important edges: {important_edges}")

    unique_batches = batch.batch.unique()
    for graph_index in unique_batches:
        node_indices = (batch.batch == graph_index).nonzero(as_tuple=True)[0]

        edge_mask = ((torch.isin(batch.edge_index[0, :], node_indices)) & (torch.isin(batch.edge_index[1, :], node_indices))).cpu().numpy()
        edge_indices = edge_mask.nonzero()[0]

        important_edge_indices = [edge_idx for edge_idx in important_edges if batch.edge_index[:, edge_idx][0] in node_indices]
        important_edge_weights_for_graph = [important_edges_mask[edge_idx].cpu().item() for edge_idx in important_edge_indices]

        important_nodes = set()
        node_weights = defaultdict(list)
        for edge_idx, edge_weight in zip(important_edge_indices, important_edge_weights_for_graph):
            edge = batch.edge_index[:, edge_idx]
            node_1 = int(edge[0].item())
            node_2 = int(edge[1].item())
            important_nodes.add(node_1)
            important_nodes.add(node_2)
            node_weights[node_1].append(edge_weight)
            node_weights[node_2].append(edge_weight)

        important_nodes_weights = {node: sum(weights) / len(weights) for node, weights in node_weights.items()}

        all_edges_list = batch.edge_index[:, edge_indices].cpu().numpy().T
        important_edges_list = batch.edge_index[:, important_edge_indices].cpu().numpy().T

        important_node_features = [batch.x[node].cpu().numpy().tolist() for node in important_nodes]

        graph_id = batch.graph_id[graph_index.item()]

        graph_result = {
            'graph_id': graph_id,
            'important_edges': json.dumps(important_edges_list.tolist()),
            'important_edges_weight': json.dumps(important_edge_weights_for_graph),
            'important_nodes': json.dumps([int(node) for node in important_nodes]),
            'important_nodes_weight': json.dumps(important_nodes_weights),
            'important_node_features': json.dumps(important_node_features),
            'node_features': json.dumps(batch.x[node_indices].cpu().numpy().tolist()),
            'edge_index': json.dumps(all_edges_list.tolist()),
            'target': json.dumps(target[graph_index].cpu().numpy().tolist())
        }
        results.append(graph_result)

Explanation_results = pd.DataFrame(results)
Explanation_results.to_csv('/content/drive/MyDrive/CV+GNN/Dataset_Classification/Model_classification/PGE_Classification/1_13safe/Explanation_results.csv', index=False)

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
         514,  515,  517,  518,  523,  525,  526,  528,  531,  533,  551,  552,
         555,  556,  562,  567,  568,  582,  587,  588,  589,  590,  595,  596,
         599,  600,  601,  602,  603,  605,  610,  611,  612,  614,  615,  617,
         618,  620,  621,  623,  624,  636,  637,  641,  642,  644,  649,  652,
         655,  659,  663,  667,  668,  669,  681,  698,  700,  702,  706,  708,
         712,  713,  715,  742,  743,  744,  745,  746,  747,  750,  751,  752,
         753,  754,  758,  760,  761,  764,  766,  768,  775,  779,  790,  811,
         834,  840,  842,  843,  855,  859,  863,  872,  873,  875,  877,  878,
         881,  882,  884,  885,  903,  905,  907,  926,  931,  933,  934,  935,
         936,  937,  938,  939,  940,  942,  944,  949,  953,  955,  957,  958,
         959,  960,  962,  963,  964,  965,  968,  969,  976,  981,  984,  985,
         987,  990,  996,  997,  998,  999, 1001, 1002, 1004, 1005, 1006, 1007,

In [None]:
Explanation_results

Unnamed: 0,graph_id,important_edges,important_edges_weight,important_nodes,important_nodes_weight,important_node_features,node_features,edge_index,target
0,513da120fdc9f03587008a9a,"[[0, 7], [1, 3], [2, 3], [5, 8], [6, 8], [7, 8]]","[1.362379241120689e-34, 6.9217579334600294e-37...","[0, 1, 2, 3, 5, 6, 7, 8]","{""0"": 1.362379241120689e-34, ""7"": 6.8121012838...","[[0.0, 0.9294710159301758, 0.5016834735870361,...","[[0.0, 0.9294710159301758, 0.5016834735870361,...","[[0, 4], [0, 7], [1, 3], [1, 4], [2, 3], [2, 4...",0
1,513e1ab1fdc9f03587009297,[],[],[],{},[],"[[0.0069444444961845875, 0.5314861536026001, 0...","[[9, 14], [9, 15], [10, 15], [11, 14], [11, 15...",0
2,513e5e20fdc9f0358700afa1,"[[16, 26], [17, 18], [17, 20], [17, 31], [18, ...","[2.315639382463441e-33, 9.332410458805333e-31,...","[16, 17, 18, 20, 22, 24, 25, 26, 27, 30, 31]","{""16"": 2.315639382463441e-33, ""26"": 1.15782470...","[[0.0, 0.9622166156768799, 0.7003366947174072,...","[[0.0, 0.9622166156768799, 0.7003366947174072,...","[[16, 26], [16, 32], [17, 18], [17, 20], [17, ...",0
3,513e6aa1fdc9f0358700bf0e,"[[34, 35]]",[6.600700108429512e-39],"[34, 35]","{""34"": 6.600700108429512e-39, ""35"": 6.60070010...","[[0.0, 0.8463475704193115, 0.6969696879386902,...","[[0.0, 0.20906800031661987, 0.8821548819541931...","[[33, 35], [33, 39], [33, 41], [33, 43], [33, ...",0
4,513e21b6fdc9f0358700a48f,"[[51, 52], [51, 53], [52, 55], [53, 54], [53, ...","[1.587651843605526e-25, 7.481930455525287e-27,...","[51, 52, 53, 54, 55, 56]","{""51"": 8.312355740803894e-26, ""52"": 7.93825921...","[[0.013888888992369175, 0.501259446144104, 0.2...","[[0.013888888992369175, 0.501259446144104, 0.2...","[[51, 52], [51, 53], [51, 57], [52, 55], [52, ...",0
...,...,...,...,...,...,...,...,...,...
24995,50f5e8e4fdc9f065f00075be,"[[72, 73], [73, 81], [74, 75], [74, 76], [74, ...","[2.9774483957771507e-33, 7.425710599613109e-36...","[72, 73, 74, 75, 76, 77, 78, 80, 81, 82, 85]","{""72"": 2.9774483957771507e-33, ""73"": 1.4924370...","[[0.0, 0.7833752632141113, 0.5690235495567322,...","[[0.0, 0.7833752632141113, 0.5690235495567322,...","[[72, 73], [72, 79], [72, 84], [73, 79], [73, ...",4
24996,513d7cbdfdc9f03587006fbe,"[[101, 106], [101, 107]]","[8.073106062028027e-39, 9.652443700140705e-39]","[106, 107, 101]","{""101"": 8.862774881084366e-39, ""106"": 8.073106...","[[0.2222222238779068, 0.11586901545524597, 0.6...","[[0.0069444444961845875, 0.7657430171966553, 0...","[[86, 94], [86, 108], [87, 94], [87, 95], [87,...",4
24997,513d6c04fdc9f03587004e73,"[[109, 110], [110, 111], [110, 112], [110, 114...","[1.1103197718375563e-31, 1.2629163248704218e-3...","[109, 110, 111, 112, 113, 114, 118, 120, 121]","{""109"": 1.1103197718375563e-31, ""110"": 1.96701...","[[0.013888888992369175, 0.38790929317474365, 0...","[[0.013888888992369175, 0.38790929317474365, 0...","[[109, 110], [110, 111], [110, 112], [110, 113...",4
24998,50f563b3fdc9f065f0005db9,"[[122, 131], [122, 137]]","[5.447743961847732e-38, 4.40961676950954e-37]","[137, 122, 131]","{""122"": 2.4771955828471568e-37, ""131"": 5.44774...","[[0.118055559694767, 0.03778337314724922, 0.61...","[[0.0, 0.16624684631824493, 0.5387205481529236...","[[122, 129], [122, 131], [122, 133], [122, 136...",4


# Importance analysis of nodes' feature

In [None]:
def analyze_feature_impact(model, data, perturbation_level=0.01):
    model.eval()
    with torch.no_grad():
        original_output = model(data.x, data.edge_index, batch=data.batch).detach()

        feature_importances = torch.zeros(data.x.size(1), device=data.x.device)
        feature_impacts = torch.zeros(data.x.size(1), device=data.x.device)

        for feature_idx in range(data.x.size(1)):
            perturbed_data = data.clone()

            perturbed_data.x[:, feature_idx] += perturbation_level
            positive_output = model(perturbed_data.x, perturbed_data.edge_index, batch=perturbed_data.batch).detach()

            positive_impact = torch.mean(torch.sum((positive_output - original_output).abs(), dim=1))

            perturbed_data.x[:, feature_idx] -= 2 * perturbation_level
            negative_output = model(perturbed_data.x, perturbed_data.edge_index, batch=perturbed_data.batch).detach()

            negative_impact = torch.mean(torch.sum((negative_output - original_output).abs(), dim=1))
            feature_importances[feature_idx] = (positive_impact + negative_impact) / 2
            feature_impacts[feature_idx] = positive_impact - negative_impact

        feature_importances /= feature_importances.sum()

    return feature_importances, feature_impacts

feature_importances, feature_impacts = analyze_feature_impact(model, batch)

sorted_indices = feature_importances.argsort(descending=True)

print("Node feature importances (sorted):")
for idx in sorted_indices:
    impact = "positive" if feature_impacts[idx].item() > 0 else "negative"
    print(f"Feature {idx.item()}: Importance {feature_importances[idx].item():.4f}, Impact: {impact}")

Node feature importances (sorted):
Feature 0: Importance 0.2578, Impact: positive
Feature 8: Importance 0.1976, Impact: negative
Feature 11: Importance 0.0804, Impact: positive
Feature 7: Importance 0.0751, Impact: negative
Feature 3: Importance 0.0618, Impact: negative
Feature 2: Importance 0.0474, Impact: negative
Feature 6: Importance 0.0433, Impact: negative
Feature 12: Importance 0.0367, Impact: negative
Feature 14: Importance 0.0361, Impact: positive
Feature 10: Importance 0.0337, Impact: negative
Feature 5: Importance 0.0320, Impact: positive
Feature 9: Importance 0.0290, Impact: negative
Feature 4: Importance 0.0278, Impact: negative
Feature 1: Importance 0.0211, Impact: positive
Feature 13: Importance 0.0202, Impact: negative
