In [1]:
!pip install torch

Collecting torch
  Obtaining dependency information for torch from https://files.pythonhosted.org/packages/d0/5f/f41b14a398d484bf218d5167ec9061c1e76f500d9e25166117818c8bacda/torch-2.3.1-cp311-none-macosx_11_0_arm64.whl.metadata
  Downloading torch-2.3.1-cp311-none-macosx_11_0_arm64.whl.metadata (26 kB)
Collecting typing-extensions>=4.8.0 (from torch)
  Obtaining dependency information for typing-extensions>=4.8.0 from https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl.metadata
  Downloading typing_extensions-4.12.2-py3-none-any.whl.metadata (3.0 kB)
Downloading torch-2.3.1-cp311-none-macosx_11_0_arm64.whl (61.0 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 MB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:01[0m
[?25hUsing cached typing_extensions-4.12.2-py3-none-any.whl (37 kB)
Installing collected packages: typing-exte

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LongformerSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, attention_window, global_tokens):
        """
        embed_dim: Dimensionality of the input embeddings
        num_heads: Number of attention heads
        attention_window: Size of the local attention window
        global_tokens: Indices of tokens that will have global attention
        """
        super(LongformerSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.attention_window = attention_window
        self.global_tokens = global_tokens

        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads."
        self.head_dim = embed_dim // num_heads
        
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out_projection = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        
        # Linear transformations
        q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Sliding window local attention
        local_attention_scores = self.local_attention(q, k)
        
        # Compute global attention scores
        global_attention_scores = self.global_attention(q, k)
        
        # Combine local and global attention scores
        attention_scores = local_attention_scores + global_attention_scores
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Compute context vectors
        context = torch.matmul(attention_weights, v)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        
        # Final linear transformation
        output = self.out_projection(context)
        
        return output

    def local_attention(self, q, k):
        batch_size, num_heads, seq_len, head_dim = q.size()
        window = self.attention_window
        
        attn_mask = self.sliding_window_mask(seq_len, window)  # Create local mask
        attn_mask = attn_mask.to(q.device)                      # Move to device
        
        attention_scores = torch.einsum("bnqd,bnkd->bnqk", q, k) / (head_dim ** 0.5)
        attention_scores.masked_fill_(attn_mask, float('-inf')) # Apply local mask
        
        return attention_scores

    def global_attention(self, q, k):
        batch_size, num_heads, seq_len, head_dim = q.size()
        global_mask = torch.zeros((seq_len, seq_len), dtype=torch.bool)
        
        for token in self.global_tokens:
            global_mask[token, :] = 1
        
        global_mask = global_mask.unsqueeze(0).unsqueeze(0)  # For batch and heads
        global_mask = global_mask.to(q.device)               # Move to device
        
        global_attention_scores = torch.einsum("bnqd,bnkd->bnqk", q, k) / (head_dim ** 0.5)
        global_attention_scores.masked_fill_(~global_mask, float('-inf')) # Apply global mask
        
        return global_attention_scores
    
    def sliding_window_mask(self, seq_len, window):
        mask = torch.ones((seq_len, seq_len), dtype=torch.bool)
        for i in range(seq_len):
            mask[i, max(0, i-window):min(seq_len, i+window+1)] = 0
        return mask.unsqueeze(0).unsqueeze(0) # For batch and heads

# Example usage
embed_dim = 256
num_heads = 8
attention_window = 5
global_tokens = [0, 1, 2]

x = torch.rand((2, 20, embed_dim))  # Batch size: 2, Sequence length: 20, Embedding dimension: 256
attention_layer = LongformerSelfAttention(embed_dim, num_heads, attention_window, global_tokens)
output = attention_layer(x)
print(output.shape)  # Should be torch.Size([2, 20, 256])
                  

torch.Size([2, 20, 256])
