In [1]:
import gc
import time
from typing import Optional

from tqdm import trange
import torch
import numpy as np
from torch import Tensor
from torch_geometric import EdgeIndex
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.transforms import LineGraph
from torch_geometric.utils import scatter, cumsum, coalesce, degree
from torch_geometric.utils import to_torch_sparse_tensor, to_torch_csc_tensor, to_torch_csr_tensor, to_torch_coo_tensor


import pathpyG as pp

In [2]:
def lift_order_edge_index(edge_index: EdgeIndex | torch.Tensor, num_nodes: int = None) -> torch.Tensor:
    num_edges = edge_index.size(1)
    if num_nodes is None:
        num_nodes = int(edge_index.max().item() + 1)
    if isinstance(edge_index, torch.Tensor):
        edge_index = EdgeIndex(edge_index, sparse_size=(num_nodes, num_nodes))
    if not edge_index.is_sorted_by_row:
        print("Sorting edge index")
        edge_index = edge_index.sort_by("row")[0]
    
    i = torch.arange(num_edges, dtype=torch.long, device=edge_index.device)
    outdegree = degree(edge_index[0], dtype=torch.long, num_nodes=num_nodes)
    # For each center node, we need to combine each outgoing edge with each incoming edge
    # We achieve this by creating `outdegree` edges for each destination node of the old edge index
    outdegree_per_dst = outdegree[edge_index[1]]
    # We use each edge from the edge index as new nodes and assign the new indices in the order of the original edge index
    # Each higher order node has one outgoing edge for each outgoing edge of the original destination node
    # Since we keep the ordering, we can just repeat each node using the outdegree_per_dst tensor and the bucketize function
    # First we create a tensor of shape (num_edges,) that contains the index at which the next node index starts for each node
    # (Note here we use Torch's cumsum function because it does not add a 0 at the beginning of the tensor)
    expanded_ptrs = torch.cumsum(outdegree_per_dst, dim=0)
    # Using this tensor with the starting indices, we can create a new tensor of shape (num_ho_edges,) that contains the index of the new higher order source nodes
    ho_edge_srcs = torch.bucketize(torch.arange(expanded_ptrs[-1], dtype=torch.long, device=edge_index.device), expanded_ptrs, right=True)

    # Get the corresponding destination nodes for each higher order edge:
    # For each node, we calculate pointers of shape (num_nodes,) that indicate the start of the original edges (new higher order nodes) that have the node as source node
    # (Note here we use PyG's cumsum function because it adds a 0 at the beginning of the tensor)
    ptrs = cumsum(outdegree, dim=0)[:-1]
    # Use these pointers to get the start of the edges for each higher order source node and repeat it `outdegree` times
    # Since we keep the ordering, all new higher order edges that have the same source node are indexed consecutively
    ho_edge_dsts = ptrs[edge_index[1]][ho_edge_srcs]
    # Since the above only repeats the start of the edges, we need to add (0, 1, 2, 3, ...) for all `outdegree` edges to get the correct destination nodes
    # We can achieve this by constructing a matrix of shape (num_nodes, num_edges) that contain the numbers 0, 1, 2, ... for each row
    # Next, we get a mask that is True for all elements of the matrix that are smaller than the outdegree of the corresponding node
    # Finally, we use this mask to get the correct numbers we need to add for each destination node in the edge index
    ho_edge_dsts += i.unsqueeze(0).expand(num_edges, -1)[(i.unsqueeze(0).expand(num_nodes, -1) < outdegree.unsqueeze(1))[edge_index[1]]]

    return torch.stack([ho_edge_srcs, ho_edge_dsts], dim=0)

# DeBruijn Transformations using GNNs

In [3]:
class ConcatAggregation(Aggregation):
    
    def __init__(self):
        super().__init__()

    # Not sure how this aggregation works, only that it does
    # Inspired by the LSTMAggregation implementation in PyG:
    # https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/aggr/lstm.html
    @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
    def forward(
        self,
        x: Tensor,
        index: Optional[Tensor] = None,
        ptr: Optional[Tensor] = None,
        dim_size: Optional[int] = None,
        dim: int = -2,
        max_num_elements: Optional[int] = None,
    ) -> Tensor:

        # Concetenate all messages with padding value -1
        x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements, fill_value=-1)
        return x


class DeBruijnTransform(MessagePassing):
    # freq_aggr is a string specifying how we should count each path in the network
    # This can either be `propagation` where every path is counted once
    # or `diffusion` where every path is counted depending on the number of outgoing edges
    # of the source node.
    def __init__(self, freq_aggr: str = "propagation"):
        super().__init__(aggr=ConcatAggregation(), flow="target_to_source")
        self.freq_aggr = freq_aggr

    def forward(self, node_idx, edge_index, edge_attr=None):
        # Sort edge_index because otherwise propagate will not work in combination with the ConcatAggregation
        edge_index = coalesce(edge_index, sort_by_row=True)
        # Set the dimension along which the node feature tensor is expected
        # This is the default value, but we need to set it explicitly here
        # because we change it later
        self.node_dim = -2
        # Update the node idx by passing the messages and aggregating them
        # In the message function, we concatenate the node idx of the source node
        # with the node idx of the target node
        # In the aggregation function, we concatenate all messages from the neighbors
        # The resulting feature for every node is a tensor of shape (max_degree, 2)
        # where max_degree is the maximum degree of the graph
        # If a node has less neighbors than max_degree, the remaining entries are filled with -1
        node_idx_set_higher_order = self.propagate(edge_index, node_idx=node_idx)
        # Since our node features changed from shape (N, 1) to (N, max_degree, 2)
        # we need to set the node_dim to -3
        self.node_dim = -3
        # We use the function that is used to update node features to create the higher order edges
        edge_index_higher_order, edge_attr_higher_order = self.edge_updater(
            edge_index, 
            node_idx=node_idx_set_higher_order,
            edge_attr=edge_attr
            )

        if edge_attr_higher_order is not None:
            return edge_index_higher_order, edge_attr_higher_order
        return edge_index_higher_order

    def message(self, node_idx_i, node_idx_j):
        # Concatenate the node idx of the source node with the node idx of the target node
        # The shape changes from (N, k) to (N, k+1) where k is the order of the nodes before
        return torch.cat([node_idx_i, node_idx_j[:, -1:]], dim=-1)
    
    def edge_update(self, node_idx_i, node_idx_j, edge_attr=None) -> tuple[Tensor, Optional[Tensor]]:
        # We take the higher order node idx sets that have been created for each node adjacent
        # to the edge (node_idx_i, node_idx_j) and repeat each node idx across different dimensions
        # so that we can compare them with each other
        #
        # Example:
        #
        #   node_idx_i = [[0, 3], [1, 3], [2, 3]]
        #   node_idx_j = [[3, 4], [2, 4], [-1, -1]]
        #
        #   strided_node_idx_i = [[
        #                           [0, 3],
        #                           [1, 3], 
        #                           [2, 3]
        #                         ],
        #                         [
        #                           [0, 3],
        #                           [1, 3],
        #                           [2, 3]
        #                         ],
        #                         [
        #                           [0, 3],
        #                           [1, 3],
        #                           [2, 3]
        #                         ]]
        #   strided_node_idx_j = [[
        #                           [3, 4],
        #                           [3, 4],
        #                           [3, 4]
        #                         ],
        #                         [
        #                           [2, 4],
        #                           [2, 4],
        #                           [2, 4]
        #                         ],
        #                         [
        #                           [-1, -1],
        #                           [-1, -1],
        #                           [-1, -1]
        #                         ]]
        strided_node_idx_i = node_idx_i.unsqueeze(1).expand(-1, node_idx_j.size(1), -1, -1)
        strided_node_idx_j = node_idx_j.unsqueeze(2).expand(-1, -1, node_idx_i.size(1), -1)
        # Only create an higher order edge if the target node idx of the first edge is equal to
        # the source node idx of the second edge
        edge_mask = (strided_node_idx_i[:, :, :, 1:] == strided_node_idx_j[:, :, :, :-1]).all(dim=-1)
        # Also, we need to remove the -1 padding values
        padd_mask = (
            (strided_node_idx_i[:, :, :] != -1).all(dim=-1) &
            (strided_node_idx_j[:, :, :] != -1).all(dim=-1)
        )
        # For the above, the following mask is:
        #
        #   mask = [[True, True, True],
        #           [False, False, False],
        #           [False, False, False]]
        mask = (edge_mask & padd_mask)
        # Concetenate the remaining higher order edges to create a new edge index
        higher_order_edges = torch.cat([strided_node_idx_i[mask].unsqueeze(0), strided_node_idx_j[mask].unsqueeze(0)], dim=0)
        
        if edge_attr is not None:
            # If edge attributes are given, we need to fit the shape to apply the same mask
            strided_edge_attr = edge_attr.unsqueeze(1).unsqueeze(1).expand(-1, node_idx_j.size(1), node_idx_i.size(1))
            # Apply the mask and use some way to combine the edge attributes of source and target node
            # For now, we just take the edge attribute of the source node
            if self.freq_aggr == "propagation":
                higher_order_edge_attr = strided_edge_attr[mask]
            elif self.freq_aggr == "diffusion":
                higher_order_edge_attr = strided_edge_attr[mask] / mask.sum(dim=(1,2), keepdim=True).expand(-1, node_idx_j.size(1), node_idx_i.size(1))[mask]
            else:
                raise ValueError(f"Unknown frequency aggregation method {self.freq_aggr}")
            return higher_order_edges, higher_order_edge_attr
        return higher_order_edges, None

## General evaluation of correctness and runtime

In [4]:
from time import time
from torch_geometric.testing import get_random_edge_index

gnn_conv = SAGEConv(128, 16)
cuda_gnn_conv = SAGEConv(128, 16).to("cuda")
line_graph_transform = LineGraph(force_directed=True)

GNN_baseline_times, cuda_GNN_baseline_times = [], []
PyG_times, MP_times, indexing_times = [], [], []
cuda_PyG_times, cuda_MP_times, cuda_indexing_times = [], [], []
for i in range(1, 201, 20):
    num_nodes = 1000*i
    x = torch.randn(num_nodes, 128)
    cuda_x = x.to("cuda")
    for j in range(1, 50, 10):
        gc.collect()
        torch.cuda.empty_cache()

        num_edges = num_nodes*j
        print(f"Iteration {int(i/20)}: Nodes: {num_nodes}, Edges: {num_edges}")
        edge_index = get_random_edge_index(num_nodes, num_nodes, num_edges)
        edge_index = EdgeIndex(edge_index, sparse_size=(num_nodes, num_nodes))
        edge_index = coalesce(edge_index, num_nodes=num_nodes)

        # GNN baseline
        
        t = time()
        gnn_conv(x, edge_index)
        GNN_baseline_times.append(time() - t)
        print(f"\tGNN baseline: {GNN_baseline_times[-1]}s")

        t = time()
        PyG_line_graph_data = line_graph_transform(Data(edge_index=edge_index, num_nodes=num_nodes))
        PyG_times.append(time() - t)
        PyG_line_graph = PyG_line_graph_data.edge_index
        print(f"\tPyG: {PyG_times[-1]}s")

        t = time()
        MP_line_graph = DeBruijnTransform()(torch.arange(num_edges).unsqueeze(-1), edge_index)
        MP_times.append(time() - t)
        print(f"\tMP: {MP_times[-1]}s")

        # t = time()
        # indexing_line_graph = lift_order_edge_index(edge_index, num_nodes)
        # indexing_times.append(time() - t)
        # print(f"\tIndexing: {indexing_times[-1]}s")

        cuda_edge_index = edge_index.to("cuda")

        t = time()
        cuda_gnn_conv(cuda_x, cuda_edge_index)
        cuda_GNN_baseline_times.append(time() - t)
        print(f"\tCUDA GNN baseline: {cuda_GNN_baseline_times[-1]}s")

        t = time()
        cuda_PyG_line_graph_data = line_graph_transform(Data(edge_index=cuda_edge_index, num_nodes=num_nodes))
        cuda_PyG_times.append(time() - t)
        cuda_PyG_line_graph = cuda_PyG_line_graph_data.edge_index
        print(f"\tCUDA PyG: {cuda_PyG_times[-1]}s")

        t = time()
        cuda_MP_line_graph = DeBruijnTransform()(torch.arange(num_edges).unsqueeze(-1).to("cuda"), cuda_edge_index)
        cuda_MP_times.append(time() - t)
        print(f"\tCUDA MP: {cuda_MP_times[-1]}s")

        # t = time()
        # cuda_indexing_line_graph = lift_order_edge_index(cuda_edge_index, num_nodes)
        # cuda_indexing_times.append(time() - t)
        # print(f"\tCUDA Indexing: {cuda_indexing_times[-1]}s")

        # check if the line graphs are equal
        # if not (edge_index.T[PyG_line_graph] == MP_line_graph).all():
        #     print(f"Iteration {i}: Message Passing and PyG are equal")
        #     print((edge_index.T[PyG_line_graph] != MP_line_graph).nonzero())
        #     break
        # if not (PyG_line_graph == indexing_line_graph).all():
        #     print(f"Iteration {i}: Indexing and PyG are not equal")
        #     print(PyG_line_graph)
        #     print(indexing_line_graph)
        #     print((PyG_line_graph != indexing_line_graph).nonzero())
        #     print(edge_index)
        #     break
        # if not (cuda_PyG_line_graph == cuda_indexing_line_graph).all():
        #     print(f"Iteration {i}: CUDA Indexing and CUDA PyG are not equal")
        #     break

print(f"Avg PyG time: {np.mean(PyG_times)}s +/- {np.std(PyG_times)}s")
print(f"Avg MP time: {sum(MP_times) / len(MP_times)}")
print(f"Avg Indexing time: {np.mean(indexing_times)}s +/- {np.std(indexing_times)}s")
# print(f"Avg CUDA PyG time: {np.mean(cuda_PyG_times)}s +/- {np.std(cuda_PyG_times)}s")
print(f"Avg CUDA MP time: {sum(cuda_MP_times) / len(cuda_MP_times)}")
# print(f"Avg CUDA Indexing time: {np.mean(cuda_indexing_times)}s +/- {np.std(cuda_indexing_times)}s")

Iteration 0: Nodes: 1000, Edges: 1000
	GNN baseline: 0.009559392929077148s
	PyG: 0.011368513107299805s
	MP: 0.0021982192993164062s
	CUDA GNN baseline: 0.24655723571777344s
	CUDA PyG: 0.17524504661560059s
	CUDA MP: 0.022670745849609375s
Iteration 0: Nodes: 1000, Edges: 11000
	GNN baseline: 0.002803802490234375s
	PyG: 0.09699869155883789s
	MP: 0.025521039962768555s
	CUDA GNN baseline: 0.001293182373046875s
	CUDA PyG: 1.4798369407653809s
	CUDA MP: 0.0063130855560302734s
Iteration 0: Nodes: 1000, Edges: 21000
	GNN baseline: 0.0033135414123535156s
	PyG: 0.18238019943237305s
	MP: 0.08893465995788574s
	CUDA GNN baseline: 0.001046895980834961s
	CUDA PyG: 2.9825801849365234s
	CUDA MP: 0.018616437911987305s
Iteration 0: Nodes: 1000, Edges: 31000
	GNN baseline: 0.001140594482421875s
	PyG: 0.256192684173584s
	MP: 0.2131354808807373s
	CUDA GNN baseline: 0.001361846923828125s
	CUDA PyG: 4.372312784194946s
	CUDA MP: 0.03851914405822754s
Iteration 0: Nodes: 1000, Edges: 41000
	GNN baseline: 0.00192570

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.97 GiB. GPU 0 has a total capacty of 8.00 GiB of which 1.59 GiB is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 4.25 GiB is allocated by PyTorch, and 1.08 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## Toy Example

In [None]:
edge_index = torch.tensor([[0, 0, 1, 1, 3, 4, 1, 6, 5],
                           [1, 3, 2, 3, 4, 5, 6, 5, 7]])
node_idx = torch.arange(edge_index.max() + 1).reshape(-1, 1)
edge_index_2_fast = DeBruijnTransform()(node_idx, edge_index)
print(edge_index_2_fast)

tensor([[[0, 1],
         [0, 1],
         [0, 1],
         [0, 3],
         [1, 3],
         [1, 6],
         [3, 4],
         [4, 5],
         [6, 5]],

        [[1, 2],
         [1, 3],
         [1, 6],
         [3, 4],
         [3, 4],
         [6, 5],
         [4, 5],
         [5, 7],
         [5, 7]]])


In [None]:
edge_index_2 = pp.DAGData.lift_order_dag(edge_index.unsqueeze(-1))
print(edge_index_2)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA_gather)

In [None]:
edge_index_new = lift_order_edge_index(EdgeIndex(edge_index))
print(edge_index_new)

tensor([[[0, 1],
         [0, 1],
         [0, 1],
         [0, 3],
         [1, 3],
         [1, 6],
         [3, 4],
         [4, 5],
         [6, 5]],

        [[1, 2],
         [1, 3],
         [1, 6],
         [3, 4],
         [3, 4],
         [6, 5],
         [4, 5],
         [5, 7],
         [5, 7]]])


In [None]:
print((edge_index_2_fast == edge_index_new).all())

tensor(True)


### With Edge Weights

Depending on how you count each walk, you will get different statistics. We can choose the aggregation via `freq_aggr` to be either "propagation", i.e. each walk counts with its weight, or "diffusion" i.e. each walk is counted with the probability of a random walker starting at the first node to end up in the last. 

In [None]:
edge_index = torch.tensor([[0, 0, 1, 1, 3, 4, 1, 6, 5],
                           [1, 3, 2, 3, 4, 5, 6, 5, 7]])
node_idx = torch.arange(edge_index.max() + 1).reshape(-1, 1)
edge_attr = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32)
edge_index_2_fast, edge_attr_2 = DeBruijnTransform(freq_aggr="diffusion")(node_idx, edge_index, edge_attr)
print(edge_index_2_fast)

tensor([[[0, 1],
         [0, 1],
         [0, 1],
         [0, 3],
         [1, 3],
         [1, 6],
         [3, 4],
         [4, 5],
         [6, 5]],

        [[1, 2],
         [1, 3],
         [1, 6],
         [3, 4],
         [3, 4],
         [6, 5],
         [4, 5],
         [5, 7],
         [5, 7]]])


In [None]:
print(edge_attr_2)

tensor([0.3333, 0.3333, 0.3333, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


### 3rd order
Currently, the transformation works only with a standard edge index. The good thing is, you can still pass in the higher_order node_idx and then the output is directly a third order edge index.

In [None]:
# Transform the edge index since the DeBruijnTransform only works for the normal edge index
edges_2 = edge_index_2_fast.reshape(-1, 2)
uniques, inverse_idx = edges_2.unique(dim=0, return_inverse=True)
transformed_edge_index_2_fast = inverse_idx.reshape(2, -1)

In [None]:
edge_index_3_fast = DeBruijnTransform()(uniques, transformed_edge_index_2_fast)
print(edge_index_3_fast)

tensor([[[0, 1, 3],
         [0, 1, 6],
         [0, 3, 4],
         [1, 3, 4],
         [1, 6, 5],
         [3, 4, 5]],

        [[1, 3, 4],
         [1, 6, 5],
         [3, 4, 5],
         [3, 4, 5],
         [6, 5, 7],
         [4, 5, 7]]])


In [None]:
edge_index_3 = pp.DAGData.lift_order_dag(edge_index_2)
print(edge_index_3)

tensor([[[0, 1, 3],
         [0, 1, 6],
         [0, 3, 4],
         [1, 3, 4],
         [3, 4, 5],
         [1, 6, 5]],

        [[1, 3, 4],
         [1, 6, 5],
         [3, 4, 5],
         [3, 4, 5],
         [4, 5, 7],
         [6, 5, 7]]])


# Exponentionally Large DAG

In [None]:
layers = 5
branches = 15

edges = []
prev_layer_nodes = [0]
j = 1
for _ in trange(layers):
    layer_nodes = []
    for node in prev_layer_nodes:
        for _ in range(branches):
            layer_nodes.append(j)
            edges.append((f"{node}", f"{j}"))
            j+=1
    prev_layer_nodes = layer_nodes

dag = pp.Graph.from_edge_list(edges)
dag_edge_index = dag.data.edge_index.unsqueeze(-1)

node_idx = torch.arange(dag_edge_index.max().item() + 1).unsqueeze(-1)
node_idx_gpu = node_idx.cuda()
dag_edge_index_gpu = dag.data.edge_index.cuda()

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:00<00:00, 18.74it/s]


### Current implementation

In [None]:
%timeit pp.DAGData.lift_order_dag(dag_edge_index)

1min 9s ± 3.52 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Message Passing based implementation (CPU)

In [None]:
%timeit DeBruijnTransform()(node_idx, dag.data.edge_index)

717 ms ± 14.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Message Passing based implementation (GPU)

In [None]:
%timeit DeBruijnTransform()(node_idx_gpu, dag_edge_index_gpu)

83.1 ms ± 2.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### New implementation (CPU)

In [None]:
%timeit lift_order_edge_index(dag.data.edge_index)

116 ms ± 2.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
%timeit lift_order_edge_index(dag_edge_index_gpu)

98.9 ms ± 12.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
data_gpu = dag.data.to("cuda")

In [None]:
%timeit linegraph(data_gpu)

KeyboardInterrupt: 

In [None]:
linegraph = LineGraph()
%timeit linegraph(dag.data)

6.82 s ± 167 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
# print((dag_edge_index_order_2.sort(dim=1)[0] == dag_edge_index_order_2_fast.sort(dim=1)[0]).all())

## Many Walks

In [None]:
n_walks = 10000
walk_length = 1000

walks = [list(range(walk_length)) for _ in range(n_walks)]
orig_walk = pp.WalkData()
for walk in walks:
    orig_walk.add_walk_seq(walk)

path_list = list(orig_walk.paths.values())
path_freq_tensor = torch.tensor(list(orig_walk.path_freq.values()))
mapping = pp.IndexMap()
nested_walk = pp.WalkDataNested(path_list, path_freq=path_freq_tensor, mapping=mapping)

  self.paths = nested_tensor(paths, dtype=torch.long)


### Original Walk Implementation

In [None]:
%timeit orig_walk.edge_index_k_weighted(2)

37.5 s ± 7.25 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Nested Tensor Implementation (CPU)

In [None]:
pp.config['torch']["device"] = "cpu"
%timeit nested_walk.edge_index_k_weighted(2)

466 ms ± 20.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Nested Tensor (GPU)

In [None]:
pp.config['torch']["device"] = "cuda"
cuda_path_list = [path.cuda() for path in path_list]
cuda_path_freq = path_freq_tensor.cuda()
cuda_nested_walk = pp.WalkDataNested(cuda_path_list, path_freq=cuda_path_freq, mapping=mapping)
%timeit cuda_nested_walk.edge_index_k_weighted(2)

440 ms ± 24.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Message Passing Implementation (CPU + GPU)

In [None]:
# We create a list of Data objects where each Data object contains the edge index of a path (could also be a DAG in theory)
data_list = [Data(edge_index=path.long(), num_nodes=walk_length) for path in path_list]
# We use a dataloader from PyG to combine all the edge indices into a single graph with multiple disjoint subgraphs
# If two paths share a node, the node is duplicated in the resulting graph and the new higher order edges need to be aggregated afterwards
# Note that due to the `batch_size` parameter, we can also do computations on a set of paths that are too large to fit into memory at once
walk_graph = next(iter(DataLoader(data_list, batch_size=n_walks)))
edge_index = walk_graph.edge_index
node_idx = torch.arange(edge_index.max() + 1).unsqueeze(-1)

The following measures the time to do the De Bruijn graph transformation for the edge index that contains all paths as disjunct subgraphs. Since the aggregations afterwards are omitted, the runtimes are not exactly comparable to the above. See the next section (With Weights and the Aggregation) for a full `edge_index_k_weighted` transformation.

In [None]:
%timeit DeBruijnTransform()(node_idx, edge_index)

576 ms ± 76.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
cuda_edge_index = edge_index.cuda()
cuda_node_idx = node_idx.cuda()
%timeit DeBruijnTransform()(cuda_node_idx, cuda_edge_index)

71.2 ms ± 5.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit lift_order_edge_index(edge_index)

2.24 s ± 26.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit lift_order_edge_index(cuda_edge_index)

128 ms ± 135 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### With Weights and the Aggregation

In [None]:
def edge_index_k_weighted(path_list, path_freq, aggregation="propagation", device="cuda"):
    data_list = [
        Data(
            edge_index=path.long(), 
            num_nodes=walk_length,
            edge_attr=torch.ones(path.size(1), dtype=torch.float32) * path_freq[i],
            node_idx=torch.arange(walk_length).unsqueeze(-1)
        ) for i, path in enumerate(path_list)
        ]
    walk_graph = next(iter(DataLoader(data_list, batch_size=n_walks, follow_batch=["node_idx"]))).to(device)
    edge_index = walk_graph.edge_index
    edge_attr = walk_graph.edge_attr
    node_idx = torch.arange(edge_index.max() + 1, device=device).unsqueeze(-1)
    edge_index_2, edge_attr_2 = DeBruijnTransform(aggregation)(node_idx, edge_index, edge_attr)
    orig_edge_index_2 = walk_graph.node_idx.squeeze()[edge_index_2]
    unique_edge_index_2, inverse_idx = orig_edge_index_2.unique(dim=1, return_inverse=True)
    edge_attr_2 = torch.zeros(unique_edge_index_2.size(1), device=device).index_add(0, inverse_idx, edge_attr_2)
    return unique_edge_index_2, edge_attr_2

In [None]:
%timeit edge_index_k_weighted(path_list, path_freq_tensor, device="cuda")

774 ms ± 48.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
edge_index = EdgeIndex(torch.tensor([
                                        [0, 0, 1, 0, 0, 2, 2, 3, 3, 4, 4],
                                        [1, 2, 2, 4, 6, 3, 4, 4, 5, 5, 6]
]), sparse_size=(7, 7))
edge_index = edge_index.sort_by("row")[0]
N = 7
A = edge_index.to_dense()

In [None]:
data = Data(edge_index=edge_index)
line_graph = LineGraph(force_directed=True)(data)
edge_index.T[line_graph.edge_index]

tensor([[[0, 1],
         [0, 2],
         [0, 2],
         [0, 4],
         [0, 4],
         [1, 2],
         [1, 2],
         [2, 3],
         [2, 3],
         [2, 4],
         [2, 4],
         [3, 4],
         [3, 4]],

        [[1, 2],
         [2, 3],
         [2, 4],
         [4, 5],
         [4, 6],
         [2, 3],
         [2, 4],
         [3, 4],
         [3, 5],
         [4, 5],
         [4, 6],
         [4, 5],
         [4, 6]]])

In [None]:
# Fast dense implementation
(A * A.unsqueeze(2)).nonzero()

tensor([[0, 1, 2],
        [0, 2, 3],
        [0, 2, 4],
        [0, 4, 5],
        [0, 4, 6],
        [1, 2, 3],
        [1, 2, 4],
        [2, 3, 4],
        [2, 3, 5],
        [2, 4, 5],
        [2, 4, 6],
        [3, 4, 5],
        [3, 4, 6]])

In [None]:
edge_index = torch.tensor([[0, 0, 1, 1, 2, 3, 3, 3, 4, 1, 6, 5],
                           [1, 3, 2, 3, 4, 2, 4, 7, 5, 6, 5, 7]])
num_nodes = int(edge_index.max().item() + 1)
num_edges = edge_index.size(1)
edge_index = EdgeIndex(edge_index, sparse_size=(8, 8)).sort_by("row")[0]

In [None]:
lift_order_edge_index(edge_index)

tensor([[ 0,  0,  0,  1,  1,  1,  2,  3,  3,  3,  4,  5,  6,  7,  9, 11],
        [ 2,  3,  4,  6,  7,  8,  5,  6,  7,  8, 11,  9,  5,  9, 10, 10]])

In [None]:
LineGraph()(Data(edge_index=edge_index, num_nodes=num_nodes)).edge_index

tensor([[ 0,  0,  0,  1,  1,  1,  2,  3,  3,  3,  4,  5,  6,  7,  9, 11],
        [ 2,  3,  4,  6,  7,  8,  5,  6,  7,  8, 11,  9,  5,  9, 10, 10]])