In [27]:
import torch
from torch_geometric.utils import from_dgl, to_networkx, k_hop_subgraph
import random
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, to_undirected
from torch_cluster import random_walk
from utils import *

In [54]:
def get_khop_subgraph(pyg_data, node_idx, maxN):
    # Extract 2-hop subgraph
    hop2_subset, hop2_edge_index, hop2_mapping, hop2_edge_mask = k_hop_subgraph(
        node_idx, 2, pyg_data.edge_index, relabel_nodes=True, flow='source_to_target')

    # Convert the 2-hop subgraph to undirected
    hop2_edge_index = to_undirected(hop2_edge_index)

    # Perform random walk to sample nodes until maxN unique nodes are reached
    if len(hop2_subset) > maxN:
        walk_start = hop2_mapping[0].item() # center node
        walks = random_walk_until_maxN(hop2_edge_index[0], hop2_edge_index[1], torch.tensor([walk_start]), maxN, walk_length=2)
        subsample_subset = torch.unique(walks.flatten())
        
        while len(subsample_subset) < maxN:
            # keep random walking until we have maxN unique nodes
            walks = random_walk_until_maxN(hop2_edge_index[0], hop2_edge_index[1], torch.tensor([walk_start]), maxN, walk_length=2)
            # concatenate the new walks with the previous walks
            subsample_subset = torch.unique(torch.cat((subsample_subset, torch.unique(walks.flatten()))))

        if len(subsample_subset) > maxN:
            subsample_subset = subsample_subset[:maxN]

        subsample_edge_index, subsample_edge_mask = subgraph(subsample_subset, hop2_edge_index, relabel_nodes=True)
    else:
        subsample_subset = hop2_subset
        subsample_edge_index = hop2_edge_index

    # Create subgraph data object
    subgraph_data = Data(x=pyg_data.x[subsample_subset], edge_index=subsample_edge_index, num_nodes=len(subsample_subset))

    # Set the label for the subgraph based on the original node's label
    label = pyg_data.y[node_idx].item()
    subgraph_data.y = torch.tensor([label], dtype=torch.long)
    subgraph_data.center_node_idx = node_idx

    return subgraph_data

In [49]:
pyg_data = torch.load('./pyg_dataset/tolokers.pt')
pyg_data 

Data(edge_index=[2, 530758], train_masks=[11758], val_masks=[11758], test_masks=[11758], num_nodes=11758, y=[11758], x=[11758, 10])

In [50]:
train_indices = torch.nonzero(pyg_data.train_masks, as_tuple=False).squeeze().tolist()
vali_indices = torch.nonzero(pyg_data.val_masks, as_tuple=False).squeeze().tolist()
test_indices = torch.nonzero(pyg_data.test_masks, as_tuple=False).squeeze().tolist()

print(len(train_indices), len(vali_indices), len(test_indices))

5879 2939 2940


In [53]:
anomaly_idx = (pyg_data.y == 1).nonzero().squeeze()
train_anomaly_idx = list(set(anomaly_idx.tolist()).intersection(set(train_indices)))
train_anomaly_idx[:20]

[0,
 4,
 8,
 9,
 8204,
 19,
 8230,
 8234,
 51,
 52,
 8250,
 8251,
 58,
 59,
 62,
 8260,
 8265,
 74,
 75,
 76]

In [55]:
for idx in train_anomaly_idx[:20]:
    print(idx)
    subgraph_data = get_khop_subgraph(pyg_data, idx, 100)
    print(subgraph_data)

0
Data(x=[1, 10], edge_index=[2, 1], num_nodes=1, y=[1], center_node_idx=0)
4
Data(x=[2, 10], edge_index=[2, 4], num_nodes=2, y=[1], center_node_idx=4)
8
Data(x=[1, 10], edge_index=[2, 1], num_nodes=1, y=[1], center_node_idx=8)
9
Data(x=[1, 10], edge_index=[2, 1], num_nodes=1, y=[1], center_node_idx=9)
8204
Data(x=[19, 10], edge_index=[2, 139], num_nodes=19, y=[1], center_node_idx=8204)
19
Data(x=[1, 10], edge_index=[2, 1], num_nodes=1, y=[1], center_node_idx=19)
8230
Data(x=[88, 10], edge_index=[2, 1672], num_nodes=88, y=[1], center_node_idx=8230)
8234
Data(x=[100, 10], edge_index=[2, 2284], num_nodes=100, y=[1], center_node_idx=8234)
51
Data(x=[1, 10], edge_index=[2, 1], num_nodes=1, y=[1], center_node_idx=51)
52
Data(x=[11, 10], edge_index=[2, 43], num_nodes=11, y=[1], center_node_idx=52)
8250
Data(x=[100, 10], edge_index=[2, 1860], num_nodes=100, y=[1], center_node_idx=8250)
8251
Data(x=[100, 10], edge_index=[2, 4026], num_nodes=100, y=[1], center_node_idx=8251)
58
Data(x=[1, 10], 

In [68]:
subgraph_data

Data(x=[2, 10], edge_index=[2, 4], num_nodes=2, y=[1], center_node_idx=1)

In [53]:
len(subset)

10109

In [40]:
import networkx as nx
nx_graph = to_networkx(subgraph, to_undirected=True)
components = list(nx.connected_components(nx_graph))
if len(components) > 1:
    print(f"Graph contains {len(components)} connected components.")
else:
    print("Graph is connected.")

Graph is connected.


In [62]:
import torch
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph

# Create a sample graph
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5],
                           [2, 2, 4, 4, 6, 6]], dtype=torch.long)

# Parameters
node_idx = 6  # Node to start the k-hop subgraph extraction
num_hops = 2  # Number of hops

# Extract k-hop subgraph (source to target)
subset_s2t, edge_index_s2t, _, _ = k_hop_subgraph(node_idx, num_hops, edge_index, flow='source_to_target')

# Extract k-hop subgraph (target to source)
subset_t2s, edge_index_t2s, _, _ = k_hop_subgraph(node_idx, num_hops, edge_index, flow='target_to_source')

# Print results
print("Original Edge Index:")
print(edge_index)

print("\nSubset nodes (source to target):")
print(subset_s2t)

print("Edge Index (source to target):")
print(edge_index_s2t)

print("\nSubset nodes (target to source):")
print(subset_t2s)

print("Edge Index (target to source):")
print(edge_index_t2s)


Original Edge Index:
tensor([[0, 1, 2, 3, 4, 5],
        [2, 2, 4, 4, 6, 6]])

Subset nodes (source to target):
tensor([2, 3, 4, 5, 6])
Edge Index (source to target):
tensor([[2, 3, 4, 5],
        [4, 4, 6, 6]])

Subset nodes (target to source):
tensor([6])
Edge Index (target to source):
tensor([], size=(2, 0), dtype=torch.int64)
