In [None]:
pip install torch torchvision torchaudio

In [70]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [56]:
def patchify_keypoints(keypoints):
    """
    Converts keypoints into patchified stem input.
    Args:
        keypoints: tensor (B, T, K, D) -> B=batch, T=frames, K=keypoints, D=(x,y,z)
    Returns:
        Tensor in shape (B, D, T*K, 1) suitable for graph modules
    """
    B, T, K, D = keypoints.shape
    x = keypoints.view(B, T * K, D)  
    x = x.permute(0, 2, 1).unsqueeze(-1) 
    return x

In [60]:
def pairwise_distance(x):
    with torch.no_grad():
        x_inner = -2 * torch.matmul(x, x.transpose(2, 1))
        x_square = torch.sum(x * x, dim=-1, keepdim=True)
        return x_square + x_inner + x_square.transpose(2, 1)

In [62]:
def dense_knn_matrix(x, k=16, relative_pos=None):
    with torch.no_grad():
        x = x.transpose(2, 1).squeeze(-1)
        batch_size, n_points, n_dims = x.shape
        dist = pairwise_distance(x)
        if relative_pos is not None:
            dist += relative_pos
        _, nn_idx = torch.topk(-dist, k=k)
        center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
    return torch.stack((nn_idx, center_idx), dim=0)

In [64]:
class DenseDilated(nn.Module):
    def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
        super().__init__()
        self.dilation = dilation
        self.stochastic = stochastic
        self.epsilon = epsilon
        self.k = k

    def forward(self, edge_index):
        if self.stochastic and torch.rand(1) < self.epsilon and self.training:
            num = self.k * self.dilation
            randnum = torch.randperm(num)[:self.k]
            edge_index = edge_index[:, :, :, randnum]
        else:
            edge_index = edge_index[:, :, :, ::self.dilation]
        return edge_index

In [66]:
class DenseDilatedKnnGraph(nn.Module):
    def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
        super().__init__()
        self.dilation = dilation
        self.stochastic = stochastic
        self.epsilon = epsilon
        self.k = k
        self._dilated = DenseDilated(k, dilation, stochastic, epsilon)

    def forward(self, x, y=None, relative_pos=None):
        if y is not None:
            x = F.normalize(x, p=2.0, dim=1)
            y = F.normalize(y, p=2.0, dim=1)
            edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, relative_pos)
        else:
            x = F.normalize(x, p=2.0, dim=1)
            edge_index = dense_knn_matrix(x, self.k * self.dilation, relative_pos)
        return self._dilated(edge_index)

In [72]:
if __name__ == "__main__":
    B, T, K, D = 2, 30, 543, 3
    keypoints = torch.randn(B, T, K, D)

    patchified = patchify_keypoints(keypoints) 

    knn_graph = DenseDilatedKnnGraph(k=9, dilation=1)
    edge_index = knn_graph(patchified) 

    print("Patchified shape:", patchified.shape)
    print("Edge index shape:", edge_index.shape)

Patchified shape: torch.Size([2, 3, 16290, 1])
Edge index shape: torch.Size([2, 2, 16290, 9])
