## Backpropagation through `knn`

In [None]:
import torch
import torch.nn.functional as F

# Example: Feature matrix [B, N, D]
B, N, D = 2, 5, 3
features = torch.randn(B, N, D, requires_grad=True)

# Compute pairwise distances
pairwise_dist = torch.cdist(features, features, p=2)  # [B, N, N]

# Get k-nearest neighbors (k=3)
k = 2
_, knn_indices = pairwise_dist.topk(k, dim=-1, largest=False)

# Gather neighbor features
neighbors = torch.gather(
    features.unsqueeze(2).expand(-1, -1, N, -1),
    2,
    knn_indices.unsqueeze(-1).expand(-1, -1, -1, D)
)

# Simple loss on neighbors (mean of neighbor features)
loss = neighbors.mean()
loss.backward()

print("Gradients on features:\n", features.grad)

## Edge Cnvolution with edge features

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EdgeConvWithEdgeFeatures(nn.Module):
    def __init__(self, in_channels, edge_in_channels, out_channels, k):
        super(EdgeConvWithEdgeFeatures, self).__init__()
        self.k = k
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels + edge_in_channels, out_channels, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x, edge_features):
        """
        Args:
            x: Input point cloud data, shape [B, N, D]
               B - batch size, N - number of points, D - feature dimensions
            edge_features: Input edge features, shape [B, N, k, E]
               E - edge feature dimensions
        Returns:
            x_out: Updated features after EdgeConv, shape [B, N, out_channels]
        """
        B, N, D = x.size()
        _, _, _, E = edge_features.size()
        
        # Step 1: Compute pairwise distance and get k-nearest neighbors
        pairwise_dist = torch.cdist(x, x, p=2)  # [B, N, N]
        idx = pairwise_dist.topk(k=self.k, dim=-1, largest=False)[1]  # [B, N, k]
        
        # Step 2: Gather neighbor features
        neighbors = torch.gather(
            x.unsqueeze(2).expand(-1, -1, N, -1), 
            2, 
            idx.unsqueeze(-1).expand(-1, -1, -1, D)
        )  # [B, N, k, D]
        
        # Central point repeated for k neighbors: [B, N, k, D]
        central = x.unsqueeze(2).expand(-1, -1, self.k, -1)
        
        # Step 3: Compute edge features
        relative_features = neighbors - central  # [B, N, k, D]
        combined_features = torch.cat([central, relative_features, edge_features], dim=-1)  # [B, N, k, 2*D + E]
        
        # Step 4: Apply MLP and aggregation
        combined_features = self.mlp(combined_features.view(-1, 2 * D + E))  # [B * N * k, out_channels]
        combined_features = combined_features.view(B, N, self.k, -1)  # Reshape to [B, N, k, out_channels]
        
        # Aggregate (max pooling across neighbors)
        x_out = combined_features.max(dim=2)[0]  # [B, N, out_channels]
        
        return x_out


### Without Edge Features

In [None]:
B, N, D, k = 1, 4, 3, 2
x = torch.rand(B, N, D)  # Point cloud (Graph) features
edge_features = torch.zeros(B, N, k, 0)  # No edge features (E = 0)

model = EdgeConvWithEdgeFeatures(in_channels=D, edge_in_channels=0, out_channels=16, k=k)
output = model(x, edge_features)
print("Output Shape:", output.shape)  # [B, N, 16]

Output Shape: torch.Size([1, 4, 16])


### With Edge Features

In [None]:
B, N, D, k, E = 1, 4, 3, 2, 2
x = torch.rand(B, N, D)  # Point cloud (Graph) features
edge_features = torch.rand(B, N, k, E)  # Edge features

model = EdgeConvWithEdgeFeatures(in_channels=D, edge_in_channels=E, out_channels=16, k=k)
output = model(x, edge_features)
print("Output Shape:", output.shape)  # [B, N, 16]

Output Shape: torch.Size([1, 4, 16])


### Custom Attention with Interaction Matrix

In [None]:
class AugmentedMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, u_feat=4, dropout=0.0, num_heads=1, qkv_bias=False):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads # Reduce the projection dim to match desired output dim
        
        self.u_embd = nn.Linear(u_feat, self.num_heads)

        self.W_q = nn.LazyLinear(embed_dim, bias=qkv_bias)
        self.W_k = nn.LazyLinear(embed_dim, bias=qkv_bias)
        self.W_v = nn.LazyLinear(embed_dim, bias=qkv_bias)
        self.W_o = nn.LazyLinear(embed_dim)  # Linear output layer to combine head outputs
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, u=None):
        # x shape = (batch, num_tokens, embed_dim)
        bs, num_tokens, _ = x.shape
        
        # Change feature dimention in the interaction matrix
        u = self.u_embd(u)

        K = self.W_k(x)
        Q = self.W_q(x)
        V = self.W_v(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        K = K.view(bs, num_tokens, self.num_heads, self.head_dim) 
        V = V.view(bs, num_tokens, self.num_heads, self.head_dim)
        Q = Q.view(bs, num_tokens, self.num_heads, self.head_dim)
        print(K.shape)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        K = K.transpose(1, 2)
        Q = Q.transpose(1, 2)
        V = V.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = Q @ K.transpose(2, 3)  # Dot product for each head (num_tokens, head_dim) * (head_dim, num_tokens)
        
        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = u != 0

        # Use the mask to fill attention scores
        u.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax((attn_scores / K.shape[-1]**0.5) + u.transpose(3, 1), dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ V).transpose(1, 2) 

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(bs, num_tokens, self.embed_dim)
        context_vec = self.W_o(context_vec) # optional projection

        return context_vec

### Particle Transformer Simple Implementation

In [None]:
import torch
import torch.nn as nn

class ParticleAttentionBlock(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4, num_heads=1):
        super(ParticleAttentionBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.pmha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, expansion_factor * embed_dim),
            nn.GELU(),
            nn.LayerNorm(expansion_factor * embed_dim),
            nn.Linear(expansion_factor * embed_dim, embed_dim),
        )

    def forward(self, x, u=None):
        x_res = x
        x = self.norm1(x)
        attn_output, _ = self.pmha(x,x,x, need_weights=False, attn_mask=u.flatten(start_dim=0, end_dim=1))
        x = self.norm2(attn_output)
        h = x + x_res
        z = self.mlp(h)
        return z+h

class ParticleTransformer(nn.Module):
    def __init__(self, feat_particles, feat_interaction, embed_dim, num_heads, num_blocks, num_classes=1):
        super(ParticleTransformer, self).__init__()
        assert (embed_dim % num_heads == 0), \
            "embed_dim must be divisible by num_heads"        
        self.particle_embed = nn.Linear(feat_particles, embed_dim)
        self.interaction_embed = nn.Linear(feat_interaction, num_heads)
        self.blocks = nn.ModuleList([
            ParticleAttentionBlock(embed_dim=embed_dim, num_heads=num_heads) for _ in range(num_blocks)
        ])
        self.mlp_head = nn.Sequential(
            nn.Linear(embed_dim, num_classes),
            nn.Sigmoid()
        )

    def forward(self, particles, interactions):
        x = self.particle_embed(particles)
        u = self.interaction_embed(interactions).transpose(3, 1)
        print(x.shape, u.shape)

        for block in self.blocks:
            x = block(x, u)

        # Aggregate features (e.g., mean pooling)
        x = x.mean(dim=1)  # Pool across particles

        logits = self.mlp_head(x)
        return logits

# Example usage
num_particles = 10
embed_dim = 64
num_blocks = 4
num_classes = 5

particles = torch.rand(5, num_particles, 64)
interactions = torch.rand(5, 1, num_particles, num_particles)

model = ParticleAttentionBlock(embed_dim=64)
output = model(particles, interactions)
print(output.shape)  # Output shape: [32, num_classes]
