# Point Transformer

- Utilities
1. k-Nearest Neighbor Search
2. k-Nearest Neighbor Linear Interpolation
3. Farthest Point Sampling
- Modules
1. Point Transformer Layer
2. TransitionDown
3. TransitionUp

In [1]:
import torch

## Utilities

### 1. k-Nearest Neighbor Search
- Input: points (N, 3) and the number of neighbors K
- Output: kNN distances (N, K) and kNN indices (N, K)

In [2]:
def find_knn(point_cloud, k):
    N = len(point_cloud)
    
    # 1. Compute pairwise distance
    delta = point_cloud.view(N, 1, 3) - point_cloud.view(1, N, 3) # (N, N, 3)
    dist = torch.sum(delta ** 2, dim=-1) # (N, N)
    
    # 2. Find k-nearest neighbor indices (Hint: torch.topk)
    knn_dist, knn_indices = dist.topk(k=k, dim=-1, largest=False)
    
    return knn_dist, knn_indices

In [3]:
N = 100
K = 5

points = torch.randn(N, 3)
knn_dist, knn_indices = find_knn(points, K)
print(knn_dist.shape)
print(knn_indices.shape)

torch.Size([100, 5])
torch.Size([100, 5])


### 1. k-Nearest Neighbor Search (General Case)
- Input: dataset points (N, 3), query points (M, 3), and the number of neighbors K
- Output: kNN distances (M, K) and kNN indices (M, K)

In [4]:
def find_knn_general(query_points, dataset_points, k):
    M = len(query_points)
    N = len(dataset_points)
    
    # 1. Compute pairwise distance
    delta = query_points.view(M, 1, 3) - dataset_points.view(1, N, 3) # (M, N, 3)
    dist = torch.sum(delta ** 2, dim=-1) # (M, N)
    
    # 2. Find k-nearest neighbor indices and corresponding features
    knn_dist, knn_indices = dist.topk(k=k, dim=-1, largest=False) # (M, k)
    
    return knn_dist, knn_indices

In [5]:
N = 100
M = 25
K = 3

query_points = torch.randn(M, 3)
dataset_points = torch.randn(N, 3)

knn_dist, knn_indices = find_knn_general(query_points, dataset_points, K)
print(knn_dist.shape)
print(knn_indices.shape)

torch.Size([25, 3])
torch.Size([25, 3])


### 2. k-Nearest Neighbor Linear Interpolation
- Input: dataset points (N, 3) with the corresponding features (N, C), query points (M, 3), the number of neighbors K
- Output: Interpolated query features (M, C)

In [6]:
def interpolate_knn(query_points, dataset_points, dataset_features, k):
    M = len(query_points)
    N, C = dataset_features.shape
    
    # 1. Find k-nearest neighbor indices and corresponding features
    knn_dist, knn_indices = find_knn_general(query_points, dataset_points, k)
    knn_dataset_features = dataset_features[knn_indices.view(-1)].view(M, k, C)
    
    # 3. Calculate interpolation wegihts
    knn_dist_recip = 1. / (knn_dist + 1e-8) # (M, k)
    denom = knn_dist_recip.sum(dim=-1, keepdim=True) # (M, 1)
    weights = knn_dist_recip / denom # (M, k)
    
    # 4. Linear interpolation
    weighted_features = weights.view(M, k, 1) * knn_dataset_features # (M, k, C)
    interpolated_features = weighted_features.sum(dim=1) # (M, C)
    
    return interpolated_features

In [7]:
N = 100
M = 25
K = 3
C = 32

query_points = torch.randn(M, 3)
dataset_points = torch.randn(N, 3)
dataset_features = torch.randn(N, C)

interpolated_features = interpolate_knn(query_points, dataset_points, dataset_features, K)
print(interpolated_features.shape)

torch.Size([25, 32])


### 3. Farthest Point Sampling
- Input: points (N, 3), the number of samples M
- Output: sampled_indices (M,)

In [8]:
import random

In [9]:
def farthest_point_sampling(points, num_samples):
    N = len(points)
    
    # 1. Initialization
    sampled_indices = torch.zeros(num_samples, dtype=torch.long)
    distance = torch.ones(N,) * 1e10
    farthest_idx = random.randint(0, N)
    
    # 2. Iteratively sample the farthest points
    for i in range(num_samples):
        # 2-1. Sample the farthest point
        sampled_indices[i] = farthest_idx
        
        # 2-2. Compute distances between the sampled point and other (remaining) points
        centroid = points[farthest_idx].view(1, 3)
        delta = points - centroid
        dist = torch.sum(delta ** 2, dim=-1) # (N,)
        mask = dist < distance
        distance[mask] = dist[mask]
        
        # 2-3. Sample the next farthest point
        farthest_idx = torch.max(distance, -1)[1]

    return sampled_indices

In [10]:
N = 100
M = 25

points = torch.randn(N, 3)
sampled_indices = farthest_point_sampling(points, M)
print(sampled_indices.shape)
sampled_points = points[sampled_indices]
print(sampled_points.shape)

torch.Size([25])
torch.Size([25, 3])


## Modules

In [11]:
import torch.nn as nn

### 1. Point Transformer Layer (Block)
- Input: points (N, 3), the corresponding features (N, C_in), the number of neighbors K
- Output: the output features (N, C_out)

In [12]:
class PointTransformerLayer(nn.Module):
    
    def __init__(self, in_channels, out_channels, k):
        super(PointTransformerLayer, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k = k
        
        self.linear_q = nn.Linear(in_channels, out_channels, bias=False)
        self.linear_k = nn.Linear(in_channels, out_channels, bias=False)
        self.linear_v = nn.Linear(in_channels, out_channels, bias=False)
        
        self.mlp_attn = nn.Sequential(
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Linear(out_channels, out_channels),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Linear(out_channels, out_channels)
        )
        self.mlp_pos = nn.Sequential(
            nn.Linear(3, 3),
            nn.BatchNorm1d(3),
            nn.ReLU(inplace=True),
            nn.Linear(3, out_channels)
        )
        
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, points, features):
        N = len(points)
        
        # 1. Query, key, and value projections
        f_q = self.linear_q(features) # (N, C_out)
        f_k = self.linear_k(features) # (N, C_out)
        f_v = self.linear_v(features) # (N, C_out)
        
        # 2. Find kNN for local self-attention
        knn_dist, knn_indices = find_knn(points, self.k) # (N, k)
        knn_points = points[knn_indices.view(-1)].view(N, self.k, 3)
        knn_k = f_k[knn_indices.view(-1)].view(N, self.k, self.out_channels)
        knn_v = f_v[knn_indices.view(-1)].view(N, self.k, self.out_channels)
        
        # 3. Calculate the relative positional encoding
        rel_pos = points.view(N, 1, 3) - knn_points # (N, k, 3)
        rel_pos_enc = self.mlp_pos(rel_pos.view(-1, 3)).view(N, self.k, -1) # (N, k, C_out)
        
        # 4. Vector similarity
        vec_sim = f_q.view(N, 1, self.out_channels) - knn_k + rel_pos_enc
        weights = self.mlp_attn(vec_sim.view(-1, self.out_channels)).view(N, self.k, self.out_channels)
        weights = self.softmax(weights) # (N, k, C_out)
        
        # 5. Weighted sum
        weighted_knn_v = weights * (knn_v + rel_pos_enc) # (N, k, C_out)
        out_features = weighted_knn_v.sum(dim=1) # (N, C_out)
        
        return out_features

In [13]:
N = 100
C_in = 32
C_out = 64
K = 7

points = torch.randn(N, 3)
features = torch.randn(N, C_in)
pt_layer = PointTransformerLayer(C_in, C_out, K)

out_features = pt_layer(points, features)
print(out_features.shape)

torch.Size([100, 64])


In [14]:
class PointTransformerBlock(nn.Module):
    
    def __init__(self, channels, k):
        super(PointTransformerBlock, self).__init__()
        self.linear_in = nn.Linear(channels, channels)
        self.pt_layer = PointTransformerLayer(channels, channels, k)
        self.linear_out = nn.Linear(channels, channels)
        
    def forward(self, points, features):
        out_features = self.linear_in(features)
        out_features = self.pt_layer(points, out_features)
        out_features = self.linear_out(out_features)
        out_features += features
        
        return out_features

In [15]:
N = 100
C = 32

K = 7

points = torch.randn(N, 3)
features = torch.randn(N, C)
pt_block = PointTransformerBlock(C, K)

out_features = pt_block(points, features)
print(out_features.shape)

torch.Size([100, 32])


### 2. TransitionDown
- Input: points (N, 3), the corresponding features (N, C), the number of samples M, the number of neighbors K
- Output: the sampled points (M, 3) and the corresponding features (M, C)

In [16]:
class TransitionDown(nn.Module):
    
    def __init__(self, channels, num_samples, k):
        super(TransitionDown, self).__init__()
        self.channels = channels
        self.num_samples = num_samples
        self.k = k
        
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels, bias=False),
            nn.BatchNorm1d(channels),
            nn.ReLU(inplace=True),
            nn.Linear(channels, channels, bias=False),
            nn.BatchNorm1d(channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, points, features):
        N = len(points)
        
        # 1. Farthest point sampling
        sampled_indices = farthest_point_sampling(points, self.num_samples)
        sampled_points = points[sampled_indices]
        
        # 2. kNN search
        knn_dist, knn_indices = find_knn_general(sampled_points, points, self.k) # (M, K)
        
        # 3. MLP
        knn_features = features[knn_indices.view(-1)] # (M*K, C)
        out_knn_features = self.mlp(knn_features)
        out_knn_features = out_knn_features.view(self.num_samples, self.k, -1)
        
        # 4. Local max pooling
        out_features = out_knn_features.max(dim=1)[0]
        
        return sampled_points, out_features

In [17]:
N = 100
C = 32
M = 25
K = 7

points = torch.randn(N, 3)
features = torch.randn(N, C)
td_module = TransitionDown(C, M, K)

down_points, down_features = td_module(points, features)
print(down_points.shape)
print(down_features.shape)

torch.Size([25, 3])
torch.Size([25, 32])


### 3. TransitionUp
- Input: up_points (N, 3), up_features (N, C_up), down_points (M, 3), and down_features (M, C_down)
- Output: out_features (N, C_out)

In [18]:
class TransitionUp(nn.Module):
    
    def __init__(self, up_channels, down_channels, out_channels):
        super(TransitionUp, self).__init__()
        self.linear_up = nn.Linear(up_channels, out_channels)
        self.linear_down = nn.Linear(down_channels, out_channels)
        
    def forward(self, up_points, up_features, down_points, down_features):
        # 1. Feed-forward with the down linear layer
        down_f = self.linear_down(down_features)
        
        # 2. Interpolation
        interp_f = interpolate_knn(up_points, down_points, down_f, 3) # (N, C_out)
        
        # 3. Skip-connection
        out_f = interp_f + self.linear_up(up_features)
        
        return out_f

In [19]:
N = 100
M = 25
C_up = 32
C_down = 64
C_out = 128

up_points = torch.randn(N, 3)
up_features = torch.randn(N, C_up)
down_points = torch.randn(M, 3)
down_features = torch.randn(M, C_down)
tu_module = TransitionUp(C_up, C_down, C_out)

out_features = tu_module(up_points, up_features, down_points, down_features)
print(out_features.shape)

torch.Size([100, 128])
