In [1]:
from typing import Optional

from tqdm import trange
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.utils import coalesce

import pathpyG as pp

# DeBruijn Transformations using GNNs

In [2]:
class ConcatAggregation(Aggregation):
    
    def __init__(self):
        super().__init__()

    @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
    def forward(
        self,
        x: Tensor,
        index: Optional[Tensor] = None,
        ptr: Optional[Tensor] = None,
        dim_size: Optional[int] = None,
        dim: int = -2,
        max_num_elements: Optional[int] = None,
    ) -> Tensor:

        # Concetenate all messages with padding value -1
        x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements, fill_value=-1)
        return x


class DeBruijnTransform(MessagePassing):
    def __init__(self):
        super().__init__(aggr=ConcatAggregation(), flow="target_to_source")

    def forward(self, node_idx, edge_index):
        # Sort edge_index because otherwise propagate will not work in combination with the ConcatAggregation
        edge_index = coalesce(edge_index, sort_by_row=True)
        # Set the dimension along which the node feature tensor is expected
        # This is the default value, but we need to set it explicitly here
        # because we change it later
        self.node_dim = -2
        # Update the node idx by passing the messages and aggregating them
        # In the message function, we concatenate the node idx of the source node
        # with the node idx of the target node
        # In the aggregation function, we concatenate all messages from the neighbors
        # The resulting feature for every node is a tensor of shape (max_degree, 2)
        # where max_degree is the maximum degree of the graph
        # If a node has less neighbors than max_degree, the remaining entries are filled with -1
        node_idx_set_higher_order = self.propagate(edge_index, node_idx=node_idx)
        # Since our node features changed from shape (N, 1) to (N, max_degree, 2)
        # we need to set the node_dim to -3
        self.node_dim = -3
        # We use the function that is used to update node features to create the higher order edges
        edge_index_higher_order = self.edge_updater(edge_index, node_idx=node_idx_set_higher_order)
        # Select unique higher order nodes from the set of higher order neighbors
        node_idx_higher_order = node_idx_set_higher_order.view(-1, node_idx_set_higher_order.size(-1)).unique(dim=0)[1:]

        return node_idx_higher_order, edge_index_higher_order

    def message(self, node_idx_i, node_idx_j):
        # Concatenate the node idx of the source node with the node idx of the target node
        # The shape changes from (N, 1) to (N, 2)
        return torch.cat([node_idx_i, node_idx_j[:, -1:]], dim=-1)
    
    def edge_update(self, node_idx_i, node_idx_j) -> Tensor:
        # We take the higher order node idx sets that have been created for each node adjacent
        # to the edge (node_idx_i, node_idx_j) and repeat each node idx across different dimensions
        # so that we can compare them with each other
        #
        # Example:
        #
        #   node_idx_i = [[0, 3], [1, 3], [2, 3]]
        #   node_idx_j = [[3, 4], [2, 4], [-1, -1]]
        #
        #   strided_node_idx_i = [[
        #                           [0, 3],
        #                           [1, 3], 
        #                           [2, 3]
        #                         ],
        #                         [
        #                           [0, 3],
        #                           [1, 3],
        #                           [2, 3]
        #                         ],
        #                         [
        #                           [0, 3],
        #                           [1, 3],
        #                           [2, 3]
        #                         ]]
        #   strided_node_idx_j = [[
        #                           [3, 4],
        #                           [3, 4],
        #                           [3, 4]
        #                         ],
        #                         [
        #                           [2, 4],
        #                           [2, 4],
        #                           [2, 4]
        #                         ],
        #                         [
        #                           [-1, -1],
        #                           [-1, -1],
        #                           [-1, -1]
        #                         ]]
        strided_node_idx_i = node_idx_i.unsqueeze(1).expand(-1, node_idx_j.size(1), -1, -1)
        strided_node_idx_j = node_idx_j.unsqueeze(2).expand(-1, -1, node_idx_i.size(1), -1)
        # Only create an higher order edge if the target node idx of the first edge is equal to
        # the source node idx of the second edge
        edge_mask = (strided_node_idx_i[:, :, :, 1:] == strided_node_idx_j[:, :, :, :-1]).all(dim=-1)
        # Also, we need to remove the -1 padding values
        padd_mask = (
            (strided_node_idx_i[:, :, :] != -1).all(dim=-1) &
            (strided_node_idx_j[:, :, :] != -1).all(dim=-1)
        )
        # For the above, the following mask is:
        #
        #   mask = [[True, True, True],
        #           [False, False, False],
        #           [False, False, False]]
        mask = (edge_mask & padd_mask)
        # Concetenate the remaining higher order edges to create a new edge index
        higher_order_edges = torch.cat([strided_node_idx_i[mask].unsqueeze(0), strided_node_idx_j[mask].unsqueeze(0)], dim=0)
        return higher_order_edges

## Toy Example

In [46]:
edge_index = torch.tensor([[0, 0, 1, 1, 3, 4, 1, 6, 5],
                           [1, 3, 2, 3, 4, 5, 6, 5, 7]])
# edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6],
                        #    [1, 2, 3, 4, 5, 6, 7]])
node_idx = torch.arange(edge_index.max() + 1).reshape(-1, 1)
node_idx_2, edge_index_2_fast = DeBruijnTransform()(node_idx, edge_index)
print(edge_index_2_fast)

tensor([[[0, 1],
         [0, 1],
         [0, 1],
         [0, 3],
         [1, 3],
         [1, 6],
         [3, 4],
         [4, 5],
         [6, 5]],

        [[1, 2],
         [1, 3],
         [1, 6],
         [3, 4],
         [3, 4],
         [6, 5],
         [4, 5],
         [5, 7],
         [5, 7]]])


In [4]:
edge_index_2 = pp.DAGData.lift_order_dag(edge_index.unsqueeze(-1))
print(edge_index_2)

tensor([[[0, 1],
         [0, 1],
         [0, 1],
         [0, 3],
         [1, 3],
         [3, 4],
         [4, 5],
         [6, 5],
         [1, 6]],

        [[1, 2],
         [1, 3],
         [1, 6],
         [3, 4],
         [3, 4],
         [4, 5],
         [5, 7],
         [5, 7],
         [6, 5]]])


In [5]:
print((edge_index_2.sort(dim=1)[0] == edge_index_2_fast.sort(dim=1)[0]).all())

tensor(True)


### 3rd order
Currently, the transformation works only with a standard edge index. The good thing is, you can still pass in the higher_order node_idx and then the output is directly a third order edge index.

In [6]:
# Transform the edge index since the DeBruijnTransform only works for the normal edge index
edges_2 = edge_index_2_fast.reshape(-1, 2)
uniques, inverse_idx = edges_2.unique(dim=0, return_inverse=True)
transformed_edge_index_2_fast = inverse_idx.reshape(2, -1)

In [7]:
node_idx_3, edge_index_3_fast = DeBruijnTransform()(node_idx_2, transformed_edge_index_2_fast)
print(edge_index_3_fast)

tensor([[[0, 1, 3],
         [0, 1, 6],
         [0, 3, 4],
         [1, 3, 4],
         [1, 6, 5],
         [3, 4, 5]],

        [[1, 3, 4],
         [1, 6, 5],
         [3, 4, 5],
         [3, 4, 5],
         [6, 5, 7],
         [4, 5, 7]]])


In [8]:
edge_index_3 = pp.DAGData.lift_order_dag(edge_index_2)
print(edge_index_3)

tensor([[[0, 1, 3],
         [0, 1, 6],
         [0, 3, 4],
         [1, 3, 4],
         [3, 4, 5],
         [1, 6, 5]],

        [[1, 3, 4],
         [1, 6, 5],
         [3, 4, 5],
         [3, 4, 5],
         [4, 5, 7],
         [6, 5, 7]]])


# Exponentionally Large DAG

In [None]:
layers = 5
branches = 15

edges = []
prev_layer_nodes = [0]
j = 1
for _ in trange(layers):
    layer_nodes = []
    for node in prev_layer_nodes:
        for _ in range(branches):
            layer_nodes.append(j)
            edges.append((f"{node}", f"{j}"))
            j+=1
    prev_layer_nodes = layer_nodes

dag = pp.Graph.from_edge_list(edges)
dag_edge_index = dag.data.edge_index.unsqueeze(-1)

node_idx = torch.arange(dag_edge_index.max().item() + 1).unsqueeze(-1)
node_idx_gpu = node_idx.cuda()
dag_edge_index_gpu = dag.data.edge_index.cuda()

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:00<00:00, 16.04it/s]


### Current implementation

In [None]:
%timeit dag_edge_index_order_2 = pp.DAGData.lift_order_dag(dag_edge_index)

1min 11s ± 877 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Message Passing based implementation (CPU)

In [None]:
%timeit node_idx_2, dag_edge_index_order_2_fast = DeBruijnTransform()(node_idx, dag.data.edge_index)

642 ms ± 21.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Message Passing based implementation (GPU)

In [None]:
%timeit node_idx_2, dag_edge_index_order_2_fast = DeBruijnTransform()(node_idx_gpu, dag_edge_index_gpu)

74 ms ± 1.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
# print((dag_edge_index_order_2.sort(dim=1)[0] == dag_edge_index_order_2_fast.sort(dim=1)[0]).all())

## Many Walks

In [3]:
n_walks = 1000
walk_length = 100

walks = [list(range(walk_length)) for _ in range(n_walks)]
orig_walk = pp.WalkData()
for walk in walks:
    orig_walk.add_walk_seq(walk)

path_list = list(orig_walk.paths.values())
path_freq_tensor = torch.tensor(list(orig_walk.path_freq.values()))
mapping = pp.IndexMap()
nested_walk = pp.WalkDataNested(path_list, path_freq=path_freq_tensor, mapping=mapping)

  self.paths = nested_tensor(paths, dtype=torch.long)


### Original Walk Implementation

In [4]:
%timeit orig_walk.edge_index_k_weighted(2)

266 ms ± 3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Nested Tensor Implementation (CPU)

In [5]:
pp.config['torch']["device"] = "cpu"
%timeit nested_walk.edge_index_k_weighted(2)

7.03 ms ± 449 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Nested Tensor (GPU)

In [6]:
pp.config['torch']["device"] = "cuda"
cuda_path_list = [path.cuda() for path in path_list]
cuda_path_freq = path_freq_tensor.cuda()
cuda_nested_walk = pp.WalkDataNested(cuda_path_list, path_freq=cuda_path_freq, mapping=mapping)
%timeit cuda_nested_walk.edge_index_k_weighted(2)

28.2 ms ± 4.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Message Passing Implementation (CPU + GPU)

In [7]:
data_list = [Data(edge_index=path.long(), num_nodes=walk_length-1) for path in path_list]
walk_graph = next(iter(DataLoader(data_list, batch_size=n_walks)))
edge_index = walk_graph.edge_index
node_idx = torch.arange(edge_index.max() + 1).unsqueeze(-1)

In [8]:
%timeit node_idx_2, edge_index_2_fast = DeBruijnTransform()(node_idx, edge_index)

146 ms ± 2.55 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
cuda_edge_index = edge_index.cuda()
cuda_node_idx = node_idx.cuda()
%timeit node_idx_2, edge_index_2_fast = DeBruijnTransform()(cuda_node_idx, cuda_edge_index)

2.64 ms ± 154 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
