## TopK

In [12]:
import torch
import torch_geometric
from torch_geometric.utils import degree, cumsum, scatter, to_undirected
import torch_sparse

In [2]:
def sparse_sort(src: torch.Tensor, index: torch.Tensor, dim=0, descending=False, eps=1e-12):
    f_src = src.float()
    f_min, f_max = f_src.min(dim)[0], f_src.max(dim)[0]
    norm = (f_src - f_min) / (f_max - f_min + eps) + index.float() * (-1) ** int(descending)
    perm = norm.argsort(dim=dim, descending=descending)

    return src[perm], perm

def sparse_topk(src: torch.Tensor, index: torch.Tensor, ratio: float, dim=0, descending=False, eps=1e-12):
    rank, perm = sparse_sort(src, index, dim, descending, eps)
    num_nodes = degree(index, dtype=torch.long)
    k = (ratio * num_nodes.to(float)).ceil().to(torch.long)
    start_indices = torch.cat([torch.zeros((1, ), device=src.device, dtype=torch.long), num_nodes.cumsum(0)])
    mask = [torch.arange(k[i], dtype=torch.long, device=src.device) + start_indices[i] for i in range(len(num_nodes))]
    mask = torch.cat(mask, dim=0)
    mask = torch.zeros_like(index, device=index.device).index_fill(0, mask, 1).bool()
    topk_perm = perm[mask]
    exc_perm = perm[~mask]

    return topk_perm, exc_perm, rank, perm, mask

def topK(x, ratio, batch, min_score, tol=1e-7, debug=False):
    if min_score is not None:
        # Make sure that we do not drop all nodes in a graph.
        scores_max = scatter(x, batch, reduce='max')[batch] - tol
        scores_min = scores_max.clamp(max=min_score)

        perm = (x > scores_min).nonzero().view(-1)
        return perm

    if ratio >= 1.:
        return torch.arange(x.shape[0], device=x.device), torch.tensor([], dtype=torch.long), None, None

    if ratio is not None:
        num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')
        if ratio >= 1:
            k = num_nodes.new_full((num_nodes.size(0), ), int(ratio))
        else:
            k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)

        x, x_perm = torch.sort(x.view(-1), descending=True, stable=True)
        batch = batch[x_perm]
        batch, batch_perm = torch.sort(batch, descending=False, stable=True)

        arange = torch.arange(x.size(0), dtype=torch.long, device=x.device)
        ptr = cumsum(num_nodes)
        batched_arange = arange - ptr[batch]
        mask = batched_arange < k[batch]

        return x_perm[batch_perm[mask]], x_perm[batch_perm[~mask]].sort()[0], x_perm, mask

In [3]:
print(f"Using version {torch_geometric.__version__} for PyG")
print(f"Using version {torch.__version__} for Torch")

Using version 2.4.0 for PyG
Using version 2.1.0+cu121 for Torch


In [4]:
N = 3
ratio = 0.5
edge_score_single = torch.tensor([0.9, 0.9, 0.9, 0.9], device="cuda")
edge_score = edge_score_single.repeat_interleave(N)
batch = torch.cat([
    torch.full_like(edge_score_single, fill_value=i, dtype=torch.int64, device="cuda")
        for i in range(N)
])

Using the original topk function

In [5]:
idx_kept, idx_dropped, _, perm, mask = sparse_topk(edge_score, batch, ratio, descending=True)
idx_kept

tensor([ 1,  0,  7,  6, 11, 10], device='cuda:0')

Using the stable version

In [6]:
idx_kept, idx_dropped, perm, mask = topK(edge_score, ratio, batch, min_score=None)
idx_kept

tensor([0, 1, 4, 5, 8, 9], device='cuda:0')

The stable version picks alsways the first elements as top candidates, while *sparse_topk* picks arbitrarly the first and the last elements inside each batch

## Directed edge scores

In [9]:
print(f"Using version {torch_geometric.__version__} for PyG")
print(f"Using version {torch.__version__} for Torch")
print(f"Using version {torch_sparse.__version__} for TorchSparse")

Using version 2.4.0 for PyG
Using version 2.1.0+cu121 for Torch
Using version 0.6.18+pt21cu121 for TorchSparse


In [10]:
edge_index = torch.tensor([
    [0,1,2,1],
    [1,0,1,2]
])
att = torch.tensor([0., 1., 0.6, 0.8])

Using the original version

In [11]:
att_avg = (att + torch_sparse.transpose(edge_index, att, 4, 4, coalesced=False)[1]) / 2
att_avg

tensor([0.0000, 1.0000, 0.6000, 0.8000])

Using our fixed version

In [13]:
_, att_avg = to_undirected(edge_index, att, reduce="mean")
att_avg

tensor([0.5000, 0.5000, 0.7000, 0.7000])