In [1]:
from tqdm import tqdm
import torch

from torch_geometric.utils import degree, cumsum, coalesce, sort_edge_index, subgraph, mask_to_index, k_hop_subgraph

import pathpyG as pp
from pathpyG import Graph

from pathpyG import TemporalGraph



In [2]:
def lift_order_temporal(g: TemporalGraph, delta: int = 1):

    # first-order edge index
    edge_index, timestamps = g.data.edge_index, g.data.time

    indices = torch.arange(0, edge_index.size(1), device=g.data.edge_index.device)

    unique_t = torch.unique(timestamps, sorted=True)
    second_order = []

    # lift order: find possible continuations for edges in each time stamp
    for i in range(unique_t.size(0)):
        t = unique_t[i]

        # find indices of all source edges that occur at unique timestamp t
        src_time_mask = timestamps == t
        src_edge_idx = indices[src_time_mask]

        # find indices of all edges that can possibly continue edges occurring at time t for the given delta
        dst_time_mask = (timestamps > t) & (timestamps <= t + delta)
        dst_node_mask = torch.isin(edge_index[0], edge_index[1, src_edge_idx])
        dst_edge_idx = indices[dst_time_mask & dst_node_mask]

        if dst_edge_idx.size(0) > 0 and src_edge_idx.size(0) > 0:

            # compute second-order edges between src and dst idx for all edges where dst in src_edges matches src in dst_edges
            x = torch.cartesian_prod(src_edge_idx, dst_edge_idx).t()
            # print(x.size(1))
            src_edges = edge_index[:, x[0]]
            dst_edges = edge_index[:, x[1]]
            ho_edge_index = x[:, src_edges[1, :] == dst_edges[0, :]]
            second_order.append(ho_edge_index)

    ho_index = torch.cat(second_order, dim=1)
    return ho_index

In [10]:
def lift_order_temporal_1(g: TemporalGraph, delta: int = 1):

    # first-order edge index
    edge_index, timestamps = g.data.edge_index, g.data.time

    indices = torch.arange(0, edge_index.size(1), device=g.data.edge_index.device)

    unique_t = torch.unique(timestamps, sorted=True)
    second_order = []

    # lift order: find possible continuations for edges in each time stamp
    for i in range(unique_t.size(0)):
        t = unique_t[i]

        # find indices of all source edges that occur at unique timestamp t
        src_time_mask = timestamps == t
        src_edge_idx = indices[src_time_mask]

        # find indices of all edges that can possibly continue edges occurring at time t for the given delta
        dst_time_mask = (timestamps > t) & (timestamps <= t + delta)
        dst_node_mask = torch.isin(edge_index[0], edge_index[1, src_edge_idx])
        dst_edge_idx = indices[dst_time_mask & dst_node_mask]

        if dst_edge_idx.size(0) > 0 and src_edge_idx.size(0) > 0:

            src_edges = edge_index[:, src_edge_idx]
            dst_edges = edge_index[:, dst_edge_idx]
            sorted_idx = torch.argsort(dst_edges[0])
            dst_edge_idx = dst_edge_idx[sorted_idx]
            dst_edges = dst_edges[:, sorted_idx]

            outdegree = degree(dst_edges[0], dtype=torch.long, num_nodes=g.n)
            outdegree_per_dst = outdegree[src_edges[1]]
            # print(outdegree_per_dst)
            num_new_edges = outdegree_per_dst.sum()
            
            ho_edge_srcs = torch.repeat_interleave(outdegree_per_dst)
            ptrs = cumsum(outdegree, dim=0)[:-1]
            ho_edge_dsts = torch.repeat_interleave(ptrs[src_edges[1]], outdegree_per_dst)
            idx_correction = torch.arange(num_new_edges, dtype=torch.long, device=edge_index.device)
            idx_correction -= cumsum(outdegree_per_dst, dim=0)[ho_edge_srcs]
            ho_edge_dsts += idx_correction
            second_order.append(torch.stack([src_edge_idx[ho_edge_srcs], dst_edge_idx[ho_edge_dsts]], dim=0))
            

    ho_index = torch.cat(second_order, dim=1)
    return ho_index

In [11]:
def lift_order_temporal_2(g: TemporalGraph, delta: int = 1):

    # first-order edge index
    edge_index, timestamps = g.data.edge_index, g.data.time

    unique_t = torch.unique(timestamps, sorted=True)
    second_order = []

    # lift order: find possible continuations for edges in each time stamp
    for i in range(unique_t.size(0)):
        t = unique_t[i]

        # find indices of all source edges that occur at unique timestamp t
        src_time_mask = timestamps == t
        src_edges = edge_index[:, src_time_mask]


        second_order.append(...)

    ho_index = torch.cat(second_order, dim=1)
    return ho_index

In [12]:
def long_temporal_graph() -> TemporalGraph:
    """Return a temporal graph with 20 time-stamped edges."""
    tedges = [('a', 'b', 1), ('c', 'b', 1), ('c', 'a', 1), ('f', 'c', 1), ('b', 'c', 5), ('a', 'd', 5), ('c', 'd', 9), ("a", "d", 9), ('c', 'e', 9),
              ('c', 'f', 11), ('f', 'a', 13), ('a', 'g', 18), ('b', 'f', 21),
              ('a', 'g', 26), ('c', 'f', 27), ('h', 'f', 27), ('g', 'h', 28),
              ('a', 'c', 30), ('a', 'b', 31), ('c', 'h', 32), ('f', 'h', 33),
              ('b', 'i', 42), ('i', 'b', 42), ('c', 'i', 47), ('h', 'i', 50)]
    return TemporalGraph.from_edge_list(tedges)

In [13]:
g = long_temporal_graph()

In [14]:
lift_order_temporal(g, delta=10)

tensor([[ 0,  1,  2,  2,  3,  3,  3,  4,  4,  4,  9, 10, 11, 13, 14, 15, 17],
        [ 4,  4,  5,  7,  6,  8,  9,  6,  8,  9, 10, 11, 16, 16, 20, 20, 19]])

In [15]:
%timeit lift_order_temporal(g, delta=10)

2.72 ms ± 90.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
%timeit lift_order_temporal_1(g, delta=10)

5.23 ms ± 289 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
lift_order_res = lift_order_temporal(g, delta=10)
lift_order_res_1 = lift_order_temporal_1(g, delta=10)

(lift_order_res == lift_order_res_1).all().item()

True