In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
class InfiniAttention(nn.Module):
    def __init__(self, d_model, n_heads, compressive_memory_size, segment_length):
        super(InfiniAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.segment_length = segment_length
        
        # Ensure d_model is divisible by n_heads for multi-head attention
        assert d_model % n_heads == 0
        self.d_k = d_model // n_heads
        
        # Projection matrices for Query, Key, Value
        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)
        
        # Compressive memory matrix
        self.memory = torch.zeros((compressive_memory_size, self.d_k))
        self.memory_key_sum = torch.zeros((self.d_k,))
        
        # Output projection
        self.WO = nn.Linear(d_model, d_model)
        
        # Gating parameter for combining local and memory states
        self.beta = nn.Parameter(torch.tensor(0.0))
        
    def scaled_dot_product_attention(self, Q, K, V):
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        attn_weights = F.softmax(scores, dim=-1)
        return torch.matmul(attn_weights, V), attn_weights

    def compressive_memory_retrieve(self, Q):
        # Reshape Q to match the dimensions for matrix multiplication
        Q = Q.view(Q.size(0), Q.size(1), self.n_heads, self.d_k).transpose(1, 2)  # (batch_size, n_heads, seq_length, d_k)
        
        retrieval = torch.matmul(F.elu(Q) + 1, self.memory.T)  # (batch_size, n_heads, seq_length, compressive_memory_size)
        normalization = torch.sum(F.elu(Q) + 1, dim=-1, keepdim=True)
        return retrieval / (normalization + 1e-6)
    
    def compressive_memory_update(self, K, V):
        # Reshape K and V to match the dimensions for matrix multiplication
        K = K.view(K.size(0), K.size(1), self.n_heads, self.d_k).transpose(1, 2)  # (batch_size, n_heads, seq_length, d_k)
        V = V.view(V.size(0), V.size(1), self.n_heads, self.d_k).transpose(1, 2)  # (batch_size, n_heads, seq_length, d_k)
        
        self.memory += torch.matmul(F.elu(K).transpose(-2, -1) + 1, V)
        self.memory_key_sum += torch.sum(F.elu(K) + 1, dim=0)
    
    def forward(self, X):
        # Compute Q, K, V
        Q = self.WQ(X)
        K = self.WK(X)
        V = self.WV(X)
        
        # Local scaled dot-product attention
        local_context, _ = self.scaled_dot_product_attention(Q, K, V)
        
        # Compressive memory retrieval
        global_context = self.compressive_memory_retrieve(Q)
        
        # Combine local and global contexts (Equation 10)
        combined_context = torch.sigmoid(self.beta) * global_context + (1 - torch.sigmoid(self.beta)) * local_context
        
        # Update memory with the current segment's K and V
        self.compressive_memory_update(K, V)
        
        # Output projection
        output = self.WO(combined_context.transpose(1, 2).contiguous().view(X.size(0), -1, self.d_model))
        
        return output

In [8]:
# Example usage
d_model = 512
n_heads = 8
compressive_memory_size = 1024
segment_length = 2048

infi_attention_layer = InfiniAttention(d_model, n_heads, compressive_memory_size, segment_length)

# Dummy input: Batch size of 10, segment length of 2048, model dimension of 512
input_data = torch.randn(10, segment_length, d_model)
output = infi_attention_layer(input_data)
print(output.shape)  # Expected shape: [10, 2048, 512]

RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 3