In [1]:
import torch_geometric.datasets as datasets
import torch_geometric.transforms as T
import networkit as nk
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import warnings

def summarize_graph(g):
    print(f'Nodes: {g.numberOfNodes()}')
    print(f'Edges: {g.numberOfEdges()}')

def degree_distribution(g):
    degree_sequence = []
    for u in g.iterNodes():
        degree_sequence.append(g.degree(u))

    degree_count = np.bincount(degree_sequence)

    plt.bar(range(len(degree_count)), degree_count)
    plt.title("Degree Distribution")
    plt.xlabel("Degree")
    plt.ylabel("Count")
    plt.show()

def sparsify(graph, sparsifier, ratio=0.5):
    graph.indexEdges()

    sparsifier = nk.sparsification.LocalDegreeSparsifier()
    sparse_graph = sparsifier.getSparsifiedGraph(graph, ratio)

    return sparse_graph

In [2]:
dataset = datasets.Planetoid(root='data/Planetoid', name='PubMed', transform=T.NormalizeFeatures())
data = dataset[0]

edge_index = data.edge_index
node_features = data.x if data.x is not None else None
edges = edge_index.t().tolist()

num_nodes = data.num_nodes
nk_graph = nk.graph.Graph(num_nodes, weighted=False, directed=False)
for u, v in edges:
    nk_graph.addEdge(u, v)

nk_graph.indexEdges()

node_features.shape

torch.Size([2708, 1433])

In [3]:
summarize_graph(nk_graph)

Nodes: 2708
Edges: 10556


In [4]:
lds_g = sparsify(nk_graph, nk.sparsification.LocalDegreeSparsifier())
summarize_graph(lds_g)

Nodes: 2708
Edges: 5482


In [5]:
warnings.filterwarnings('ignore')
nk.profiling.Profile.create(nk_graph).show()

In [6]:
nk.profiling.Profile.create(lds_g).show()

In [10]:
from torch_geometric.utils import from_networkit

edge_index, edge_weight = from_networkit(lds_g)

edge_index

tensor([[   0, 1862,    0,  ..., 2695, 2694, 2695],
        [1862,    0, 1862,  ..., 2694, 2695, 2694]])

In [12]:
scan_g = sparsify(nk_graph, nk.sparsification.SCANSparsifier())
summarize_graph(scan_g)

Nodes: 2708
Edges: 5482


In [19]:
scan_g_profile = nk.profiling.Profile.create(scan_g)
scan_g_profile.show()

In [None]:
scan_g_profile.getStat()