## Backpropagation through `knn`

In [None]:
import torch
import torch.nn.functional as F

# Example: Feature matrix [B, N, D]
B, N, D = 2, 5, 11
features = torch.randn(B, N, D, requires_grad=True)

# Compute pairwise distances
pairwise_dist = torch.cdist(features[..., -2:], features[..., -2:], p=2)  # [B, N, N]

# Get k-nearest neighbors (k=3)
k = N
_, knn_indices = pairwise_dist.topk(k, dim=-1, largest=False)

# Gather neighbor features
neighbors = torch.gather(
    features.unsqueeze(2).expand(-1, -1, N, -1),
    2,
    knn_indices.unsqueeze(-1).expand(-1, -1, -1, D)
)
print(neighbors.shape, pairwise_dist.shape, knn_indices.shape)

# Simple loss on neighbors (mean of neighbor features)
loss = neighbors.mean()
loss.backward()

#print("Gradients on features:\n", features.grad)

## Edge Cnvolution with edge features

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EdgeConvWithEdgeFeatures(nn.Module):
    def __init__(self, in_channels, edge_in_channels, out_channels, k):
        super(EdgeConvWithEdgeFeatures, self).__init__()
        self.k = k
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels + edge_in_channels, out_channels, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x, edge_features):
        """
        Args:
            x: Input point cloud data, shape [B, N, D]
               B - batch size, N - number of points, D - feature dimensions
            edge_features: Input edge features, shape [B, N, k, E]
               E - edge feature dimensions
        Returns:
            x_out: Updated features after EdgeConv, shape [B, N, out_channels]
        """
        B, N, D = x.size()
        _, _, _, E = edge_features.size()
        
        # Step 1: Compute pairwise distance and get k-nearest neighbors
        pairwise_dist = torch.cdist(x, x, p=2)  # [B, N, N]
        idx = pairwise_dist.topk(k=self.k, dim=-1, largest=False)[1]  # [B, N, k]
        
        # Step 2: Gather neighbor features
        neighbors = torch.gather(
            x.unsqueeze(2).expand(-1, -1, N, -1), 
            2, 
            idx.unsqueeze(-1).expand(-1, -1, -1, D)
        )  # [B, N, k, D]
        
        # Central point repeated for k neighbors: [B, N, k, D]
        central = x.unsqueeze(2).expand(-1, -1, self.k, -1)
        
        # Step 3: Compute edge features
        relative_features = neighbors - central  # [B, N, k, D]
        combined_features = torch.cat([central, relative_features, edge_features], dim=-1)  # [B, N, k, 2*D + E]
        
        # Step 4: Apply MLP and aggregation
        combined_features = self.mlp(combined_features.view(-1, 2 * D + E))  # [B * N * k, out_channels]
        combined_features = combined_features.view(B, N, self.k, -1)  # Reshape to [B, N, k, out_channels]
        
        # Aggregate (max pooling across neighbors)
        x_out = combined_features.max(dim=2)[0]  # [B, N, out_channels]
        
        return x_out


### Without Edge Features

In [None]:
B, N, D, k = 1, 4, 3, 2
x = torch.rand(B, N, D)  # Point cloud (Graph) features
edge_features = torch.zeros(B, N, k, 0)  # No edge features (E = 0)

model = EdgeConvWithEdgeFeatures(in_channels=D, edge_in_channels=0, out_channels=16, k=k)
output = model(x, edge_features)
print("Output Shape:", output.shape)  # [B, N, 16]

### With Edge Features

In [None]:
B, N, D, k, E = 1, 4, 3, 2, 2
x = torch.rand(B, N, D)  # Point cloud (Graph) features
edge_features = torch.rand(B, N, k, E)  # Edge features

model = EdgeConvWithEdgeFeatures(in_channels=D, edge_in_channels=E, out_channels=16, k=k)
output = model(x, edge_features)
print("Output Shape:", output.shape)  # [B, N, 16]

### Custom Attention with Interaction Matrix

In [None]:
class AugmentedMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, u_feat=4, dropout=0.0, num_heads=1, qkv_bias=False):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads # Reduce the projection dim to match desired output dim
        
        self.u_embd = nn.Linear(u_feat, self.num_heads)

        self.W_q = nn.LazyLinear(embed_dim, bias=qkv_bias)
        self.W_k = nn.LazyLinear(embed_dim, bias=qkv_bias)
        self.W_v = nn.LazyLinear(embed_dim, bias=qkv_bias)
        self.W_o = nn.LazyLinear(embed_dim)  # Linear output layer to combine head outputs
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, u=None):
        # x shape = (batch, num_tokens, embed_dim)
        bs, num_tokens, _ = x.shape
        
        # Change feature dimention in the interaction matrix
        u = self.u_embd(u)

        K = self.W_k(x)
        Q = self.W_q(x)
        V = self.W_v(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        K = K.view(bs, num_tokens, self.num_heads, self.head_dim) 
        V = V.view(bs, num_tokens, self.num_heads, self.head_dim)
        Q = Q.view(bs, num_tokens, self.num_heads, self.head_dim)
        print(K.shape)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        K = K.transpose(1, 2)
        Q = Q.transpose(1, 2)
        V = V.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = Q @ K.transpose(2, 3)  # Dot product for each head (num_tokens, head_dim) * (head_dim, num_tokens)
        
        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = u != 0

        # Use the mask to fill attention scores
        u.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax((attn_scores / K.shape[-1]**0.5) + u.transpose(3, 1), dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ V).transpose(1, 2) 

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(bs, num_tokens, self.embed_dim)
        context_vec = self.W_o(context_vec) # optional projection

        return context_vec

### Particle Transformer Simple Implementation

In [4]:
import math
import torch
import torch.nn as nn

def lambda_init_fn(depth):
    return 0.8 - 0.6 * math.exp(-0.3 * depth)

class InteractionInputEncoding(nn.Module):
    def __init__(self, input_dim=4, output_dim=8):
        super(InteractionInputEncoding, self).__init__()
        self.embed = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.GELU(),
            nn.LayerNorm(64),
            nn.Linear(64, 64),
            nn.GELU(),
            nn.LayerNorm(64),
            nn.Linear(64, 64),
            nn.GELU(),
            nn.LayerNorm(64),
            nn.Linear(64, output_dim),
            nn.GELU(),
        )

    def forward(self, x):
        return self.embed(x)


class MultiheadDiffAttn(nn.Module):
    def __init__(self, embed_dim, depth, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        # Dimension per head
        self.head_dim = embed_dim // num_heads // 2  # 2 attention maps
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        assert embed_dim % 2 == 0, "embed_dim must be divisible by 2"
        self.scaling = self.head_dim**-0.5

        # Projections
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

        # Lambda parameters
        self.lambda_init = lambda_init_fn(depth)
        self.lambda_q1 = nn.Parameter(
            torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
        )
        self.lambda_k1 = nn.Parameter(
            torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
        )
        self.lambda_q2 = nn.Parameter(
            torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
        )
        self.lambda_k2 = nn.Parameter(
            torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
        )

        # Normalization layer
        self.subln = nn.LayerNorm(self.head_dim * 2)

    def forward(self, x, u=None):
        batch_size, num_tokens, _ = x.size()

        # Linear projections
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Reshape into heads
        q = q.view(batch_size, num_tokens, self.num_heads * 2, self.head_dim).transpose(
            1, 2
        )  # shape (bs, num_heads*2, num_tokens, head_dim)
        k = k.view(batch_size, num_tokens, self.num_heads * 2, self.head_dim).transpose(
            1, 2
        )
        v = v.view(batch_size, num_tokens, self.num_heads, 2 * self.head_dim).transpose(
            1, 2
        )

        # Scale queries
        q *= self.scaling

        # Compute attention weights
        attn_weights = torch.matmul(q, k.transpose(-1, -2))
        if u is not None:
            mask = ~u.bool()
            u.masked_fill_(mask, -torch.inf)
            attn_weights += u.transpose(3, 1)

        attn_weights = F.softmax(attn_weights, dim=-1)

        # Compute lambda
        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float())
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float())
        lambda_full = lambda_1 - lambda_2 + self.lambda_init

        # Apply differential attention
        attn_weights = attn_weights.view(
            batch_size, self.num_heads, 2, num_tokens, num_tokens
        )
        attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
        # Weighted sum
        attn = torch.matmul(attn_weights, v)

        # Normalize and reshape
        attn = self.subln(attn)
        attn = attn.transpose(1, 2).reshape(batch_size, num_tokens, self.embed_dim)

        # Final projection
        attn = self.out_proj(attn)
        return attn


class ParticleAttentionBlock(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4, num_heads=1, num_layers=1):
        super(ParticleAttentionBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.pmha = MultiheadDiffAttn(
            embed_dim=embed_dim, num_heads=num_heads, depth=num_layers
        )
        self.mlp = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, expansion_factor * embed_dim),
            nn.GELU(),
            nn.LayerNorm(expansion_factor * embed_dim),
            nn.Linear(expansion_factor * embed_dim, embed_dim),
        )

    def forward(self, x, u=None):
        x_res = x
        x = self.norm1(x)
        attn_output = self.pmha(x, u)  # x, and u embeddings
        x = self.norm2(attn_output)
        h = x + x_res
        z = self.mlp(h)
        return z + h

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DynamicEdgeConv(nn.Module):
    def __init__(self, in_channels, embed_dim, k, out_channels=None):
        super(DynamicEdgeConv, self).__init__()
        self.k = k
        out_channels = embed_dim if out_channels is None else out_channels
        self.mlp = nn.Sequential(
            nn.Linear(in_channels * 2, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, out_channels),
            nn.LayerNorm(out_channels),
            nn.GELU(),            
        )
    
    def forward(self, x):
        """
        Args:
            x: Input point cloud data, shape [B, N, D]
               B - batch size, N - number of points, D - feature dimensions
        Returns:
            x_out: Updated features after EdgeConv, shape [B, N, out_channels]
        """
        B, N, D = x.size()
        
        # Step 1: Compute pairwise distance and get k-nearest neighbors
        # TODO: remove hard-coded 8 and 9 to replace with eta and phi
        pairwise_dist = torch.cdist(x[..., [8,9]], x[..., [8,9]], p=2)  # [B, N, N]
        idx = pairwise_dist.topk(k=self.k, dim=-1, largest=False)[1]  # [B, N, k]
        
        # Step 2: Gather neighbor features
        neighbors = torch.gather(
            x.unsqueeze(2).expand(-1, -1, N, -1), 
            2, 
            idx.unsqueeze(-1).expand(-1, -1, -1, D)
        )  # [B, N, k, D]
        
        # Central point repeated for k neighbors: [B, N, k, D]
        central = x.unsqueeze(2).expand(-1, -1, self.k, -1)
        
        # Step 3: Compute edge features
        relative_features = neighbors - central  # [B, N, k, D]
        combined_features = torch.cat([central, relative_features], dim=-1)  # [B, N, k, 2*D]

        # Step 4: Apply MLP and aggregation
        combined_features = self.mlp(combined_features.view(-1, 2 * D))  # [B * N * k, out_channels]
        combined_features = combined_features.view(B, N, self.k, -1)  # Reshape to [B, N, k, out_channels]
        
        # Aggregate (avg pooling across neighbors)
        x_out = combined_features.mean(dim=2)  # [B, N, out_channels]

        return x_out

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ParticleNet(nn.Module):
    def __init__(self, in_channels = 11, num_layers=3, embed_dims=[64, 128, 256], k=[16, 16, 16]):
        super(ParticleNet, self).__init__()
        
        # Ensure embed_dims and k are lists
        assert isinstance(embed_dims, list), f"Expected embed_dims to be a list, but got {type(embed_dims)}"
        assert isinstance(k, list), f"Expected k to be a list, but got {type(k)}"        
        
        # Assertion to ensure embed_dims and k have length 3
        assert len(embed_dims) == num_layers, f"Expected embed_dims to have the same length as 'num_layers={num_layers}', but got {len(embed_dims)}"
        assert len(k) == num_layers, f"Expected k to have length the same length as 'num_layers={num_layers}', but got {len(k)}"        
        
        # Creating a list of DynamicEdgeConv layers
        self.edge_conv = nn.ModuleList([
            DynamicEdgeConv(
                in_channels=in_channels if i == 0 else embed_dims[i - 1],
                embed_dim=embed_dims[i],
                k=k[i]
            ) 
            for i in range(num_layers)
        ])
        
        #self.classifier = nn.Sequential(nn.Linear(embed_dims[-1], embed_dims[-1]),
        #                                nn.GELU(),
        #                                nn.Dropout(0.1),
        #                                nn.Linear(embed_dims[-1], 1))
    
    def forward(self, x):
        # Pass input through each DynamicEdgeConv layer
        for conv in self.edge_conv:
            x = conv(x)
        
        # Aggregate over the token dim
        # x_out = x.mean(dim=1)
        
        return x
    
class ParticleTransformer(nn.Module):
    def __init__(
        self,
        feat_particles_dim,
        feat_interaction_dim,
        embed_dim,
        num_heads,
        num_blocks,
        num_classes=1,
    ):
        super(ParticleTransformer, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.particle_embed = ParticleNet(in_channels=feat_particles_dim)
        self.interaction_embed = InteractionInputEncoding(
            input_dim=feat_interaction_dim, output_dim=num_heads * 2
        )
        self.blocks = nn.ModuleList(
            [
                ParticleAttentionBlock(
                    embed_dim=embed_dim, num_heads=num_heads, num_layers=num_blocks
                )
                for _ in range(num_blocks)
            ]
        )
        self.mlp_head = nn.Sequential(nn.Linear(embed_dim, embed_dim),
                                        nn.GELU(),
                                        nn.Dropout(0.1),
                                        nn.Linear(embed_dim, 1))

    def forward(self, particles, interactions):
        x = self.particle_embed(particles)
        u = self.interaction_embed(interactions)

        for block in self.blocks:
            x = block(x, u)

        # Aggregate features (e.g., mean pooling)
        x = x.mean(dim=1)  # Pool across particles

        logits = self.mlp_head(x)
        return logits    

x = torch.rand(5, 128, 11)
u = torch.rand(5, 128, 128, 4)

model = ParticleTransformer(feat_particles_dim=11, feat_interaction_dim=4, embed_dim=256, num_heads=8, num_blocks=4)

model(x, u)

tensor([[0.8302],
        [0.7175],
        [0.6700],
        [0.2269],
        [0.6973]], grad_fn=<AddmmBackward0>)

In [3]:
import torch
import torch.nn as nn

def symmetrize(t: torch.Tensor):
    """
    Forces symmetry along dim=1 and dim=2 for a t of shape (N, n, n, F).
    """
    # Transpose the last two dimensions (n, n) to get the lower triangular part
    lower_triangular = t.transpose(1, 2)
    # Average the tensor with its transpose to make it symmetric
    symmetric_tensor = (t + lower_triangular) / 2
    return symmetric_tensor

class InteractionInputEncoding(nn.Module):
    def __init__(self, input_dim=4, output_dim=8):
        super(InteractionInputEncoding, self).__init__()
        self.conv = nn.Sequential(nn.Conv1d(input_dim, 64, kernel_size=1),
                                  nn.BatchNorm1d(64),
                                  nn.GELU(),
                                  nn.Conv1d(64, 64, kernel_size=1),
                                  nn.BatchNorm1d(64),
                                  nn.GELU(),
                                  nn.Conv1d(64, 64, kernel_size=1),
                                  nn.BatchNorm1d(64),
                                  nn.GELU(),
                                  nn.Conv1d(64, output_dim, kernel_size=1),
                                  nn.BatchNorm1d(output_dim),
                                  nn.GELU())

    def forward(self, x):
        bs, tokens = x.size(0), x.size(1)
        x = x.view(bs, -1, x.size(-1)).permute(0, 2, 1)
        x = self.conv(x).permute(0, 2, 1)
        
        return x.view(bs, tokens, tokens, x.size(-1))
            
u = torch.rand(5, 128, 128, 4)
model = InteractionInputEncoding(input_dim=4)
model(u).shape

torch.Size([5, 128, 128, 8])