In [1]:
import pandas as pd
import networkx as nx
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import SAGEConv
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.utils import from_networkx
from sklearn.metrics import confusion_matrix, classification_report
import os
from tqdm import tqdm
import joblib
import matplotlib.pyplot as plt

data_path = 'stage3_data_cleaning/v2/type1_label_merged_final_decoded_clean3.xlsx'
data = pd.read_excel(data_path)
data['can_id'] = data['can_id'].astype(str)

output_dir = "can_graphs/v8"
visualization_dir = os.path.join(output_dir, "visualizations")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(visualization_dir, exist_ok=True)

In [9]:
def calculate_optimized_pagerank(G, damping_factor=0.7):
    N = len(G)
    pagerank = {node: 1 / N for node in G}
    for _ in range(100):  # Iterate 100 times for convergence
        new_pagerank = {}
        for node in G:
            rank_sum = 0
            for neighbor in G.predecessors(node):
                weight_sum = sum([G[neighbor][succ]['weight'] for succ in G.successors(neighbor)])
                rank_sum += pagerank[neighbor] * (G[neighbor][node]['weight'] / weight_sum)
            new_pagerank[node] = (1 - damping_factor) / N + damping_factor * rank_sum
        pagerank = new_pagerank
    nx.set_node_attributes(G, pagerank, 'pagerank')
    return G

In [10]:
def create_graph(window_df):
    G = nx.DiGraph()
    index_tracker = {}
    
    for i in range(len(window_df) - 1):
        node1 = window_df.iloc[i]['can_id']
        node2 = window_df.iloc[i + 1]['can_id']
        timestamp_diff = window_df.iloc[i + 1]['timestamp'] - window_df.iloc[i]['timestamp']
        label = window_df.iloc[i]['label']
        transfer_id1 = window_df.iloc[i]['transfer_ID']
        transfer_id2 = window_df.iloc[i + 1]['transfer_ID']
        
        if node1 != node2 or transfer_id1 != transfer_id2:  # Avoid self-loops
            if G.has_edge(node1, node2):
                G[node1][node2]['weight'] += timestamp_diff
            else:
                G.add_edge(node1, node2, weight=timestamp_diff)
        
        if node1 not in index_tracker:
            index_tracker[node1] = []
        index_tracker[node1].append((i, label))
        
        # if i == len(window_df)-1:
        if node2 not in index_tracker:
            index_tracker[node2] = []
        index_tracker[node2].append((i, label))

    index_tracker = {k: sorted(list(v)) for k, v in index_tracker.items()}
    
    # print(index_tracker)
    # Convert sets to sorted lists to ensure consistent ordering
    index_tracker = {k: sorted(list(v)) for k, v in index_tracker.items()}
    
    # Calculate optimized PageRank
    G = calculate_optimized_pagerank(G)
    
    return G, index_tracker

In [13]:
# Function to visualize the graph and save to file
def visualize_graph_old_v1(G, window_index):
    pos = nx.spring_layout(G)
    pagerank = nx.get_node_attributes(G, 'pagerank')
    # indegree = nx.get_node_attributes(G, 'indegree')
    labels = {node: f'{node}\nPR: {pagerank[node]:.2f}\nInDeg: {indegree[node]}' for node in G.nodes()}
    
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, labels=labels, node_size=7000, node_color='skyblue', font_size=10, edge_color='gray')
    plt.title(f"Graph for Window {window_index}")
    output_path = os.path.join(visualization_dir, f'graph_window_{window_index}.png')
    plt.savefig(output_path)
    plt.close()

def visualize_graph(G, window_index):
    pos = nx.spring_layout(G)
    pagerank = nx.get_node_attributes(G, 'pagerank')
    
    # Ensure all nodes have a pagerank value, set default if missing
    for node in G.nodes():
        if node not in pagerank:
            pagerank[node] = 0.0  # Default PageRank value
    
    labels = {node: f'{node}\nPR: {pagerank[node]:.2f}' for node in G.nodes()}
    
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, labels=labels, node_size=7000, node_color='skyblue', font_size=10, edge_color='gray')
    plt.title(f"Graph for Window {window_index}")
    output_path = os.path.join(visualization_dir, f'graph_window_{window_index}.png')
    plt.savefig(output_path)
    plt.close()


def preprocess_data(data, window_size=100):
    pyg_data_list = []
    for window_start in tqdm(range(0, len(data), window_size)):
        window_end = min(window_start + window_size, len(data))
        window_data = data.iloc[window_start:window_end]
        G, index_tracker = create_graph(window_data)
        
        # Convert networkx graph to PyG data object
        pyg_data = from_networkx(G, group_node_attrs=['pagerank'])
        pyg_data.x = pyg_data.x.float()  # Ensure x is Float
        
        
        # Add labels to PyG data object
        labels = []
        for node in G.nodes:
            # Use the most recent label for each node
            labels.append(index_tracker[node][-1][1])
        pyg_data.y = torch.tensor(labels, dtype=torch.long)
        
        # Save the raw graph for later analysis
        graph_path = os.path.join(output_dir, f'graph_window_{window_start // window_size}.graphml')
        nx.write_graphml(G, graph_path)
        
        # Save the visualization of the graph
        visualize_graph(G, window_start // window_size)
        
        pyg_data_list.append(pyg_data)
    
    return pyg_data_list



# def preprocess_data(data, window_size=100):
#     pyg_data_list = []
#     for window_start in tqdm(range(0, len(data), window_size)):
#         window_end = min(window_start + window_size, len(data))
#         window_data = data.iloc[window_start:window_end]
#         G, index_tracker = create_graph(window_data)
#         # break
#         # Convert networkx graph to PyG data object
#         pyg_data = from_networkx(G, group_node_attrs=['pagerank'])

#         pyg_data.x = pyg_data.x.float()
        
#         # Add labels to PyG data object
#         labels = []
#         for node in G.nodes:
#             # Use the most recent label for each node
#             labels.append(index_tracker[node][-1][1])
#         pyg_data.y = torch.tensor(labels, dtype=torch.long)

#         # Save the raw graph for later analysis
#         graph_path = os.path.join(output_dir, f'graph_window_{window_start // window_size}.gpickle')
#         # nx.write_gpickle(G, graph_path)
#         nx.write_graphml(G, graph_path)
        
#         # Save the visualization of the graph
#         visualize_graph(G, window_start // window_size)
        
        
#         pyg_data_list.append(pyg_data)
    
#     return pyg_data_list


# Train the model
def train_model(model, train_loader, optimizer, criterion, epochs=20):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data in train_loader:
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}')

    

# Evaluate the model
def evaluate_model(model, test_loader):
    model.eval()
    y_true = []
    y_pred = []
    for data in test_loader:
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        y_true.extend(data.y.tolist())
        y_pred.extend(pred.tolist())
    
    cm = confusion_matrix(y_true, y_pred)
    report = classification_report(y_true, y_pred)
    return cm, report

# Save the model
def save_model(model, path):
    torch.save(model.state_dict(), path)


class EGraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(EGraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.lin(x)
        return F.log_softmax(x, dim=-1)
    



In [14]:
pyg_data_list = preprocess_data(data)

  0%|          | 0/2079 [00:00<?, ?it/s]

100%|██████████| 2079/2079 [02:43<00:00, 12.69it/s]


In [16]:

train_size = int(0.7 * len(pyg_data_list))
train_data = pyg_data_list[:train_size]
test_data = pyg_data_list[train_size:]

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

model = EGraphSAGE(in_channels=1, hidden_channels=128, out_channels=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

train_model(model, train_loader, optimizer, criterion)

Epoch 1, Loss: 0.302394542357196
Epoch 2, Loss: 0.23712661635616553
Epoch 3, Loss: 0.2260263686918694
Epoch 4, Loss: 0.21370285328315652
Epoch 5, Loss: 0.21267390380734982
Epoch 6, Loss: 0.21187588928834253
Epoch 7, Loss: 0.206527818966171
Epoch 8, Loss: 0.20887750387191772
Epoch 9, Loss: 0.20584902879984482
Epoch 10, Loss: 0.20111454796531927
Epoch 11, Loss: 0.21092909266767296
Epoch 12, Loss: 0.20280337576632915
Epoch 13, Loss: 0.2064149243676144
Epoch 14, Loss: 0.20343439604925073
Epoch 15, Loss: 0.20026986884034198
Epoch 16, Loss: 0.2019865843264953
Epoch 17, Loss: 0.20383687155402225
Epoch 18, Loss: 0.19858395972329637
Epoch 19, Loss: 0.20254863685239916
Epoch 20, Loss: 0.19529277507377707


In [17]:
cm, report = evaluate_model(model, test_loader)
print('Confusion Matrix:\n', cm)
print('Classification Report:\n', report)

Confusion Matrix:
 [[1117  136]
 [  40  344]]
Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.89      0.93      1253
           1       0.72      0.90      0.80       384

    accuracy                           0.89      1637
   macro avg       0.84      0.89      0.86      1637
weighted avg       0.91      0.89      0.90      1637



In [18]:
save_model(model, os.path.join(output_dir, 'graphsage_model_optimized_pagerank_no_indegree.pth'))

In [21]:
from torch_geometric.nn import GCNConv
class GCNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCNN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.lin(x)
        return F.log_softmax(x, dim=-1)

In [22]:

model2 = GCNN(in_channels=1, hidden_channels=128, out_channels=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

In [23]:

train_model(model2, train_loader, optimizer, criterion)

Epoch 1, Loss: 0.6961166158966396
Epoch 2, Loss: 0.6960158309210902
Epoch 3, Loss: 0.6961866293264471
Epoch 4, Loss: 0.6962129875369694
Epoch 5, Loss: 0.696018108855123
Epoch 6, Loss: 0.6955929074598395
Epoch 7, Loss: 0.6959004596523617
Epoch 8, Loss: 0.696125618789507
Epoch 9, Loss: 0.6957598808019058
Epoch 10, Loss: 0.6961407700310582
Epoch 11, Loss: 0.6963170753872913
Epoch 12, Loss: 0.6954345651294874
Epoch 13, Loss: 0.6959811578626218
Epoch 14, Loss: 0.6966515520344609
Epoch 15, Loss: 0.6963633739429972
Epoch 16, Loss: 0.6953185187733691
Epoch 17, Loss: 0.6959081616090692
Epoch 18, Loss: 0.6956511969151704
Epoch 19, Loss: 0.6956994533538818
Epoch 20, Loss: 0.6959719696770543


In [24]:
cm, report = evaluate_model(model, test_loader)
print('Confusion Matrix:\n', cm)
print('Classification Report:\n', report)

Confusion Matrix:
 [[1117  136]
 [  40  344]]
Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.89      0.93      1253
           1       0.72      0.90      0.80       384

    accuracy                           0.89      1637
   macro avg       0.84      0.89      0.86      1637
weighted avg       0.91      0.89      0.90      1637

