In [None]:
import os
import random
import numpy as np
import torch
import torch_geometric.transforms as T
from torch_geometric.data import HeteroData
from torch_geometric.transforms import RandomNodeSplit
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.datasets import CitationFull
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.datasets import WikipediaNetwork
from torch_geometric.datasets import HeterophilousGraphDataset

In [None]:
def setup_determinism(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

In [None]:
def sub_data(data, mask, device):
    s_data = HeteroData()
    s_data['paper'].x = data.x
    s_data['paper', 'cites', 'paper'].edge_index = data.edge_index
    unique_classes = data.y.unique()
    s_data['label'].x = torch.eye(len(unique_classes), dtype=torch.float32, device=device)
    mask_indices = torch.nonzero(mask, as_tuple=False).flatten()
    label_edges = torch.zeros((2, mask_indices.size(0)), dtype=torch.int64, device=device)
    label_edges[0] = mask_indices
    label_edges[1] = data.y[mask_indices]
    s_data['paper', 'is', 'label'].edge_index = label_edges
    return s_data

def split_train_edges(data, split_ratio=0.3):
    total_edges = data['paper', 'is', 'label'].edge_index.size(1)
    num_msg_edges = int(total_edges * split_ratio)
    perm = torch.randperm(total_edges)
    msg_edges = perm[:num_msg_edges]
    sup_edges = perm[num_msg_edges:]
    msg_edge_index = data['paper', 'is', 'label'].edge_index[:, msg_edges]
    sup_edge_index = data['paper', 'is', 'label'].edge_index[:, sup_edges]
    return msg_edge_index, sup_edge_index

def prepare_lp_data(data, train_mask, val_mask, test_mask, device):
    train_data_LP = sub_data(data, train_mask, device)
    val_data_LP = sub_data(data, val_mask, device)
    test_data_LP = sub_data(data, test_mask, device)
    msg_edge_index, sup_edge_index = split_train_edges(train_data_LP)

    train_data_LP['paper', 'is', 'label'].edge_index = msg_edge_index
    train_data_LP['paper', 'is', 'label'].edge_label_index = sup_edge_index
    train_data_LP['paper', 'is', 'label'].edge_label = torch.tensor(np.ones((sup_edge_index.size(1))))

    val_data_LP['paper', 'is', 'label'].edge_label_index = val_data_LP['paper', 'is', 'label'].edge_index
    val_data_LP['paper', 'is', 'label'].edge_index = torch.cat((msg_edge_index, sup_edge_index), 1)
    val_data_LP['paper', 'is', 'label'].edge_label = torch.tensor(np.ones((val_data_LP['paper', 'is', 'label'].edge_label_index.size(1))))

    test_data_LP['paper', 'is', 'label'].edge_label_index = test_data_LP['paper', 'is', 'label'].edge_index
    test_data_LP['paper', 'is', 'label'].edge_index = torch.cat(
        (msg_edge_index, sup_edge_index, val_data_LP['paper', 'is', 'label'].edge_label_index), 1
    )
    test_data_LP['paper', 'is', 'label'].edge_label = torch.tensor(
        np.ones((test_data_LP['paper', 'is', 'label'].edge_label_index.size(1)))
    )

    train_data_LP = T.ToUndirected()(train_data_LP)
    val_data_LP = T.ToUndirected()(val_data_LP)
    test_data_LP = T.ToUndirected()(test_data_LP)

    return train_data_LP, val_data_LP, test_data_LP

In [None]:
from torch_geometric.utils import degree, to_undirected
from torch_geometric.loader import NeighborLoader
from torch_geometric.data import Data
import pandas as pd

def get_n_hop_sizes(graph_data, n_hop, input_nodes=None):
    if input_nodes is None:
        input_nodes = torch.arange(graph_data.num_nodes)
    loader = NeighborLoader(
        graph_data,
        num_neighbors=[-1] * n_hop,
        batch_size=1,
        shuffle=False,
        input_nodes=input_nodes
    )
    sizes = []
    hp_numbers = []
    for batch in loader:
        sizes.append(batch.num_nodes - 1)  # subtract the root node itself
        root = batch.n_id[0]
        num_same_class = 0
        num_label_neighbors = 0
        for i in range(batch.num_nodes-1):
            neighbor_id = batch.n_id[i+1]
            root_class = -1
            if root<len(og_data.y):
                root_class = og_data.y[root]
            if neighbor_id<len(og_data.y):   
                if og_data.y[neighbor_id] == root_class :
                    num_same_class += 1
            else:
                num_label_neighbors += 1
        
        hp_numbers.append(num_same_class/(batch.num_nodes-1-num_label_neighbors))
        
    return torch.tensor(sizes), torch.tensor(hp_numbers)
    
def graph_stats(new_edges, new_edges_rev, data):    
    # ORIGINAL GRAPH
    
    # metrics for the original graph (before rewiring)
    num_paper_original = data.num_nodes
    original_edge_index = data.edge_index
    
    #original_degrees = degree(original_edge_index[0], num_nodes=num_paper_original)
    original_1hop, original_hp_1hop= get_n_hop_sizes(data, n_hop=1)
    original_2hop, original_hp_2hop = get_n_hop_sizes(data, n_hop=2)
    original_3hop, original_hp_3hop = get_n_hop_sizes(data, n_hop=3)
    
    
    # reindex label nodes
    new_edges[1] += num_paper_original  # Shift label indices by the number of paper nodes!!
    new_edges_rev[0] += num_paper_original
    # combine edges
    combined_edge_index = torch.cat([original_edge_index, new_edges, new_edges_rev], dim=1)
    
    # calculate NEW total number of nodes
    N_total = num_paper_original + num_classes
    
    # turn into homogeneous graph
    homo_data = Data(edge_index=combined_edge_index, num_nodes=N_total)
    
    # metrics for the rewired graph (link prediction)
    rewired_degrees = degree(homo_data.edge_index[0], num_nodes=N_total)
    rewired_2hop, rewired_hp_2hop = get_n_hop_sizes(homo_data, n_hop=2)
    rewired_3hop, rewired_hp_3hop = get_n_hop_sizes(homo_data, n_hop=3)
    
    # results
    original_data = {
        "Node Index": torch.arange(num_paper_original).numpy(),
        "Original 1-hop": original_1hop.numpy(),
        "Original 2-hop": original_2hop.numpy(),
        "Original 3-hop": original_3hop.numpy(),
        "Original 1-hop HP": original_hp_1hop.numpy(),
        "Original 2-hop HP": original_hp_2hop.numpy(),
        "Original 3-hop HP": original_hp_3hop.numpy(),
    }
    # rewired metrics apply to all nodes
    rewired_data = {
        "Node Index": torch.arange(N_total).numpy(),
        "Rewired 1-hop": rewired_degrees.numpy(),
        "Rewired 2-hop": rewired_2hop.numpy(),
        "Rewired 3-hop": rewired_3hop.numpy(),
        "Rewired 2-hop HP": rewired_hp_2hop.numpy(),
        "Rewired 3-hop HP": rewired_hp_3hop.numpy(),
        "Is Label Node": [1 if i >= num_paper_original else 0 for i in range(N_total)],
    }
    
    original_df = pd.DataFrame(original_data)
    rewired_df = pd.DataFrame(rewired_data)
    comparison_df = pd.merge(original_df, rewired_df, on="Node Index", how="right")
    
    return comparison_df

In [None]:
ds = 'CiteSeer'

In [None]:

dataset=None
if ds == 'CiteSeer':
    dataset = CitationFull(root='data/CitationFull', name='CiteSeer')
elif ds == 'Cora_ML':
    dataset = CitationFull(root='data/CitationFull', name='Cora_ML')
elif ds == 'Chameleon':
    dataset = WikipediaNetwork(root='data/chameleon', name='chameleon')
elif ds == 'Roman_Empire':
    dataset = HeterophilousGraphDataset(root='data/RomanEmpire', name='Roman-empire')
elif ds == 'Squirrel':
    dataset = WikipediaNetwork(root='data/squirrel', name='squirrel')
elif ds == 'OGBN':
    dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='../../data/ogbn_arxiv')
else:
    print("Invalid dataset name.")
og_data = dataset[0]
print(og_data)
del og_data.train_mask
del og_data.val_mask
del og_data.test_mask
og_data = T.ToUndirected()(og_data)

In [None]:
setup_determinism(1)
transform = RandomNodeSplit(num_val=0.1, num_test=0.1)
data = transform(og_data.clone())
num_classes = data.y.max().item() + 1

In [None]:
train_data_LP, val_data_LP, test_data_LP = prepare_lp_data(data, data.train_mask, data.val_mask, data.test_mask, 'cpu')

In [None]:
df1 = graph_stats(train_data_LP['paper', 'is', 'label'].edge_index.clone(), train_data_LP['label', 'rev_is', 'paper'].edge_index.clone(), og_data.clone())

In [None]:
df1_means = df1.groupby("Is Label Node")[["Original 1-hop", "Original 2-hop", "Original 3-hop", "Original 1-hop HP", "Original 2-hop HP", "Original 3-hop HP", "Rewired 1-hop", "Rewired 2-hop", "Rewired 3-hop", "Rewired 2-hop HP", "Rewired 3-hop HP"]].mean()

In [None]:
df1_means