This notebook can be used for analyzing the statistics of graphs in datasets.

In [1]:
import torch
from torch_geometric.utils import to_networkx
import networkx as nx
from collections import Counter

def print_average_graph_statistics(graphs):
    total_num_nodes = 0
    total_num_edges = 0
    total_avg_degree = 0
    total_density = 0
    total_avg_clustering = 0
    connected_graphs = 0
    total_diameter = 0
    diameter_count = 0  # Count only for connected graphs
    label_counter = Counter()
    
    for data in graphs:
        # Convert PyTorch Geometric Data to NetworkX graph
        graph = to_networkx(data, to_undirected=True)
        
        # Number of nodes
        num_nodes = graph.number_of_nodes()
        total_num_nodes += num_nodes
        
        # Number of edges
        num_edges = graph.number_of_edges()
        total_num_edges += num_edges
        
        # Average degree
        if num_nodes > 0:
            avg_degree = sum(dict(graph.degree()).values()) / num_nodes
        else:
            avg_degree = 0
        total_avg_degree += avg_degree
        
        # Density
        density = nx.density(graph)
        total_density += density
        
        # Average clustering coefficient
        avg_clustering = nx.average_clustering(graph)
        total_avg_clustering += avg_clustering
        
        # Diameter (only if the graph is connected)
        if not graph.is_directed():
            is_connected = nx.is_connected(graph)
            if is_connected:
                connected_graphs += 1
                diameter = nx.diameter(graph)
                total_diameter += diameter
                diameter_count += 1
        
        # Update label counter
        if hasattr(data, 'y'):
            label_counter[data.y.item()] += 1
    
    num_graphs = len(graphs)
    
    print("Average graph statistics:")
    print(f"  Number of graphs: {num_graphs}")
    print(f"  Average number of nodes: {total_num_nodes / num_graphs:.2f}")
    print(f"  Average number of edges: {total_num_edges / num_graphs:.2f}")
    print(f"  Average degree: {total_avg_degree / num_graphs:.2f}")
    print(f"  Average density: {total_density / num_graphs:.4f}")
    print(f"  Average clustering coefficient: {total_avg_clustering / num_graphs:.4f}")
    
    if connected_graphs > 0:
        print(f"  Average diameter (connected graphs only): {total_diameter / diameter_count:.2f}")
    else:
        print("  No connected graphs found for diameter calculation")
    
    print("Label distribution (percentages):")
    for label, count in label_counter.items():
        percentage = (count / num_graphs) * 100
        print(f"  Label {label}: {percentage:.2f}%")

def read_graphs_from_file(file_path):
    data = torch.load(file_path)
    return data

In [3]:
file_path = 'path_to_graphs.pt'
graphs = read_graphs_from_file(file_path)
print_average_graph_statistics(graphs)

Average graph statistics:
  Number of graphs: 4488
  Average number of nodes: 73.03
  Average number of edges: 2378.26
  Average degree: 37.46
  Average density: 0.5130
  Average clustering coefficient: 0.8818
  Average diameter (connected graphs only): 1.87
Label distribution (percentages):
  Label 0: 52.14%
  Label 1: 15.71%
  Label 2: 32.15%
