In [None]:
# from cuda_test import test_cuda_availability, matrix_multiplication_test
# test_cuda_availability()
# matrix_multiplication_test(size=1000, runs=5)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class SWaTGraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super(SWaTGraphSAGE, self).__init__()
        self.num_layers = num_layers
        
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x, edge_index):
        for i in range(self.num_layers - 1):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)
            x = self.convs[-1](x, edge_index)
        return torch.sigmoid(x)

In [4]:
def create_swat_graph(normal_df, attack_df=None, save_path='graph_data.pt'):
    normal_df.columns = normal_df.columns.str.strip()
    if attack_df is not None:
        attack_df.columns = attack_df.columns.str.strip()
    

    feature_cols = [col for col in normal_df.columns 
                   if col not in ['Timestamp', 'Normal/Attack']]
    
    print("\nnormal_df.columns:")
    print(normal_df.columns.tolist())
    print("\nattack_df.columns:")
    print(attack_df.columns.tolist() if attack_df is not None else "None")
    

    node_features = normal_df[feature_cols].values
    if attack_df is not None:
        attack_features = attack_df[feature_cols].values
        node_features = np.vstack([node_features, attack_features])
    

    edges = []
    feature_to_idx = {name: idx for idx, name in enumerate(feature_cols)}
    

    connections = [
    # P1 connections
    ('FIT101', 'LIT101'),
    ('MV101', 'FIT101'),
    ('P101', 'LIT101'),
    ('P102', 'FIT101'),

    # P2 connections
    ('AIT201', 'AIT202'),
    ('AIT202', 'AIT203'),
    ('FIT201', 'AIT201'),
    ('MV201', 'FIT201'),
    ('P201', 'FIT201'),
    ('P202', 'AIT202'),
    ('P203', 'AIT203'),
    ('P204', 'FIT201'),
    ('P205', 'AIT202'),  
    ('P206', 'AIT203'), 

    # P3 connections
    ('DPIT301', 'FIT301'),
    ('FIT301', 'LIT301'),
    ('MV301', 'FIT301'),
    ('MV302', 'LIT301'),
    ('MV303', 'FIT301'),
    ('MV304', 'LIT301'),
    ('P301', 'FIT301'),
    ('P302', 'LIT301'),

    # P4 connections
    ('AIT401', 'AIT402'),
    ('FIT401', 'LIT401'),
    ('P401', 'FIT401'),
    ('P402', 'LIT401'),
    ('P403', 'FIT401'),
    ('P404', 'LIT401'),
    ('UV401', 'FIT401'),

    # P5 connections
    ('AIT501', 'AIT502'),
    ('AIT502', 'AIT503'),
    ('AIT503', 'AIT504'),
    ('FIT501', 'AIT501'),
    ('FIT502', 'AIT502'),
    ('FIT503', 'AIT503'),
    ('FIT504', 'AIT504'),
    ('P501', 'FIT501'),
    ('P502', 'FIT502'),
    ('PIT501', 'FIT503'),
    ('PIT502', 'FIT504'),
    ('PIT503', 'FIT503'),

    # P6 connections
    ('FIT601', 'P601'),
    ('P601', 'P602'),
    ('P602', 'P603'),

    # Cross-process connections
    ('LIT101', 'AIT201'),  # P1 -> P2
    ('AIT203', 'DPIT301'),  # P2 -> P3
    ('LIT301', 'AIT401'),  # P3 -> P4
    ('FIT401', 'AIT501'),  # P4 -> P5
    ('AIT503', 'FIT601'),  # P5 -> P6
    ('LIT301', 'FIT201'),  # P3 -> P2 
    ('AIT401', 'DPIT301'),  # P4 -> P3
    ('FIT503', 'AIT401'),  # P5 -> P4
    ('P205', 'LIT301'),    
    ('P206', 'FIT503')     
    ]
    
    print("\nCreate.Edge:")
    for src, dst in connections:
        if src in feature_to_idx and dst in feature_to_idx:
            i, j = feature_to_idx[src], feature_to_idx[dst]
            edges.extend([[i, j], [j, i]])  
            print(f"{src} <-> {dst}")
    
    edge_index = torch.tensor(edges, dtype=torch.long).t()
    x = torch.tensor(node_features, dtype=torch.float)
    
    y = torch.zeros(len(node_features))
    if attack_df is not None:
        y[len(normal_df):] = 1
    
    torch.save({'x': x, 'edge_index': edge_index}, save_path)
    print(f"x & edge_index save to {save_path}")
    
    print(f"Node {x.size(0)}")
    print(f"Dim: {x.size(1)}")
    print(f"Edge: {edge_index.size(1)}")
    
    return Data(x=x, edge_index=edge_index, y=y)


In [5]:
def train_graphsage(model, data, epochs=100, lr=0.01):

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.BCELoss()

    num_nodes = data.x.size(0)
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    
    train_indices = np.random.choice(num_nodes, int(0.8 * num_nodes), replace=False)
    train_mask[train_indices] = True
    test_mask[~train_mask] = True
    
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        
        out = model(data.x, data.edge_index)
        loss = criterion(out[train_mask].squeeze(), data.y[train_mask])
        
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            model.eval()
            with torch.no_grad():
                pred = (out[test_mask] > 0.5).float()
                acc = (pred.squeeze() == data.y[test_mask]).float().mean()
                print(f'Epoch {epoch+1:03d}, Loss: {loss:.4f}, Test Acc: {acc:.4f}')
            model.train()
    
    return model

In [6]:
def evaluate_model(model, data):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = (out > 0.5).float()
        acc = (pred.squeeze() == data.y).float().mean()
        
        tp = ((pred.squeeze() == 1) & (data.y == 1)).sum()
        fp = ((pred.squeeze() == 1) & (data.y == 0)).sum()
        tn = ((pred.squeeze() == 0) & (data.y == 0)).sum()
        fn = ((pred.squeeze() == 0) & (data.y == 1)).sum()
        
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = 2 * precision * recall / (precision + recall)
        
    return {
        'accuracy': acc.item(),
        'precision': precision.item(),
        'recall': recall.item(),
        'f1': f1.item()
    }

In [7]:
if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    

    normal_df = pd.read_csv('processed_data/SWaT_normal.csv')
    attack_df = pd.read_csv('processed_data/SWaT_attack.csv')
    print(f"normal_df: {normal_df.shape}, attack_df: {attack_df.shape}")
    
    data = create_swat_graph(normal_df, attack_df)
    print(f"data.x.size(0): {data.x.size(0)}, data.x.size(1): {data.x.size(1)}")
    print(f"data.edge_index.size(1): {data.edge_index.size(1)}")
    
    in_channels = data.x.size(1)  
    hidden_channels = 64
    out_channels = 1
    print(f"in_channels: {in_channels}")
    print(f"hidden_channels: {hidden_channels}")
    print(f"out_channels: {out_channels}")
    
    model = SWaTGraphSAGE(in_channels, hidden_channels, out_channels)
    print(model)
    
    model = train_graphsage(model, data, epochs=100)
    metrics = evaluate_model(model, data)
    print("\nModel Performance:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")
    
    torch.save({'model_state_dict': model.state_dict(),'x': data.x,'edge_index':  data.edge_index}, 'swat_graphsage_model.pt') 

normal_df: (495000, 53), attack_df: (449919, 53)

normal_df.columns:
['Timestamp', 'FIT101', 'LIT101', 'MV101', 'P101', 'P102', 'AIT201', 'AIT202', 'AIT203', 'FIT201', 'MV201', 'P201', 'P202', 'P203', 'P204', 'P205', 'P206', 'DPIT301', 'FIT301', 'LIT301', 'MV301', 'MV302', 'MV303', 'MV304', 'P301', 'P302', 'AIT401', 'AIT402', 'FIT401', 'LIT401', 'P401', 'P402', 'P403', 'P404', 'UV401', 'AIT501', 'AIT502', 'AIT503', 'AIT504', 'FIT501', 'FIT502', 'FIT503', 'FIT504', 'P501', 'P502', 'PIT501', 'PIT502', 'PIT503', 'FIT601', 'P601', 'P602', 'P603', 'Normal/Attack']

attack_df.columns:
['Timestamp', 'FIT101', 'LIT101', 'MV101', 'P101', 'P102', 'AIT201', 'AIT202', 'AIT203', 'FIT201', 'MV201', 'P201', 'P202', 'P203', 'P204', 'P205', 'P206', 'DPIT301', 'FIT301', 'LIT301', 'MV301', 'MV302', 'MV303', 'MV304', 'P301', 'P302', 'AIT401', 'AIT402', 'FIT401', 'LIT401', 'P401', 'P402', 'P403', 'P404', 'UV401', 'AIT501', 'AIT502', 'AIT503', 'AIT504', 'FIT501', 'FIT502', 'FIT503', 'FIT504', 'P501', 'P502'

In [8]:
with torch.no_grad():
    predictions = model(data.x, data.edge_index)
    pred_labels = (predictions > 0.5).float().numpy().flatten()
    true_labels = data.y.numpy()

    accuracy = (pred_labels == true_labels).mean()
    normal_mask = (true_labels == 0)
    attack_mask = (true_labels == 1)
    
    normal_accuracy = (pred_labels[normal_mask] == true_labels[normal_mask]).mean()
    attack_accuracy = (pred_labels[attack_mask] == true_labels[attack_mask]).mean()
    
    print(f"normal_accuracy: {normal_accuracy:.4f}")
    print(f"attack_accuracy: {attack_accuracy:.4f}")

    tp = np.sum((pred_labels == 1) & (true_labels == 1))
    tn = np.sum((pred_labels == 0) & (true_labels == 0))
    fp = np.sum((pred_labels == 1) & (true_labels == 0))
    fn = np.sum((pred_labels == 0) & (true_labels == 1))
    
    print(f" (True Positives): {tp}")
    print(f" (True Negatives): {tn}")
    print(f" (False Positives): {fp}")
    print(f" (False Negatives): {fn}")
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f" (Precision): {precision:.4f}")
    print(f" (Recall): {recall:.4f}")
    print(f"F1score: {f1:.4f}")
    
    anomaly_scores = predictions.numpy().flatten()
    
    print(f"anomaly_scores > 0.5: {(anomaly_scores > 0.5).mean():.4f}")
    print(f"anomaly_scores.max(): {anomaly_scores.max():.4f}")
    print(f"anomaly_scores.min(): {anomaly_scores.min():.4f}")
    print(f"anomaly_scores.mean(): {anomaly_scores.mean():.4f}")
    
    thresholds = [0.3, 0.4, 0.5, 0.6, 0.7]
    for threshold in thresholds:
        pred_at_threshold = (anomaly_scores > threshold).astype(float)
        acc_at_threshold = (pred_at_threshold == true_labels).mean()
        print(f"threshold {threshold}: {acc_at_threshold:.4f}")

normal_accuracy: 0.9999
attack_accuracy: 0.9993
 (True Positives): 449582
 (True Negatives): 494929
 (False Positives): 71
 (False Negatives): 337
 (Precision): 0.9998
 (Recall): 0.9993
F1score: 0.9995
anomaly_scores > 0.5: 0.4759
anomaly_scores.max(): 1.0000
anomaly_scores.min(): 0.0000
anomaly_scores.mean(): 0.4759
threshold 0.3: 0.9991
threshold 0.4: 0.9995
threshold 0.5: 0.9996
threshold 0.6: 0.9995
threshold 0.7: 0.9994


In [9]:
def analyze_feature_importance(model, data, feature_names):
    plt.figure(figsize=(20, 10))
    with torch.no_grad():
        feature_importance = []
        base_output = model(data.x, data.edge_index)
        base_pred = (base_output > 0.5).float()
        
        for i in range(data.x.size(1)):
            perturbed_x = data.x.clone()
            perturbed_x[:, i] = torch.zeros_like(perturbed_x[:, i])
            
            new_output = model(perturbed_x, data.edge_index)
            new_pred = (new_output > 0.5).float()
            
            importance = (base_pred != new_pred).float().mean().item()
            feature_importance.append(importance)

        feature_importance = np.array(feature_importance)
        feature_importance = (feature_importance - feature_importance.min()) / (feature_importance.max() - feature_importance.min())

        plt.subplot(121)
        importance_df = pd.DataFrame({
            'Feature': feature_names,
            'Importance': feature_importance
        }).sort_values('Importance', ascending=True)
        
        sns.barplot(x='Importance', y='Feature', data=importance_df, 
                   palette='YlOrRd')
        plt.title('Feature Importance')
        plt.xlabel('Normalized Importance')
        
        plt.subplot(122)
        G = nx.Graph()
        
        for i, name in enumerate(feature_names):
            G.add_node(i, name=name, importance=feature_importance[i])

        edge_index = data.edge_index.numpy()
        edges = list(zip(edge_index[0], edge_index[1]))
        G.add_edges_from(edges)

        pos = nx.spring_layout(G, k=1, iterations=50)

        node_sizes = [3000 * G.nodes[node]['importance'] for node in G.nodes()]
        node_colors = [G.nodes[node]['importance'] for node in G.nodes()]
        
        nx.draw_networkx_nodes(G, pos, 
                             node_size=node_sizes,
                             node_color=node_colors,
                             cmap=plt.cm.YlOrRd)
        nx.draw_networkx_edges(G, pos, alpha=0.2, edge_color='gray')
        
        labels = {i: f"{name}\n{feature_importance[i]:.2f}" 
                 for i, name in enumerate(feature_names)}
        nx.draw_networkx_labels(G, pos, labels, font_size=8)
        
        plt.title('Node Relationship Graph\n(Node size and color indicate importance)')
    
    plt.tight_layout()
    plt.savefig('feature_importance.png', dpi=300, bbox_inches='tight')
    plt.close()

    importance_ranking = [(name, feature_importance[i]) 
                         for i, name in enumerate(feature_names)]
    importance_ranking.sort(key=lambda x: x[1], reverse=True)
    
    print("\n10:")
    for name, importance in importance_ranking[:10]:
        print(f"{name}: {importance:.4f}")
    
    process_importance = {}
    for name, importance in importance_ranking:
        process_num = name[-3:] if name[-3:].isdigit() else name[-2:] if name[-2:].isdigit() else name[-1]
        process = f"P{process_num}"
        if process not in process_importance:
            process_importance[process] = []
        process_importance[process].append(importance)
  
    process_avg_importance = {
        process: np.mean(importances) 
        for process, importances in process_importance.items()
    }
    sorted_processes = sorted(
        process_avg_importance.items(), 
        key=lambda x: x[1], 
        reverse=True
    )
    
    for process, avg_importance in sorted_processes:
        print(f"{process}: {avg_importance:.4f}")
    
    return importance_ranking

In [10]:
def get_feature_names(normal_df):
    return [col for col in normal_df.columns 
            if col not in ['Timestamp', 'Normal/Attack']]

In [13]:
normal_df = pd.read_csv('processed_data/SWaT_normal.csv')
attack_df = pd.read_csv('processed_data/SWaT_attack.csv')

feature_names = get_feature_names(normal_df)

data = create_swat_graph(normal_df, attack_df)

model = SWaTGraphSAGE(
    in_channels=data.x.size(1),
    hidden_channels=64,
    out_channels=1
)
checkpoint = torch.load('swat_graphsage_model.pt')

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

importance_ranking = analyze_feature_importance(model, data, feature_names)

print("\nfeature_importance.png")


normal_df.columns:
['Timestamp', 'FIT101', 'LIT101', 'MV101', 'P101', 'P102', 'AIT201', 'AIT202', 'AIT203', 'FIT201', 'MV201', 'P201', 'P202', 'P203', 'P204', 'P205', 'P206', 'DPIT301', 'FIT301', 'LIT301', 'MV301', 'MV302', 'MV303', 'MV304', 'P301', 'P302', 'AIT401', 'AIT402', 'FIT401', 'LIT401', 'P401', 'P402', 'P403', 'P404', 'UV401', 'AIT501', 'AIT502', 'AIT503', 'AIT504', 'FIT501', 'FIT502', 'FIT503', 'FIT504', 'P501', 'P502', 'PIT501', 'PIT502', 'PIT503', 'FIT601', 'P601', 'P602', 'P603', 'Normal/Attack']

attack_df.columns:
['Timestamp', 'FIT101', 'LIT101', 'MV101', 'P101', 'P102', 'AIT201', 'AIT202', 'AIT203', 'FIT201', 'MV201', 'P201', 'P202', 'P203', 'P204', 'P205', 'P206', 'DPIT301', 'FIT301', 'LIT301', 'MV301', 'MV302', 'MV303', 'MV304', 'P301', 'P302', 'AIT401', 'AIT402', 'FIT401', 'LIT401', 'P401', 'P402', 'P403', 'P404', 'UV401', 'AIT501', 'AIT502', 'AIT503', 'AIT504', 'FIT501', 'FIT502', 'FIT503', 'FIT504', 'P501', 'P502', 'PIT501', 'PIT502', 'PIT503', 'FIT601', 'P601',

  checkpoint = torch.load('swat_graphsage_model.pt')

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(x='Importance', y='Feature', data=importance_df,



10:
AIT202: 1.0000
AIT501: 0.3716
DPIT301: 0.1933
LIT301: 0.0970
FIT201: 0.0899
LIT101: 0.0460
AIT502: 0.0459
PIT502: 0.0390
AIT402: 0.0360
P501: 0.0359
P202: 0.5000
P501: 0.1030
P301: 0.0600
P201: 0.0248
P304: 0.0225
P502: 0.0217
P402: 0.0200
P101: 0.0160
P504: 0.0155
P205: 0.0146
P302: 0.0139
P503: 0.0114
P401: 0.0105
P303: 0.0068
P203: 0.0064
P403: 0.0006
P601: 0.0000
P102: 0.0000
P204: 0.0000
P206: 0.0000
P404: 0.0000
P602: 0.0000
P603: 0.0000

feature_importance.png
