Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Line-Graph-based higher-order edge index for CPU and GPU #132

Closed
wants to merge 15 commits into from

Conversation

M-Lampert
Copy link
Contributor

@M-Lampert M-Lampert commented Mar 1, 2024

This PR will eventually add the new implementation idea for the higher-order transformations based on the message-passing framework. For now it is only a proof of concept (in dag_conv_test.ipynb) that will be extended at a later time.

It is based on the branch from #128 to enable runtime comparisons

@M-Lampert M-Lampert changed the title Fast GPU-based *Message Passing* higher-order edge index Fast GPU-based Message Passing higher-order edge index Mar 1, 2024
@IngoScholtes
Copy link
Member

IngoScholtes commented Mar 6, 2024

Following our discussion of yesterday (and thanks to Moritz' useful pointer to the LineGraph transformation in pyG), I now found a much simpler solution (no for loops, no message passing) that should also be fast ...

The following code calculates a k-th-order edge index for a given DAG and maps back the node indices to tensors of first-order node sequences, which preserves the structure and semantics of our higher-order edge indices.

Mapping the indices back to tensors of first-order node indices was actually a bit of a head-scratcher, but the function below does this efficiently.

def map_higher_order_index(edge_indices, k):

    ei = edge_indices[k].reshape(2,-1,1)
    
    j = 0
    for i in range(k-1, 0, -1):
        src_edge, tgt_edge = ei
        src = edge_indices[i][:,src_edge]
        tgt = edge_indices[i][:,tgt_edge]
        if j == 0:
            ei = torch.cat([src, tgt], dim=2)
        else:
            ei = torch.cat([src[:,:,:j], tgt], dim=2)
        j -= 1
    return ei

def lift_order_dag(dag, k):
    d = dag.data
    lg = LineGraph()
    edge_indices = {}
    edge_indices[1] = d.edge_index
    for i in range(1, k):
        d = lg(d)
        edge_indices[i+1] = d.edge_index

    return map_higher_order_index(edge_indices, k)

As you see, the code is very simple. Also, I think I have a simple solution how we can do this for an arbitrarily large collection of walks in a single shot. Will push this solution later.

I will do some performance tests later.

@M-Lampert
Copy link
Contributor Author

So I did some performance tests on the PyG version and the Message Passing (MP) one. The PyG version is very slow on GPU (probably because of some copying between devices due to the for-loop). The message passing version works with random graphs that have about 20 000 nodes and 450 000 edges on my laptop (8GB GPU-RAM) and is almost ten times faster as the PyG version on CPU:

Nodes: 21000, Edges: 441000
	PyG: 4.61s
	MP: 2.16s
	CUDA PyG: 74.25s
	CUDA MP: 0.51s

In theory, I was also able to convert the idea behind the code from PyG to torch functions without for-loops (see 497c519). The problem is that I have to construct a (M, M)-Matrix which is even less memory efficient than the MP approach.

So in my opinion, the MP approach is still the way to go. Normally for this graph size you would also typically not train a GNN on the whole graph but instead on subgraphs sampled with some kind of NeighborLoader or LinkNeighborLoader. I think we could also do this here and then compute the transformation (since we only need the two hop neighborhood of every node) for subgraphs and then aggregate the subgraphs afterwards.

Since the LineGraph is a functionality of PyG, we could also open an issue there and see if they have some more insights into this but I do not think that we will find a solution that is much more efficient (memory- and speed-wise) than what we have now.

@IngoScholtes
Copy link
Member

Thanks a lot for the evaluation, very insightful! I'll have a look later today!

@M-Lampert
Copy link
Contributor Author

Correction: I finally found a viable solution that is based on the indexing tricks that are used by PyG's LineGraph but solely based on tensor-operations. 🥳
The runtime should be in O(n*log(n)) (for initial sorting) or O(m_ho) (where m_ho is the number of higher order edges) and the memory consumption in O(m_ho). You can find the method for now in the second cell of the Jupyter Notebook docs/tutorial/dag_conv_test.ipynb.
As for the runtime evaluation (The new method is called "Indexing" here and I also added a simple forward pass for 1-layer SAGEConv as a baseline):

Iteration 1: Nodes: 21000, Edges: 441000
	GNN baseline: 0.0237576961517334s
	PyG: 4.8754637241363525s
	Indexing: 0.053536415100097656s
	CUDA GNN baseline: 0.0019686222076416016s
	CUDA PyG: 74.7792739868164s
	CUDA Indexing: 0.01907491683959961s
	CUDA MP: 0.5287909507751465s
Iteration 1: Nodes: 21000, Edges: 651000
	GNN baseline: 0.03194832801818848s
	PyG: 6.880799770355225s
	Indexing: 0.11534667015075684s
	CUDA GNN baseline: 0.0008349418640136719s
	CUDA PyG: 108.8860068321228s
	CUDA Indexing: 0.005877971649169922s
        CUDA MP: OOM-Error

@M-Lampert M-Lampert changed the title Fast GPU-based Message Passing higher-order edge index Fast Line-Graph-based higher-order edge index for CPU and GPU Mar 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants