In [1]:
from typing import Optional

from tqdm import trange
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.utils import coalesce, degree, cumsum
from torch_geometric import EdgeIndex

import pathpyG as pp

In [2]:
dags = pp.DAGData()
dags.append(torch.tensor([[3,0,1],[0,1,2]]))
dags.append(torch.tensor([[1,0,2],[0,2,0]]))
dags.append(torch.tensor([[0,1],[1,2]]))

In [3]:
print(dags)

DAGData with 3 dags and total weight 3


In [4]:
def lift_order_edge_index(edge_index: EdgeIndex | torch.Tensor, num_nodes: int | None = None) -> torch.Tensor:
        # Since this is a complicated function, we will use the following example to explain the steps:
        # Example:
        #   edge_index = [[0, 0, 1, 1, 1, 3, 4, 5, 6],
        #                 [1, 3, 2, 3, 6, 4, 5, 7, 5]]

        # Compute the outdegree of each node which we will use to get all the edge combinations that lead to a higher order edge
        # Example:
        #   outdegree = [2, 3, 0, 1, 1, 1, 1, 0]
        outdegree = degree(edge_index[0], dtype=torch.long, num_nodes=num_nodes)

        # For each center node, we need to combine each outgoing edge with each incoming edge
        # We achieve this by creating `outdegree` number of edges for each destination node of the old edge index
        # Example:
        #   outdegree_per_dst = [3, 1, 0, 1, 1, 1, 1, 0, 1]
        #   num_new_edges = 9
        outdegree_per_dst = outdegree[edge_index[1]]
        num_new_edges = outdegree_per_dst.sum()

        # We use each edge from the edge index as new node and assign the new indices in the order of the original edge index
        # Each higher order node has one outgoing edge for each outgoing edge of the original destination node
        # Since we keep the ordering, we can just repeat each node using the outdegree_per_dst tensor
        # Example:
        #   ho_edge_srcs = [0, 0, 0, 1, 3, 4, 5, 6, 8]
        ho_edge_srcs = torch.repeat_interleave(outdegree_per_dst)

        # For each node, we calculate pointers of shape (num_nodes,) that indicate the start of the original edges (new higher order nodes) that have the node as source node
        # (Note we use PyG's cumsum function because it adds a 0 at the beginning of the tensor and we want the `left` boundaries of the intervals, so we also remove the last element of the result with [:-1])
        # Example:
        #   ptrs = [0, 2, 5, 5, 6, 7, 8, 9]
        ptrs = cumsum(outdegree, dim=0)[:-1]

        # Use these pointers to get the start of the edges for each higher order source node and repeat it `outdegree` times
        # Since we keep the ordering, all new higher order edges that have the same source node are indexed consecutively
        # Example:
        #   ho_edge_dsts = [2, 2, 2, 5, 5, 8, 6, 7, 7]
        ho_edge_dsts = torch.repeat_interleave(ptrs[edge_index[1]], outdegree_per_dst)

        # Since the above only repeats the start of the edges, we need to add (0, 1, 2, 3, ...) for all `outdegree` number of edges consecutively to get the correct destination nodes
        # We can achieve this by starting with a range from (0, 1, ..., num_new_edges)
        # Example: 
        #   idx_correction    = [0, 1, 2, 3, 4, 5, 6, 7, 8]
        idx_correction = torch.arange(num_new_edges, dtype=torch.long, device=edge_index.device)
        # Then, we subtract the cumulative sum of the outdegree for each destination node to get a tensor.
        # Example:
        #   idx_correction    = [0, 1, 2, 0, 0, 0, 0, 0, 0]
        idx_correction -= cumsum(outdegree_per_dst, dim=0)[ho_edge_srcs]
        # Finally, we add this tensor to the destination nodes to get the correct destination nodes for each higher order edge
        # Example:
        #   ho_edge_dsts = [2, 3, 4, 5, 5, 8, 6, 7, 7]
        ho_edge_dsts += idx_correction
    # tensor([[0, 0, 0, 1, 3, 4, 5, 6, 8],
    #         [2, 3, 4, 5, 5, 8, 6, 7, 7]])
        return torch.stack([ho_edge_srcs, ho_edge_dsts], dim=0)

In [95]:
def from_DAGs(data: pp.DAGData, max_order: int = 1) -> pp.MultiOrderModel:
    """Creates multiple higher-order De Bruijn graphs for paths in DAGData."""
    m = pp.MultiOrderModel()

    # Outsource to DAGData
    data_list = []
    for dag in data.dags:
        edge_index = coalesce(dag.long())
        unique_nodes = torch.unique(edge_index)
        num_nodes = unique_nodes.size(0)
        data_list.append(Data(edge_index=edge_index, fo_nodes=unique_nodes.unsqueeze(1), num_nodes=num_nodes))
    dag_graph = next(iter(DataLoader(data_list, batch_size=len(data.dags))))
    edge_index = dag_graph.edge_index
    old_fo_nodes = dag_graph.fo_nodes

    m.layers[1] = pp.Graph(dag_graph)
    
    for k in range(2, max_order+1):
        # Lift order
        fo_nodes = torch.cat([old_fo_nodes[edge_index[0]], old_fo_nodes[edge_index[1]][:, -1:]], dim=1)
        ho_index = lift_order_edge_index(edge_index, num_nodes=old_fo_nodes.size(0))
        unique_nodes, inverse_idx = torch.unique(fo_nodes, dim=0, return_inverse=True)

        # Save for the next iteration
        edge_index = ho_index
        old_fo_nodes = fo_nodes

        # Save aggregated higher-order graph
        mapped_ho_index = inverse_idx[ho_index]
        aggregated_ho_index, edge_weights = coalesce(mapped_ho_index, edge_attr=torch.ones(ho_index.size(1)), num_nodes=unique_nodes.size(0))
        m.layers[k] = pp.Graph(Data(edge_index=aggregated_ho_index, num_nodes=unique_nodes.size(0), fo_nodes=unique_nodes, edge_weights=edge_weights))
    
    return m

In [97]:
m = from_DAGs(dags, max_order=3)

In [104]:
pp.plot(m.layers[2], node_label=list(map(str, m.layers[2].data.fo_nodes.tolist())))

<pathpyG.visualisations.network_plots.StaticNetworkPlot at 0x7f8aa8ac5bd0>

In [105]:
print(m.layers[2].data.edge_index)

EdgeIndex([[0, 1, 2, 4, 5],
           [3, 4, 1, 1, 0]], sparse_size=(6, 6), nnz=5, sort_order=row)


In [106]:
print(m.layers[2].data.edge_weights)

tensor([2., 1., 1., 1., 1.])


In [107]:
print(m.layers[2].data.fo_nodes)

tensor([[0, 1],
        [0, 2],
        [1, 0],
        [1, 2],
        [2, 0],
        [3, 0]])
