In [None]:
import torch
from torch_geometric.utils import from_dgl, to_networkx, k_hop_subgraph, subgraph
import random
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, to_undirected
from torch import Tensor
from typing import Optional, Tuple, Union
from collections import Counter
import numpy as np
from tqdm import tqdm
import torch_geometric.utils as utils

In [None]:
reddit = torch.load('pyg_dataset/reddit.pt')
print(reddit)
anomaly_indices = torch.nonzero(reddit.y, as_tuple=False).squeeze().tolist()

In [None]:
# set seed
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [None]:
def random_walk_subgraph(pyg_graph, start_node, walk_length, max_nodes, onlyE=False):
    edge_index = to_undirected(pyg_graph.edge_index)

    # Extract a 2-hop subgraph around the start_node
    hop2_subset, hop2_edge_index, mapping, _ = k_hop_subgraph(start_node, num_hops=2, edge_index=edge_index, relabel_nodes=True)
    node_mapping = {i: hop2_subset[i].item() for i in range(len(hop2_subset))}
    if len(hop2_subset) > max_nodes:
        walks = []
        while len(set(walks)) < max_nodes:
            walk = random_walk(pyg_graph, start_node, walk_length)
            walks.extend(walk)
            
        subset = [item[0] for item in Counter(walks).most_common(max_nodes)]
        subg_edge_index, _ = utils.subgraph(subset, edge_index, relabel_nodes=True)
        node_mapping = {i: subset[i] for i in range(len(subset))}
    else:
        subset = hop2_subset
        subg_edge_index = hop2_edge_index

    x = pyg_graph.y[subset]
    x = torch.nn.functional.one_hot(x, num_classes=2).float()
    edge_attr = torch.tensor([[0, 1] for _ in range(subg_edge_index.shape[1])])
    extra_x = pyg_graph.x[subset]
    node_mapping = torch.tensor(list(node_mapping.values()))
    y = torch.empty(1, 0)
    # remove self-loops or not 
    if onlyE:
        x = torch.ones((len(subset), 1))
        
    # Create a new data object for the subgraph
    d = Data(x=x, edge_index=subg_edge_index, edge_attr = edge_attr, extra_x = extra_x,
             num_nodes=len(subset), node_mapping=node_mapping, y = y)
    return d

def random_walk(pyg_graph, start_node, walk_length=3):
    walk = [start_node]
    edge_index = pyg_graph.edge_index
    for _ in range(walk_length):
        neighbors = edge_index[1][edge_index[0] == walk[-1]]
        if len(neighbors) == 0:  # If no neighbors, stop the walk
            break
        next_node = np.random.choice(neighbors.cpu().numpy())
        walk.append(next_node)
    return walk

In [None]:
i = 10
print(random_walk_subgraph(reddit, i, 3, 150, onlyE=True))
hop2_subset, hop2_edge_index, mapping, _ = k_hop_subgraph(i, num_hops=2, edge_index=reddit.edge_index, relabel_nodes=True)
print(len(hop2_subset))

In [None]:
anomaly_subgraphs = []

for i in tqdm(range(1500)):
    node_idx = random.choice(anomaly_indices)
    subgraph = random_walk_subgraph(reddit, node_idx, 3, 150, onlyE=True)
    anomaly_subgraphs.append(subgraph)


torch.save(anomaly_subgraphs, f'./pyg_dataset/reddit_diffusion/reddit_anomaly.pt')


## Cluster-aware Sampling

In [1]:
from utils import GADDataset
data = GADDataset('reddit')

  from .autonotebook import tqdm as notebook_tqdm


Data(edge_index=[2, 168016], train_masks=[10984, 20], val_masks=[10984, 20], test_masks=[10984, 20], num_nodes=10984, y=[10984], x=[10984, 64])


In [None]:
data.split(semi_supervised=False, trial_id=1)
data.cluster_anomalous_nodes(k=10)

## Subgraph augmentation

In [1]:
# get sampled local subgraph
import torch
from augment import augmentation
from torch_geometric.loader import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
local_subgraph = torch.load('local_subgraphs/reddit_0.pt')
subgraph_loader = 

Data(x=[150, 1], edge_index=[2, 1004], edge_attr=[1004, 2], y=[1, 0], extra_x=[150, 64], num_nodes=150, node_mapping=[150])