# TODO: Create a Notebook that explains the new concepts for order lifting for DAGs and temporal graphs.

In [None]:
def lift_order_edge_index(edge_index: torch.Tensor, num_nodes: int, edge_weights: torch.Tensor) -> torch.Tensor:
        """
        Do a line graph transformation on the edge index to lift the order of the graph by one.

        Args:
            edge_index: A **sorted** edge index tensor of shape (2, num_edges).
            num_nodes: The number of nodes in the graph.
        """

        # Since this is a complicated function, we will use the following example to explain the steps:
        # Example:
        #   edge_index = [[0, 0, 1, 1, 1, 3, 4, 5, 6],
        #                 [1, 3, 2, 3, 6, 4, 5, 7, 5]]

        # Compute the outdegree of each node used to get all the edge combinations leading to a higher-order edge
        # Example:
        #   outdegree = [2, 3, 0, 1, 1, 1, 1, 0]
        outdegree = degree(edge_index[0], dtype=torch.long, num_nodes=num_nodes)

        # For each center node, we need to combine each outgoing edge with each incoming edge
        # We achieve this by creating `outdegree` number of edges for each destination node of the old edge index
        # Example:
        #   outdegree_per_dst = [3, 1, 0, 1, 1, 1, 1, 0, 1]
        #   num_new_edges = 9
        outdegree_per_dst = outdegree[edge_index[1]]
        num_new_edges = outdegree_per_dst.sum()

        # Use each edge from the edge index as node and assign the new indices in the order of the original edge index
        # Each higher order node has one outgoing edge for each outgoing edge of the original destination node
        # Since we keep the ordering, we can just repeat each node using the outdegree_per_dst tensor
        # Example:
        #   ho_edge_srcs = [0, 0, 0, 1, 3, 4, 5, 6, 8]
        ho_edge_srcs = torch.repeat_interleave(outdegree_per_dst)

        # For each node, we calculate pointers of shape (num_nodes,) that indicate the start of the original edges
        # (new higher-order nodes) that have the node as source node
        # (Note we use PyG's cumsum function because it adds a 0 at the beginning of the tensor and
        # we want the `left` boundaries of the intervals, so we also remove the last element of the result with [:-1])
        # Example:
        #   ptrs = [0, 2, 5, 5, 6, 7, 8, 9]
        ptrs = cumsum(outdegree, dim=0)[:-1]

        # Use these pointers to get the start of the edges for each higher-order src and repeat it `outdegree` times
        # Since we keep the ordering, all new higher-order edges that have the same src are indexed consecutively
        # Example:
        #   ho_edge_dsts = [2, 2, 2, 5, 5, 8, 6, 7, 7]
        ho_edge_dsts = torch.repeat_interleave(ptrs[edge_index[1]], outdegree_per_dst)

        # Since the above only repeats the start of the edges, we need to add (0, 1, 2, 3, ...)
        # for all `outdegree` number of edges consecutively to get the correct destination nodes
        # We can achieve this by starting with a range from (0, 1, ..., num_new_edges)
        # Example:
        #   idx_correction    = [0, 1, 2, 3, 4, 5, 6, 7, 8]
        idx_correction = torch.arange(num_new_edges, dtype=torch.long, device=edge_index.device)
        # Then, we subtract the cumulative sum of the outdegree for each destination node to get a tensor.
        # Example:
        #   idx_correction    = [0, 1, 2, 0, 0, 0, 0, 0, 0]
        idx_correction -= cumsum(outdegree_per_dst, dim=0)[ho_edge_srcs]
        # Add this tensor to the destination nodes to get the correct destination nodes for each higher-order edge
        # Example:
        #   ho_edge_dsts = [2, 3, 4, 5, 5, 8, 6, 7, 7]
        ho_edge_dsts += idx_correction
        # tensor([[0, 0, 0, 1, 3, 4, 5, 6, 8],
        #         [2, 3, 4, 5, 5, 8, 6, 7, 7]])
        return torch.stack([ho_edge_srcs, ho_edge_dsts], dim=0)