In [27]:
import numpy as np
from scipy.sparse import coo_matrix
import networkx as nx
import matplotlib.pyplot as plt

### Examining reddit_data.npz 

In [14]:
data = np.load('reddit/reddit_data.npz')
print("Columns of reddit_data:", data.files)

Columns of reddit_data: ['feature', 'node_types', 'node_ids', 'label']


In [15]:
features = data['feature'] 
node_types = data['node_types'] 
node_ids = data['node_ids']
labels = data['label']  

print("Features shape:", features.shape)
print("Node types shape:", node_types.shape)
print("Node IDs shape:", node_ids.shape)
print("Labels shape:", labels.shape)

Features shape: (232965, 602)
Node types shape: (232965,)
Node IDs shape: (232965,)
Labels shape: (232965,)


In [25]:
print(" 10 features:")
print(features[:10])

print("\n 10 node types:")
print(node_types[:10])

print("\n 10 node IDs:")
print(node_ids[:10])

print("\n 10 labels:")
print(labels[:10])

 10 features:
[[ 1.23341458  9.04301168 -0.92328005 ... -0.2578954   0.31119258
  -0.37721241]
 [-0.13855163 -0.20221879  0.12771628 ...  0.15627094  0.10478096
  -0.65342009]
 [-0.13304173 -0.1962387  -0.02956016 ...  0.03580163  0.28636732
   0.27441325]
 ...
 [-0.13855163 -0.20520884 -2.05916421 ...  0.08524973  0.65964959
  -0.2079452 ]
 [-0.15508134 -0.19922874  4.0886282  ... -0.1134238  -0.14133508
   0.58370981]
 [-0.15508134 -0.20221879  0.93078404 ... -0.60601635  3.21872952
  -0.83696218]]

 10 node types:
[3 1 3 1 1 2 1 2 1 1]

 10 node IDs:
[0 1 2 3 4 5 6 7 8 9]

 10 labels:
[30 17 18 23 22 15 33 14 38 18]


In [16]:
unique_types, counts = np.unique(node_types, return_counts=True)
print("Node type distribution:", dict(zip(unique_types, counts)))

unique_labels, label_counts = np.unique(labels, return_counts=True)
print("Label distribution:", dict(zip(unique_labels, label_counts)))

Node type distribution: {1: 153431, 2: 23831, 3: 55703}
Label distribution: {0: 13101, 1: 3550, 2: 3302, 3: 15181, 4: 2322, 5: 3597, 6: 3952, 7: 2138, 8: 11187, 9: 2246, 10: 4928, 11: 2964, 12: 1696, 13: 2731, 14: 4854, 15: 28272, 16: 1003, 17: 2639, 18: 13999, 19: 10308, 20: 1596, 21: 4066, 22: 8222, 23: 12146, 24: 328, 25: 1659, 26: 4239, 27: 5962, 28: 4673, 29: 5101, 30: 2846, 31: 4570, 32: 1575, 33: 4960, 34: 3429, 35: 4202, 36: 4180, 37: 4233, 38: 12797, 39: 3099, 40: 5112}


### Features of Reddit_Data

Features shape: (232965, 602): 232,965 Nodes with 602 features each 

Node IDs shape: Each Node have a unique ID 

Label distribution: Number of nodes in each label class (40 Classes in total)

### Examining reddit_graph.npz
This contains the structure of the graph - Represent it as a sparse matrix and examine the connectedness 

In [20]:
graph_data = np.load('reddit/reddit_graph.npz')

print("Columns of reddit_graph:", graph_data.files)


Columns of reddit_graph: ['format', 'col', 'data', 'shape', 'row']


In [21]:
row = graph_data['row']
col = graph_data['col']
data_values = graph_data['data']
shape = tuple(graph_data['shape'])

In [23]:
sparse_matrix = coo_matrix((data_values, (row, col)), shape=shape)

print("Sparse matrix shape:", sparse_matrix.shape)
print("Number of non-zero elements:", sparse_matrix.nnz)

density = sparse_matrix.nnz / (sparse_matrix.shape[0] * sparse_matrix.shape[1]) # check number of non-zero / total entries 
print("Matrix density:", density)


Sparse matrix shape: (232965, 232965)
Number of non-zero elements: 114615892
Matrix density: 0.002111852009048774


### Visualizing the Graph 

In [37]:
def create_reddit_graph(sparse_matrix):
    """Creates Reddit Graph from the sparse matrix """
    G = nx.Graph()
    row, col = sparse_matrix.row, sparse_matrix.col
    
    edges = list(zip(row, col))
    G.add_edges_from(edges)  
    
    return G

def visualize_reddit_graph(G, labels):
    """Visualizes Reddit Graph with Class Labels"""
    plt.figure(figsize=(12, 12))
    pos = nx.spring_layout(G, seed=42)
    nx.draw(G, pos, node_size=10, node_color='blue', edge_color='gray', with_labels=False)
    
    nx.draw_networkx_labels(G, pos, labels, font_size=8, font_color='red')
    
    plt.show()

In [35]:
# Create the label dictionary 
reddit_labels = {node_ids[i]: labels[i] for i in range(len(node_ids))}

print(len(reddit_labels)) # number of nodes 

unique_labels = set(labels)
print("number unique labels:", len(unique_labels)) 

232965
number unique labels: 41


In [38]:
G = create_reddit_graph(sparse_matrix)

visualize_reddit_graph(G, reddit_labels)