## Backpropagation through `knn`

In [None]:
import torch
import torch.nn.functional as F

# Example: Feature matrix [B, N, D]
B, N, D = 2, 5, 3
features = torch.randn(B, N, D, requires_grad=True)

# Compute pairwise distances
pairwise_dist = torch.cdist(features, features, p=2)  # [B, N, N]

# Get k-nearest neighbors (k=3)
k = 2
_, knn_indices = pairwise_dist.topk(k, dim=-1, largest=False)

# Gather neighbor features
neighbors = torch.gather(
    features.unsqueeze(2).expand(-1, -1, N, -1),
    2,
    knn_indices.unsqueeze(-1).expand(-1, -1, -1, D)
)

# Simple loss on neighbors (mean of neighbor features)
loss = neighbors.mean()
loss.backward()

print("Gradients on features:\n", features.grad)

## Edge Cnvolution with edge features

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EdgeConvWithEdgeFeatures(nn.Module):
    def __init__(self, in_channels, edge_in_channels, out_channels, k):
        super(EdgeConvWithEdgeFeatures, self).__init__()
        self.k = k
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels + edge_in_channels, out_channels, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x, edge_features):
        """
        Args:
            x: Input point cloud data, shape [B, N, D]
               B - batch size, N - number of points, D - feature dimensions
            edge_features: Input edge features, shape [B, N, k, E]
               E - edge feature dimensions
        Returns:
            x_out: Updated features after EdgeConv, shape [B, N, out_channels]
        """
        B, N, D = x.size()
        _, _, _, E = edge_features.size()
        
        # Step 1: Compute pairwise distance and get k-nearest neighbors
        pairwise_dist = torch.cdist(x, x, p=2)  # [B, N, N]
        idx = pairwise_dist.topk(k=self.k, dim=-1, largest=False)[1]  # [B, N, k]
        
        # Step 2: Gather neighbor features
        neighbors = torch.gather(
            x.unsqueeze(2).expand(-1, -1, N, -1), 
            2, 
            idx.unsqueeze(-1).expand(-1, -1, -1, D)
        )  # [B, N, k, D]
        
        # Central point repeated for k neighbors: [B, N, k, D]
        central = x.unsqueeze(2).expand(-1, -1, self.k, -1)
        
        # Step 3: Compute edge features
        relative_features = neighbors - central  # [B, N, k, D]
        combined_features = torch.cat([central, relative_features, edge_features], dim=-1)  # [B, N, k, 2*D + E]
        
        # Step 4: Apply MLP and aggregation
        combined_features = self.mlp(combined_features.view(-1, 2 * D + E))  # [B * N * k, out_channels]
        combined_features = combined_features.view(B, N, self.k, -1)  # Reshape to [B, N, k, out_channels]
        
        # Aggregate (max pooling across neighbors)
        x_out = combined_features.max(dim=2)[0]  # [B, N, out_channels]
        
        return x_out


### Without Edge Features

In [None]:
B, N, D, k = 1, 4, 3, 2
x = torch.rand(B, N, D)  # Point cloud (Graph) features
edge_features = torch.zeros(B, N, k, 0)  # No edge features (E = 0)

model = EdgeConvWithEdgeFeatures(in_channels=D, edge_in_channels=0, out_channels=16, k=k)
output = model(x, edge_features)
print("Output Shape:", output.shape)  # [B, N, 16]

Output Shape: torch.Size([1, 4, 16])


### With Edge Features

In [None]:
B, N, D, k, E = 1, 4, 3, 2, 2
x = torch.rand(B, N, D)  # Point cloud (Graph) features
edge_features = torch.rand(B, N, k, E)  # Edge features

model = EdgeConvWithEdgeFeatures(in_channels=D, edge_in_channels=E, out_channels=16, k=k)
output = model(x, edge_features)
print("Output Shape:", output.shape)  # [B, N, 16]

Output Shape: torch.Size([1, 4, 16])
