In [1]:
from cuda_test import test_cuda_availability, matrix_multiplication_test
test_cuda_availability()
matrix_multiplication_test(size=1000, runs=5)

=== CUDA 可用性測試 ===
PyTorch版本: 2.4.1+cu124
CUDA是否可用: True
CUDA版本: 12.4
當前CUDA設備: 0
設備名稱: NVIDIA GeForce RTX 2060
設備數量: 1
設備屬性: _CudaDeviceProperties(name='NVIDIA GeForce RTX 2060', major=7, minor=5, total_memory=6143MB, multi_processor_count=30)

=== 矩陣乘法性能測試 (大小: 1000x1000) ===
CPU平均時間: 0.0066 秒
GPU平均時間: 0.0008 秒
GPU加速比: 8.24x


In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.explain import Explainer, GNNExplainer
import matplotlib.pyplot as plt
import networkx as nx

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class SWaTGraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super(SWaTGraphSAGE, self).__init__()
        self.num_layers = num_layers
        
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        
        self.dropout = torch.nn.Dropout(0.2)
        
    def forward(self, x, edge_index):
        for i in range(self.num_layers - 1):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)
        x = self.convs[-1](x, edge_index)
        return torch.sigmoid(x)

In [4]:
# Initialize the model and load pre-trained weights
in_channels = 51  # Example input feature size, adjust as per your dataset
hidden_channels = 64
out_channels = 1
num_layers = 2

model = SWaTGraphSAGE(in_channels, hidden_channels, out_channels, num_layers)
checkpoint = torch.load('swat_graphsage_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
#model.load_state_dict(torch.load('swat_graphsage_model.pt'))
model.eval()


  checkpoint = torch.load('swat_graphsage_model.pt')


SWaTGraphSAGE(
  (convs): ModuleList(
    (0): SAGEConv(51, 64, aggr=mean)
    (1): SAGEConv(64, 1, aggr=mean)
  )
  (dropout): Dropout(p=0.2, inplace=False)
)

In [5]:
x = checkpoint['x']
edge_index = checkpoint['edge_index']
print("x:",x)
print("edge_index:",edge_index)

x: tensor([[0.0000, 0.0053, 0.5000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0054, 0.5000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0055, 0.5000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.9172, 0.4500, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.9134, 0.4501, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.9064, 0.4506, 1.0000,  ..., 0.0000, 0.0000, 0.0000]])
edge_index: tensor([[ 0,  1,  2,  0,  3,  1,  4,  0,  5,  6,  6,  7,  8,  5,  9,  8, 10,  8,
         11,  6, 12,  7, 13,  8, 14,  6, 15,  7, 16, 17, 17, 18, 19, 17, 20, 18,
         21, 17, 22, 18, 23, 17, 24, 18, 25, 26, 27, 28, 29, 27, 30, 28, 31, 27,
         32, 28, 33, 27, 34, 35, 35, 36, 36, 37, 38, 34, 39, 35, 40, 36, 41, 37,
         42, 38, 43, 39, 44, 40, 45, 41, 46, 40, 47, 48, 48, 49, 49, 50,  1,  5,
          7, 16, 18, 25, 27, 34, 36, 47, 18,  8, 25, 16, 40, 25, 14, 18, 15, 40],
        [ 1,  0,  0,  2,  1,  3,  0,  4,  6,  5,  7,  6,  5,  8,  8,  9,  8, 10,
          6, 11,  7

In [6]:
# Initialize the Explainer
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='raw',
    ),
)

In [7]:
# Explain a specific node (e.g., node index 10)


node_index = 20
explanation = explainer(x,edge_index, index=node_index)

# Print available explanations
print(f'Generated explanations: {explanation.available_explanations}')

Generated explanations: ['edge_mask', 'node_mask']


In [8]:
# Visualize feature importance
feature_importance_path = 'swat_feature_importance.png'
explanation.visualize_feature_importance(feature_importance_path, top_k=10)
print(f'Feature importance plot saved to {feature_importance_path}')

Feature importance plot saved to swat_feature_importance.png


In [9]:
# Visualize the subgraph
subgraph_path = 'swat_subgraph.pdf'
explanation.visualize_graph(subgraph_path)
print(f'Subgraph visualization saved to {subgraph_path}')

Subgraph visualization saved to swat_subgraph.pdf
