In [1]:
import torch
from torch_geometric.utils import from_dgl, to_networkx, k_hop_subgraph, subgraph
import random
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, to_undirected
from torch import Tensor
from typing import Optional, Tuple, Union
from collections import Counter
import numpy as np
from tqdm import tqdm
import torch_geometric.utils as utils

In [2]:
from utils import GADDataset

data = GADDataset('tolokers')
pyg_graph = data.get_pyg_graph(save=False)

train_masks = pyg_graph.train_masks
train_mask = train_masks[:, 0]
print(train_mask.shape)
print(train_mask)

  from .autonotebook import tqdm as notebook_tqdm


Data(edge_index=[2, 530758], train_masks=[11758, 20], val_masks=[11758, 20], test_masks=[11758, 20], num_nodes=11758, y=[11758], x=[11758, 10])
Data(edge_index=[2, 530758], train_masks=[11758, 20], val_masks=[11758, 20], test_masks=[11758, 20], num_nodes=11758, y=[11758], x=[11758, 10])
torch.Size([11758])
tensor([1, 1, 1,  ..., 1, 1, 1], dtype=torch.uint8)


In [3]:
reddit = torch.load('pyg_dataset/tolokers.pt')
print(reddit)
anomaly_indices = torch.nonzero(reddit.y, as_tuple=False).squeeze().tolist()

Data(edge_index=[2, 530758], train_masks=[11758, 20], val_masks=[11758, 20], test_masks=[11758, 20], num_nodes=11758, y=[11758], x=[11758, 10])


In [4]:
# set seed
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [10]:
def random_walk_subgraph(pyg_graph, start_node, walk_length, max_nodes, onlyE=False):
    edge_index = to_undirected(pyg_graph.edge_index)

    # Extract a 2-hop subgraph around the start_node
    hop2_subset, hop2_edge_index, mapping, _ = k_hop_subgraph(start_node, num_hops=2, edge_index=edge_index, relabel_nodes=True)
    node_mapping = {i: hop2_subset[i].item() for i in range(len(hop2_subset))}
    if len(hop2_subset) > max_nodes:
        walks = []
        while len(set(walks)) < max_nodes:
            walk = random_walk(pyg_graph, start_node, walk_length)
            walks.extend(walk)
            
        subset = [item[0] for item in Counter(walks).most_common(max_nodes)]
        subg_edge_index, _ = utils.subgraph(subset, edge_index, relabel_nodes=True)
        node_mapping = {i: subset[i] for i in range(len(subset))}
    else:
        subset = hop2_subset
        subg_edge_index = hop2_edge_index

    x = pyg_graph.y[subset]
    x = torch.nn.functional.one_hot(x, num_classes=2).float()
    edge_attr = torch.tensor([[0, 1] for _ in range(subg_edge_index.shape[1])])
    extra_x = pyg_graph.x[subset]
    node_mapping = torch.tensor(list(node_mapping.values()))
    y = torch.empty(1, 0)
    # remove self-loops or not 
    if onlyE:
        x = torch.ones((len(subset), 1))
        
    # Create a new data object for the subgraph
    d = Data(x=x, edge_index=subg_edge_index, edge_attr = edge_attr, extra_x = extra_x,
             num_nodes=len(subset), node_mapping=node_mapping, y = y)
    return d

def random_walk(pyg_graph, start_node, walk_length=3):
    walk = [start_node]
    edge_index = pyg_graph.edge_index
    for _ in range(walk_length):
        neighbors = edge_index[1][edge_index[0] == walk[-1]]
        if len(neighbors) == 0:  # If no neighbors, stop the walk
            break
        next_node = np.random.choice(neighbors.cpu().numpy())
        walk.append(next_node)
    return walk

In [9]:
import torch
import numpy as np
from torch_geometric.utils import to_undirected, k_hop_subgraph, subgraph
from torch_geometric.data import Data
from collections import deque

def downsample_connected_subgraph(pyg_graph, start_node, max_nodes, onlyE=False):
    edge_index = to_undirected(pyg_graph.edge_index)
    num_nodes = pyg_graph.num_nodes
    
    # Start BFS from the start_node
    visited = set()
    queue = deque([start_node])
    visited.add(start_node)
    
    while queue and len(visited) < max_nodes:
        node = queue.popleft()
        neighbors = edge_index[1][edge_index[0] == node].cpu().numpy()
        np.random.shuffle(neighbors)  # Shuffle to introduce randomness in selection
        for neighbor in neighbors:
            if len(visited) >= max_nodes:
                break
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append(neighbor)
    
    # Convert the visited set to a list for indexing
    subset = list(visited)
    subg_edge_index, _ = subgraph(subset, edge_index, relabel_nodes=True)
    node_mapping = {i: subset[i] for i in range(len(subset))}

    x = pyg_graph.y[subset]
    x = torch.nn.functional.one_hot(x, num_classes=2).float()
    edge_attr = torch.tensor([[0, 1] for _ in range(subg_edge_index.shape[1])])
    extra_x = pyg_graph.x[subset]
    node_mapping = torch.tensor(list(node_mapping.values()))
    y = torch.empty(1, 0)

    if onlyE:
        x = torch.ones((len(subset), 1))
        
    d = Data(x=x, edge_index=subg_edge_index, edge_attr=edge_attr, extra_x=extra_x,
             num_nodes=len(subset), node_mapping=node_mapping, y=y, center_node_idx=start_node)
    return d


In [11]:
anomaly_subgraphs = []

for i in tqdm(range(1500)):
    node_idx = random.choice(anomaly_indices)
    subgraph = downsample_connected_subgraph(reddit, node_idx, 3, 150)
    anomaly_subgraphs.append(subgraph)


# torch.save(anomaly_subgraphs, f'./pyg_dataset/reddit_diffusion/reddit_anomaly.pt')


  0%|          | 1/1500 [00:00<01:11, 21.02it/s]


TypeError: Data.__call__() got an unexpected keyword argument 'relabel_nodes'

## Cluster-aware Sampling

In [1]:
from utils import GADDataset
data = GADDataset('reddit')

  from .autonotebook import tqdm as notebook_tqdm


Data(edge_index=[2, 168016], train_masks=[10984, 20], val_masks=[10984, 20], test_masks=[10984, 20], num_nodes=10984, y=[10984], x=[10984, 64])


In [None]:
data.split(semi_supervised=False, trial_id=1)
data.cluster_anomalous_nodes(k=10)

## Subgraph augmentation

In [22]:
# get sampled local subgraph
import torch
from augment import augmentation
from torch_geometric.loader import DataLoader 
import yaml
from omegaconf import DictConfig
with open('configs/config.yaml') as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)

cfg = DictConfig(cfg)
print(cfg)

{'hydra': {'job': {'chdir': True}, 'run': {'dir': './'}}, 'general': {'name': 'asn', 'wandb': 'online', 'gpus': 1, 'setting': 'train_scratch', 'resume': None, 'ckpt_path': None, 'sample_every_val': 4, 'check_val_every_n_epochs': 1, 'samples_to_generate': 100, 'samples_to_save': 3, 'chains_to_save': 1, 'log_every_steps': 50, 'number_chain_steps': 8, 'final_model_samples_to_generate': 100, 'final_model_samples_to_save': 30, 'final_model_chains_to_save': 20, 'num_train': -1}, 'model': {'type': 'discrete', 'transition': 'marginal', 'model': 'graph_tf', 'diffusion_steps': 500, 'diffusion_noise_schedule': 'cosine', 'n_layers': 5, 'extra_features': 'all', 'hidden_mlp_dims': {'X': 256, 'E': 128, 'y': 128}, 'hidden_dims': {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}, 'lambda_train': [5, 0]}, 'train': {'n_epochs': 300, 'batch_size': 8, 'accumulate_grad_batches': 1, 'lr': 0.0002, 'clip_grad': None, 'save_model': True, 'num_workers': 0, 'ema_decay': 

In [25]:
local_subgraph = torch.load('local_subgraphs/reddit_0.pt')
subgraph_loader = DataLoader(local_subgraph, batch_size=10, shuffle=True)
reddit = torch.load('pyg_dataset/reddit.pt')

In [26]:
augmentation(cfg, reddit, 'reddit', subgraph_loader)

augmenting reddit ...
True
Marginal distribution of the classes: tensor([1.]) for nodes, tensor([0.9419, 0.0581]) for edges




MisconfigurationException: `Trainer(strategy='ddp_find_unused_parameters_true')` is not compatible with an interactive environment. Run your code as a script, or choose one of the compatible strategies: `Fabric(strategy='dp'|'ddp_notebook')`. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.