In [2]:
import torch
from torch import Tensor
from torch_sparse import SparseTensor
from typing import Tuple


def triplets(
    edge_index: Tensor,
    num_nodes: int,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
    row, col = edge_index  # j->i

    value = torch.arange(row.size(0), device=row.device)
    adj_t = SparseTensor(row=col, col=row, value=value,
                         sparse_sizes=(num_nodes, num_nodes))
    adj_t_row = adj_t[row]
    num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

    # Node indices (k->j->i) for triplets.
    idx_i = col.repeat_interleave(num_triplets)
    idx_j = row.repeat_interleave(num_triplets)
    idx_k = adj_t_row.storage.col()
    mask = idx_i != idx_k  # Remove i == k triplets.
    idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]

    # Edge indices (k-j, j->i) for triplets.
    idx_kj = adj_t_row.storage.value()[mask]
    idx_ji = adj_t_row.storage.row()[mask]

    return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji

In [3]:
edge_index = torch.tensor([
    [1, 2, 3, 0, 2],
    [0, 1, 2, 3, 0]
])

In [5]:
triplets(edge_index, edge_index.max().item() + 1)

(tensor([0, 1, 2, 3, 0]),
 tensor([1, 2, 3, 0, 2]),
 tensor([0, 1, 2, 3, 3, 0]),
 tensor([1, 2, 3, 0, 0, 2]),
 tensor([2, 3, 0, 1, 2, 3]),
 tensor([1, 2, 3, 0, 4, 2]),
 tensor([0, 1, 2, 3, 3, 4]))

In [12]:
import numpy as np
np.sqrt(2)

1.4142135623730951

In [19]:
torch.empty(3, 3).data.normal_(mean=0, std=2 / 9)

tensor([[ 0.3375, -0.2287,  0.4344],
        [-0.2857, -0.1572,  0.0026],
        [ 0.1044,  0.2744, -0.2282]])