In [29]:
import pandas as pd
import networkx as nx
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import SAGEConv
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.utils import from_networkx
from sklearn.metrics import confusion_matrix, classification_report
import os
from tqdm import tqdm
import joblib
import matplotlib.pyplot as plt

In [30]:
data_path = 'stage3_data_cleaning/v2/type1_label_merged_final_decoded_clean3.xlsx'
data = pd.read_excel(data_path)
data['can_id'] = data['can_id'].astype(str)

In [31]:
data.head()

Unnamed: 0,label,timestamp,can_id,data_length,source_node_id_decimal,service_flag,priority,message_type_decimal,destination_node_id_decimal,request_or_response,...,end_of_message,single_message_frame,transfer_ID,effective_data_0,effective_data_1,effective_data_2,effective_data_3,effective_data_4,effective_data_5,effective_data_6
0,0,0.0,10015501,8,1,0,16,341,-99,-99,...,1,1,0,0,0,0,0,8,0,0
1,0,0.192053,104E2001,2,1,0,16,20000,-99,-99,...,1,1,0,0,-199,-199,-199,-199,-199,-199
2,0,0.192335,1F043901,8,1,0,31,1081,-99,-99,...,0,0,0,0,0,0,246,0,-199,-199
3,0,0.192504,1F043901,8,1,0,31,1081,-99,-99,...,0,0,0,0,248,0,0,247,255,223
4,0,0.192637,1F043901,4,1,0,31,1081,-99,-99,...,1,0,0,254,0,0,-199,-199,-199,-199


In [32]:
import os
output_dir = "can_graphs/v8"
visualization_dir = os.path.join(output_dir, "visualizations")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(visualization_dir, exist_ok=True)

In [33]:
def create_graph(window_df):
    G = nx.DiGraph()
    index_tracker = {}
    
    for i in range(len(window_df) - 1):
        node1 = window_df.iloc[i]['can_id']
        node2 = window_df.iloc[i + 1]['can_id']
        timestamp_diff = window_df.iloc[i + 1]['timestamp'] - window_df.iloc[i]['timestamp']
        label = window_df.iloc[i]['label']
        transfer_id1 = window_df.iloc[i]['transfer_ID']
        transfer_id2 = window_df.iloc[i + 1]['transfer_ID']
        
        if node1 != node2 or transfer_id1 != transfer_id2:  # Avoid self-loops
            if G.has_edge(node1, node2):
                G[node1][node2]['weight'] += timestamp_diff
            else:
                G.add_edge(node1, node2, weight=timestamp_diff)
        
        if node1 not in index_tracker:
            index_tracker[node1] = []
        index_tracker[node1].append((i, label))
        
        # if i == len(window_df)-1:
        if node2 not in index_tracker:
            index_tracker[node2] = []
        index_tracker[node2].append((i, label))
    
    # print(index_tracker)
    # index_tracker = {k: sorted(list(v)) for k, v in index_tracker.items()}
    # print(index_tracker)
    # Calculate PageRank and in-degree
    pagerank = nx.pagerank(G)
    indegree = dict(G.in_degree())
    
    for node in G.nodes:
        G.nodes[node]['pagerank'] = pagerank.get(node, 0.0)
        G.nodes[node]['indegree'] = indegree.get(node, 0)
    
    return G, index_tracker

In [34]:
# Function to visualize the graph and save to file
def visualize_graph(G, window_index):
    pos = nx.spring_layout(G)
    pagerank = nx.get_node_attributes(G, 'pagerank')
    indegree = nx.get_node_attributes(G, 'indegree')
    labels = {node: f'{node}\nPR: {pagerank[node]:.2f}\nInDeg: {indegree[node]}' for node in G.nodes()}
    
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, labels=labels, node_size=7000, node_color='skyblue', font_size=10, edge_color='gray')
    plt.title(f"Graph for Window {window_index}")
    output_path = os.path.join(visualization_dir, f'graph_window_{window_index}.png')
    plt.savefig(output_path)
    plt.close()

In [35]:

def preprocess_data(data, window_size=100):
    pyg_data_list = []
    for window_start in tqdm(range(0, len(data), window_size)):
        window_end = min(window_start + window_size, len(data))
        window_data = data.iloc[window_start:window_end]
        G, index_tracker = create_graph(window_data)
        # break
        # Convert networkx graph to PyG data object
        pyg_data = from_networkx(G, group_node_attrs=['pagerank', 'indegree'])
        
        # Add labels to PyG data object
        labels = []
        for node in G.nodes:
            # Use the most recent label for each node
            labels.append(index_tracker[node][-1][1])
        pyg_data.y = torch.tensor(labels, dtype=torch.long)

        # Save the raw graph for later analysis
        graph_path = os.path.join(output_dir, f'graph_window_{window_start // window_size}.gpickle')
        # nx.write_gpickle(G, graph_path)
        nx.write_graphml(G, graph_path)
        
        # Save the visualization of the graph
        visualize_graph(G, window_start // window_size)
        
        
        pyg_data_list.append(pyg_data)
    
    return pyg_data_list

In [44]:

# Train the model
def train_model(model, train_loader, optimizer, criterion, epochs=20):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data in train_loader:
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}')

# Evaluate the model
def evaluate_model(model, test_loader):
    model.eval()
    y_true = []
    y_pred = []
    for data in test_loader:
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        y_true.extend(data.y.tolist())
        y_pred.extend(pred.tolist())
    
    cm = confusion_matrix(y_true, y_pred)
    report = classification_report(y_true, y_pred,output_dict=True)
    return cm, report

# Save the model
def save_model(model, path):
    torch.save(model.state_dict(), path)

In [37]:
class EGraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(EGraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.lin(x)
        return F.log_softmax(x, dim=-1)


In [38]:
pyg_data_list = preprocess_data(data)

100%|██████████| 2079/2079 [02:54<00:00, 11.92it/s]


In [39]:
train_size = int(0.7 * len(pyg_data_list))
train_data = pyg_data_list[:train_size]
test_data = pyg_data_list[train_size:]

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)





In [40]:
# Model initialization
model = EGraphSAGE(in_channels=2, hidden_channels=128, out_channels=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

In [41]:
train_model(model, train_loader, optimizer, criterion)

Epoch 1, Loss: 0.3265642876858297
Epoch 2, Loss: 0.24534739780685175
Epoch 3, Loss: 0.22238780655290769
Epoch 4, Loss: 0.22390729885386385
Epoch 5, Loss: 0.222973865011464
Epoch 6, Loss: 0.21619846639425858
Epoch 7, Loss: 0.2178439812167831
Epoch 8, Loss: 0.21164145320653915
Epoch 9, Loss: 0.2077532358791517
Epoch 10, Loss: 0.2064644245673781
Epoch 11, Loss: 0.21197256397293962
Epoch 12, Loss: 0.2121329678465491
Epoch 13, Loss: 0.20609567460158598
Epoch 14, Loss: 0.2078935478044593
Epoch 15, Loss: 0.20318811500202055
Epoch 16, Loss: 0.2009108589719171
Epoch 17, Loss: 0.2017851961047753
Epoch 18, Loss: 0.20278253082347952
Epoch 19, Loss: 0.20055269498540007
Epoch 20, Loss: 0.19804564118385315


In [45]:
cm, report = evaluate_model(model, test_loader)
print('Confusion Matrix:\n', cm)
print('Classification Report:\n', report)

Confusion Matrix:
 [[1135  118]
 [  62  322]]
Classification Report:
 {'0': {'precision': 0.948203842940685, 'recall': 0.9058260175578612, 'f1-score': 0.926530612244898, 'support': 1253.0}, '1': {'precision': 0.7318181818181818, 'recall': 0.8385416666666666, 'f1-score': 0.7815533980582524, 'support': 384.0}, 'accuracy': 0.8900427611484423, 'macro avg': {'precision': 0.8400110123794334, 'recall': 0.8721838421122639, 'f1-score': 0.8540420051515751, 'support': 1637.0}, 'weighted avg': {'precision': 0.8974450806492731, 'recall': 0.8900427611484423, 'f1-score': 0.8925225180190751, 'support': 1637.0}}


In [46]:
report["accuracy"]

0.8900427611484423

In [27]:
save_model(model, os.path.join(output_dir, 'sageConv_model_probalistic_label.pth'))