In [29]:
import torch
from torch_geometric.data import Data
import torch_geometric.transforms as T
from itertools import combinations, chain

def add_metagraph(graph):
    '''
    Add a metagraph to the graph
    Needs to be updated after each linegraph iteration
    '''
    assert hasattr(graph, 'edge_index'), 'Edge index does not exist in graph'
    
    graph.metagraph = [[{graph.edge_index[0,i].item(), graph.edge_index[1,i].item()} for i in range(graph.edge_index.shape[1])]]
    return graph

def gpu_update_metagraph(graph):
    '''
    Efficiently updates the metagraph after each linegraph iteration
    '''
    assert hasattr(graph, 'metagraph'), 'Metagraph does not exist in graph'
    assert hasattr(graph, 'edge_index'), 'Edge index does not exist in graph'

    edge_index = graph.edge_index.cpu().numpy()  # Move to CPU for efficient indexing
    metagraph_np = graph.metagraph  # Assuming metagraph is already a CPU list of sets

    # Efficiently compute new metagraph using list comprehension
    new_metagraph = [metagraph_np[-1][a] | metagraph_np[-1][b] for a, b in edge_index.T]

    graph.metagraph.append(new_metagraph)  # Store the updated metagraph
    return graph

linegraph_trafo = T.LineGraph()

def build_fc_graph(n_nodes):
    '''
    Builds a fully connected graph with n_nodes, no self-loops and only edges in one direction
    '''
    graph = Data()
    graph.x = torch.tensor(range(n_nodes))
    graph.num_nodes = n_nodes
    
    edge_index_list = list(combinations(range(n_nodes),r=2))
    graph.edge_index = torch.tensor(edge_index_list).T
    return graph


graph = build_fc_graph(5)
graph = add_metagraph(graph)
print(graph)
print(graph.metagraph)
print(len(list(chain(*graph.metagraph))))
graph = linegraph_trafo(graph)
graph = gpu_update_metagraph(graph)
print(graph)
print(graph.metagraph)
print(len(list(chain(*graph.metagraph))))
graph = linegraph_trafo(graph)
graph = gpu_update_metagraph(graph)
print(graph)
print(graph.metagraph)
print(len(list(chain(*graph.metagraph))))

def produce_longest_track_candidate(flattened_metagraph):
    longest_track_candidate = max(flattened_metagraph, key=len)
    flattened_metagraph = [s for s in flattened_metagraph if s != longest_track_candidate and s.isdisjoint(longest_track_candidate)]

    return longest_track_candidate, flattened_metagraph

def get_all_track_candidates(metagraph):
    flattened_metagraph = list(chain(*metagraph))

    track_candidates = []

    while flattened_metagraph:
        largest_set, flattened_metagraph = produce_longest_track_candidate(flattened_metagraph)

        track_candidates.append(largest_set)

    return track_candidates

print(get_all_track_candidates(graph.metagraph))


Data(x=[5], num_nodes=5, edge_index=[2, 10], metagraph=[1])
[[{0, 1}, {0, 2}, {0, 3}, {0, 4}, {1, 2}, {1, 3}, {1, 4}, {2, 3}, {2, 4}, {3, 4}]]
10
Data(num_nodes=10, edge_index=[2, 10], metagraph=[2])
[[{0, 1}, {0, 2}, {0, 3}, {0, 4}, {1, 2}, {1, 3}, {1, 4}, {2, 3}, {2, 4}, {3, 4}], [{0, 1, 2}, {0, 1, 3}, {0, 1, 4}, {0, 2, 3}, {0, 2, 4}, {0, 3, 4}, {1, 2, 3}, {1, 2, 4}, {1, 3, 4}, {2, 3, 4}]]
20
Data(num_nodes=10, edge_index=[2, 5], metagraph=[3])
[[{0, 1}, {0, 2}, {0, 3}, {0, 4}, {1, 2}, {1, 3}, {1, 4}, {2, 3}, {2, 4}, {3, 4}], [{0, 1, 2}, {0, 1, 3}, {0, 1, 4}, {0, 2, 3}, {0, 2, 4}, {0, 3, 4}, {1, 2, 3}, {1, 2, 4}, {1, 3, 4}, {2, 3, 4}], [{0, 1, 2, 3}, {0, 1, 2, 4}, {0, 1, 3, 4}, {0, 2, 3, 4}, {1, 2, 3, 4}]]
25
[{0, 1, 2, 3}]


In [None]:
import h5py

def save_data_to_hdf5(data, file_name):
    with h5py.File(file_name, 'w') as f:
        for i, sublist in enumerate(data):
            # Convert each set to a tuple (or list) to store in HDF5
            sublist_data = [tuple(s) for s in sublist]
            f.create_dataset(f'sublist_{i}', data=sublist_data)

def load_data_from_hdf5(file_name):
    data = []
    with h5py.File(file_name, 'r') as f:
        for key in f.keys():
            sublist_data = f[key][:]
            # Convert each tuple back to a set
            sublist = [set(s) for s in sublist_data]
            data.append(sublist)
    return data

save_data_to_hdf5(graph.metagraph, 'test_metagraph.h5')
loaded_data = load_data_from_hdf5('test_metagraph.h5')
print(loaded_data)


[[{0, 1}, {0, 2}, {0, 3}, {0, 4}, {1, 2}, {1, 3}, {1, 4}, {2, 3}, {2, 4}, {3, 4}], [{0, 1, 2}, {0, 1, 3}, {0, 1, 4}, {0, 2, 3}, {0, 2, 4}, {0, 3, 4}, {1, 2, 3}, {1, 2, 4}, {1, 3, 4}, {2, 3, 4}], [{0, 1, 2, 3}, {0, 1, 2, 4}, {0, 1, 3, 4}, {0, 2, 3, 4}, {1, 2, 3, 4}]]
