# VoteNet

- Utilities
1. Radius Search
2. IoU for axis-aligned 3D bounding boxes
3. 3D non-maximum suppression(NMS)
- Modules
1. Voting module
2. Detection head

In [1]:
import random

import torch
import torch.nn as nn

## Useful utilities and modules we already implemented

In [2]:
def find_knn_general(query_points, dataset_points, k):
    M = len(query_points)
    N = len(dataset_points)
    
    # 1. Compute pairwise distance
    delta = query_points.view(M, 1, 3) - dataset_points.view(1, N, 3) # (M, N, 3)
    dist = torch.sum(delta ** 2, dim=-1) # (M, N)
    
    # 2. Find k-nearest neighbor indices and corresponding features
    knn_dist, knn_indices = dist.topk(k=k, dim=-1, largest=False) # (M, k)
    
    return knn_dist, knn_indices


def interpolate_knn(query_points, dataset_points, dataset_features, k):
    M = len(query_points)
    N, C = dataset_features.shape
    
    # 1. Find k-nearest neighbor indices and corresponding features
    knn_dist, knn_indices = find_knn_general(query_points, dataset_points, k)
    knn_dataset_features = dataset_features[knn_indices.view(-1)].view(M, k, C)
    
    # 3. Calculate interpolation wegihts
    knn_dist_recip = 1. / (knn_dist + 1e-8) # (M, k)
    denom = knn_dist_recip.sum(dim=-1, keepdim=True) # (M, 1)
    weights = knn_dist_recip / denom # (M, k)
    
    # 4. Linear interpolation
    weighted_features = weights.view(M, k, 1) * knn_dataset_features # (M, k, C)
    interpolated_features = weighted_features.sum(dim=1) # (M, C)
    
    return interpolated_features


def farthest_point_sampling(points, num_samples):
    N = len(points)
    
    # 1. Initialization
    sampled_indices = torch.zeros(num_samples, dtype=torch.long)
    distance = torch.ones(N,) * 1e10
    farthest_idx = random.randint(0, N)
    
    # 2. Iteratively sample the farthest points
    for i in range(num_samples):
        # 2-1. Sample the farthest point
        sampled_indices[i] = farthest_idx
        
        # 2-2. Compute distances between the sampled point and other (remaining) points
        centroid = points[farthest_idx].view(1, 3)
        delta = points - centroid
        dist = torch.sum(delta ** 2, dim=-1) # (N,)
        mask = dist < distance
        distance[mask] = dist[mask]
        
        # 2-3. Sample the next farthest point
        farthest_idx = torch.max(distance, -1)[1]

    return sampled_indices


class PointTransformerLayer(nn.Module):
    
    def __init__(self, in_channels, out_channels, k):
        super(PointTransformerLayer, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k = k
        
        self.linear_q = nn.Linear(in_channels, out_channels, bias=False)
        self.linear_k = nn.Linear(in_channels, out_channels, bias=False)
        self.linear_v = nn.Linear(in_channels, out_channels, bias=False)
        
        self.mlp_attn = nn.Sequential(
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Linear(out_channels, out_channels),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Linear(out_channels, out_channels)
        )
        self.mlp_pos = nn.Sequential(
            nn.Linear(3, 3),
            nn.BatchNorm1d(3),
            nn.ReLU(inplace=True),
            nn.Linear(3, out_channels)
        )
        
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, points, features):
        N = len(points)
        
        # 1. Query, key, and value projections
        f_q = self.linear_q(features) # (N, C_out)
        f_k = self.linear_k(features) # (N, C_out)
        f_v = self.linear_v(features) # (N, C_out)
        
        # 2. Find kNN for local self-attention
        knn_dist, knn_indices = find_knn_general(points, points, self.k) # (N, k)
        knn_points = points[knn_indices.view(-1)].view(N, self.k, 3)
        knn_k = f_k[knn_indices.view(-1)].view(N, self.k, self.out_channels)
        knn_v = f_v[knn_indices.view(-1)].view(N, self.k, self.out_channels)
        
        # 3. Calculate the relative positional encoding
        rel_pos = points.view(N, 1, 3) - knn_points # (N, k, 3)
        rel_pos_enc = self.mlp_pos(rel_pos.view(-1, 3)).view(N, self.k, -1) # (N, k, C_out)
        
        # 4. Vector similarity
        vec_sim = f_q.view(N, 1, self.out_channels) - knn_k + rel_pos_enc
        weights = self.mlp_attn(vec_sim.view(-1, self.out_channels)).view(N, self.k, self.out_channels)
        weights = self.softmax(weights) # (N, k, C_out)
        
        # 5. Weighted sum
        weighted_knn_v = weights * (knn_v + rel_pos_enc) # (N, k, C_out)
        out_features = weighted_knn_v.sum(dim=1) # (N, C_out)
        
        return out_features
    
    
class PointTransformerBlock(nn.Module):
    
    def __init__(self, channels, k):
        super(PointTransformerBlock, self).__init__()
        self.linear_in = nn.Linear(channels, channels)
        self.pt_layer = PointTransformerLayer(channels, channels, k)
        self.linear_out = nn.Linear(channels, channels)
        
    def forward(self, points, features):
        out_features = self.linear_in(features)
        out_features = self.pt_layer(points, out_features)
        out_features = self.linear_out(out_features)
        out_features += features
        
        return out_features
    
    
class TransitionDown(nn.Module):
    
    def __init__(self, channels, ratio, k):
        super(TransitionDown, self).__init__()
        self.channels = channels
        self.ratio = ratio
        self.k = k
        
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels, bias=False),
            nn.BatchNorm1d(channels),
            nn.ReLU(inplace=True),
            nn.Linear(channels, channels, bias=False),
            nn.BatchNorm1d(channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, points, features):
        N = len(points)
        M = int(N / self.ratio)
        
        # 1. Farthest point sampling
        sampled_indices = farthest_point_sampling(points, M)
        sampled_points = points[sampled_indices]
        
        # 2. kNN search
        knn_dist, knn_indices = find_knn_general(sampled_points, points, self.k) # (M, K)
        
        # 3. MLP
        knn_features = features[knn_indices.view(-1)] # (M*K, C)
        out_knn_features = self.mlp(knn_features)
        out_knn_features = out_knn_features.view(M, self.k, -1)
        
        # 4. Local max pooling
        out_features = out_knn_features.max(dim=1)[0]
        
        return sampled_points, out_features
    
    
class TransitionUp(nn.Module):
    
    def __init__(self, up_channels, down_channels, out_channels):
        super(TransitionUp, self).__init__()
        self.linear_up = nn.Linear(up_channels, out_channels)
        self.linear_down = nn.Linear(down_channels, out_channels)
        
    def forward(self, up_points, up_features, down_points, down_features):
        # 1. Feed-forward with the down linear layer
        down_f = self.linear_down(down_features)
        
        # 2. Interpolation
        interp_f = interpolate_knn(up_points, down_points, down_f, 3) # (N, C_out)
        
        # 3. Skip-connection
        out_f = interp_f + self.linear_up(up_features)
        
        return out_f
    
    
class SimplePointTransformer(nn.Module):
    
    def __init__(self, in_channels, out_channels, ratio, k):
        super(SimplePointTransformer, self).__init__()
        self.layer = PointTransformerLayer(in_channels, out_channels, k)
        self.down = TransitionDown(out_channels, ratio, k)
        self.up = TransitionUp(out_channels, out_channels, out_channels)
        
    def forward(self, points, features):
        skip_features = self.layer(points, features)
        down_points, out_features = self.down(points, skip_features)
        out_features = self.up(points, skip_features, down_points, out_features)
        
        return out_features
    
    
############ Test ##############
N = 100
K = 5
ratio = 4
C_in = 3
C_out = 16

points = torch.randn(N, 3)
features = torch.randn(N, C_in)
net = SimplePointTransformer(C_in, C_out, ratio, K)

out_features = net(points, features)
print(out_features.shape)

torch.Size([100, 16])


## Utilities

### 1. Radius Search
- Input: dataset points (N, 3), query points (M, 3) and the radius, R
- Output: indices, a list

In [3]:
def find_radius_general(query_points, dataset_points, r):
    M = len(query_points)
    N = len(dataset_points)
    
    # 1. Compute pairwise distance
    delta = query_points.view(M, 1, 3) - dataset_points.view(1, N, 3) # (M, N, 3)
    dist = torch.sum(delta ** 2, dim=-1) # (M, N)
    
    # 2. Find indices
    mask = dist < r
    indices = []
    for mask_ in mask:
        indices.append(torch.nonzero(mask_, as_tuple=True)[0])
    
    return indices

In [4]:
N = 20
M = 100
R = 0.6

query_points = torch.randn(N, 3)
dataset_points = torch.randn(M, 3)

indices = find_radius_general(query_points, dataset_points, R)
print(len(indices))
print(indices[0])

20
tensor([ 3, 15, 37, 41, 63, 68, 77, 88, 92])


### 2. IoU for Axis-aligned 3D Bounding Boxes
- Input: a 3D bounding box (6,), the other set 3D bounding box (6,)
- Output: IoU, a scalar

In [5]:
def cal_iou2d(bb1, bb2):
    # bounding box: (x1, y1, x2, y2), (x1, y1): top left, (x2, y2): bottom right
    # Coordinate system:
    # (0., 0.) ... (0., 0.1) ... (0., 1.)
    # (0.1, 0.), ...
    #  ...
    # (1., 0.), ...
    
    # 1. Find coordinates of the intersection rectangle.
    x_left = max(bb1[0], bb2[0])
    y_top = max(bb1[1], bb2[1])
    x_right = min(bb1[2], bb2[2])
    y_bottom = min(bb1[3], bb2[3])
    
    # 2. If there is no overlap, return 0. Otherwise, calculate the IoU.
    if x_right < x_left or y_bottom < y_top:
        iou = 0.
    else:
        intersection_area = (x_right - x_left) * (y_bottom - y_top)
        bb1_area = (bb1[2] - bb1[0]) * (bb1[3] - bb1[1])
        bb2_area = (bb2[2] - bb2[0]) * (bb2[3] - bb2[1])
        iou = intersection_area / (bb1_area + bb2_area - intersection_area)
        
    return iou

In [6]:
bb1 = torch.Tensor([0.1, 0.1, 0.3, 0.3])
bb2 = torch.Tensor([0.2, 0.2, 0.4, 0.4])
iou = cal_iou2d(bb1, bb2) # should be 0.142857... (= 0.01 / (0.04 + 0.04 - 0.01))
print(iou)

tensor(0.1429)


In [7]:
def cal_iou3d(bb1, bb2):
    # bounding box: (x1, y1, z1, x2, y2, z2), z1 < z2
    # Use the same coordinate system
    
    # 1. Find coordinates of the intersection cuboid.
    x_small = max(bb1[0], bb2[0])
    y_small = max(bb1[1], bb2[1])
    z_small = max(bb1[2], bb2[2])
    x_large = min(bb1[3], bb2[3])
    y_large = min(bb1[4], bb2[4])
    z_large = min(bb1[5], bb2[5])
    
    # 2. If there is no overlap, return 0. Otherwise, find the overlapped volume.
    if x_large < x_small or y_large < y_small or z_large < z_small:
        iou = 0.
    else:
        intersection_volume = (x_large - x_small) * (y_large - y_small) * (z_large - z_small)
        bb1_volume = (bb1[3] - bb1[0]) * (bb1[4] - bb1[1]) * (bb1[5] - bb1[2])
        bb2_volume = (bb2[3] - bb2[0]) * (bb2[4] - bb2[1]) * (bb2[5] - bb2[2])
        iou = intersection_volume / (bb1_volume + bb2_volume - intersection_volume)
        
    return iou

In [8]:
bb1 = torch.Tensor([0.1, 0.1, 0.1, 0.3, 0.3, 0.3])
bb2 = torch.Tensor([0.2, 0.2, 0.2, 0.4, 0.4, 0.4])
iou = cal_iou3d(bb1, bb2) # should be 0.066666... (= 0.001 / (0.008 + 0.008 - 0.001))
print(iou)

tensor(0.0667)


In [9]:
def cal_iou3d_multi(box, boxes):
    # box: (x1, y1, z1, x2, y2, z2), z1 < z2
    # boxes: (N, 6)
    # Use the same coordinate system
    
    # 1. Find coordinates of the intersection cuboid.
    x_small = boxes[:, 0].clamp(min=box[0])
    y_small = boxes[:, 1].clamp(min=box[1])
    z_small = boxes[:, 2].clamp(min=box[2])
    x_large = boxes[:, 3].clamp(max=box[3])
    y_large = boxes[:, 4].clamp(max=box[4])
    z_large = boxes[:, 5].clamp(max=box[5])
    
    # 2. Define the delta tensor.
    x_delta = x_large - x_small
    y_delta = y_large - y_small
    z_delta = z_large - z_small
    
    # 3. Calculate IoUs.
    iou = torch.zeros((len(boxes),), dtype=box.dtype)
    has_overlap = (x_delta > 0) * (y_delta > 0) * (z_delta > 0)
    
    # 4. Find the overlapped volume if there is overlap.
    if len(has_overlap.nonzero()) == 0:
        return iou
    else:
        boxes_valid = boxes[has_overlap]
        x_delta_valid = x_delta[has_overlap]
        y_delta_valid = y_delta[has_overlap]
        z_delta_valid = z_delta[has_overlap]

        intersection_volume = x_delta_valid * y_delta_valid * z_delta_valid
        box_volume = (box[3] - box[0]) * (box[4] - box[1]) * (box[5] - box[2])
        boxes_volume = (boxes_valid[:, 3] - boxes_valid[:, 0]) \
                        * (boxes_valid[:, 4] - boxes_valid[:, 1]) \
                        * (boxes_valid[:, 5] - boxes_valid[:, 2])
        iou_valid = intersection_volume / (box_volume + boxes_volume - intersection_volume)

        iou[has_overlap] = iou_valid
        
    return iou

In [10]:
box = torch.tensor([0.1, 0.1, 0.1, 0.3, 0.3, 0.3])
boxes = torch.tensor([[0.2, 0.2, 0.2, 0.4, 0.4, 0.4], [0.01, 0.01, 0.01, 0.03, 0.03, 0.03]])
iou = cal_iou3d_multi(box, boxes) 
print(iou)

tensor([0.0667, 0.0000])


### 3. 3D Non-Maximum Suppression
- Input: a set of bounding boxes (N, 6), the corresponding scores (N,), iou_threshold
- Output: the output boxes (M, 6)

In [11]:
def nms(boxes, scores, threshold):
    # 1. Sort boxes in the ascending order of scores.
    order = scores.argsort()
    
    # 2. Iteratively perform NMS.
    keep = []
    while len(order) > 0:
        # 2-1. Pick the box with the highest score among the remaining boxes.
        idx = order[-1]
        box = boxes[idx]
        keep.append(box)
        order = order[:-1]
        
        if len(order) == 0:
            break
            
        # 2-2. Calculate IoU between the selected box and the others.
        remaining_boxes = boxes[order]
        iou = cal_iou3d_multi(box, remaining_boxes)
        
        # 2-3. Find the non-maximum boxes.
        mask = iou < threshold
        order = order[mask]
        
    return torch.stack(keep)

In [12]:
boxes = torch.tensor([
    [0.1, 0.1, 0.1, 0.3, 0.3, 0.3],
    [0.11, 0.11, 0.11, 0.31, 0.31, 0.31],
    [0.2, 0.2, 0.2, 0.4, 0.4, 0.4],
    [0.21, 0.21, 0.21, 0.41, 0.41, 0.41]
])
scores = torch.tensor([0.9, 0.8, 0.7, 0.6])
nms_boxes = nms(boxes, scores, threshold=0.5)
print(nms_boxes)

tensor([[0.1000, 0.1000, 0.1000, 0.3000, 0.3000, 0.3000],
        [0.2000, 0.2000, 0.2000, 0.4000, 0.4000, 0.4000]])


## Modules

### 1. Voting Module (including feature extraction)
- Input: input points (N, 3), input features (N, C_in), output feature dimension (C_out), num_votes + ratio, K (for Point Transformer)
- Output: the output votes (num_votes, 3), the output features (num_votes, C_out)

In [13]:
class VotingModule(nn.Module):
    
    def __init__(self, in_channels, out_channels, num_votes, ratio, k):
        super(VotingModule, self).__init__()
        self.num_votes = num_votes
        
        self.pfe = SimplePointTransformer(in_channels, out_channels, ratio, k) # Point Feature Extractor
        self.voter = nn.Sequential(
            nn.Linear(out_channels, out_channels, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Linear(out_channels, 3 + out_channels) # delta_x (3) and delta_f (C_out)
        )
        
    def forward(self, points, features):
        # 1. Point Feature Extraction (In our case, Point Transformer)
        out_features = self.pfe(points, features)
        
        # 2. Sample seed points
        indices = farthest_point_sampling(points, self.num_votes)
        seed_points = points[indices]
        seed_features = out_features[indices]
        
        # 3. Voting
        residuals = self.voter(seed_features)
        vote_points = seed_points + residuals[:, :3]
        vote_features = seed_features + residuals[:, 3:]
        
        return vote_points, vote_features

In [14]:
N = 50
C_in = 3
C_out = 16
num_votes = 32
ratio = 4
K = 5

points = torch.randn(N, 3)
features = torch.randn(N, C_in)
voting_m = VotingModule(C_in, C_out, num_votes, ratio, K)

vote_points, vote_features = voting_m(points, features)
print(vote_points.shape)
print(vote_features.shape)

torch.Size([32, 3])
torch.Size([32, 16])


### 2. Detection Head
- Input: vote_points (N, 3), vote_features(N, C_in), num_clusters, radius, nms_iou_threshold
- Output: the detected bounding boxes (M, 1 + 6) # 1 (objectness) + 6 (box coordinates)

In [15]:
class DetectionHead(nn.Module):
    
    def __init__(self, in_channels, num_clusters, radius, nms_iou_threshold):
        super(DetectionHead, self).__init__()
        self.num_clusters = num_clusters
        self.radius = radius
        self.nms_iou_threshold = nms_iou_threshold
        
        self.mlp1 = nn.Sequential(
            nn.Linear(3 + in_channels, in_channels),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels, in_channels),
            nn.ReLU(inplace=True)
        )
        self.mlp2 = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels, in_channels),
            nn.ReLU(inplace=True)
        )
        self.final = nn.Linear(in_channels, 7)
        
    def forward(self, vote_points, vote_features):
        # 1. Sample cluster centroids.
        sampled_indices = farthest_point_sampling(vote_points, self.num_clusters)
        cluster_points = vote_points[sampled_indices]
        
        # 2. Find cluster neighbors.
        indices = find_radius_general(cluster_points, vote_points, self.radius) # List[torch.LongTensor]
        
        # 3. Grouping (MLP1 and MLP2)
        grouped_features = []
        for group_center, group_indices in zip(cluster_points, indices):
            # 3-1. Calculate the relative position.
            features_in_group = vote_features[group_indices]
            relative_pos = (group_center.unsqueeze(0) - vote_points[group_indices]) / self.radius
            features_with_pos = torch.cat([relative_pos, features_in_group], dim=1)
            
            # 3-2. MLP1 -> MaxPool -> MLP2
            group_feature = self.mlp1(features_with_pos).max(dim=0)[0]
            group_feature = self.mlp2(group_feature)
            grouped_features.append(group_feature)
        grouped_features = torch.stack(grouped_features)
        
        # 4. Predict bounding boxes
        boxes = self.final(grouped_features)
        box_scores = boxes[:, 0].sigmoid()
        box_coordinates = boxes[:, 1:]
        
        # 5. Non-maximum suppression
        final_boxes = nms(box_coordinates, box_scores, self.nms_iou_threshold)
        
        return final_boxes

In [16]:
N = 50
C_in = 8
num_clusters = 16
radius = 0.9
iou_threshold = 0.5

vote_points = torch.randn(N, 3)
vote_features = torch.randn(N, C_in)
detection_h = DetectionHead(C_in, num_clusters, radius, iou_threshold)

pred_boxes = detection_h(vote_points, vote_features)
print(pred_boxes.shape)

torch.Size([16, 6])
