In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, num_token_types=3):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.num_token_types = num_token_types  # Number of token types (e.g., Global, Local, Sliding Window)

        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)

        # NAtS-related layers
        self.token_type_scorer = nn.Linear(embed_dim, num_token_types) # Predicts token type
        self.softmax = nn.Softmax(dim=-1)  # Convert scores to probabilities

    def forward(self, query, key, value, key_padding_mask=None, attn_mask=None):
        seq_len, batch_size, embed_dim = query.size()

        # Project queries, keys, and values
        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)

        # Reshape to split into multiple heads
        q = q.view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
        k = k.view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
        v = v.view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)

        # NAtS: Determine token types dynamically
        token_type_scores = self.token_type_scorer(query)  # (seq_len, batch_size, num_token_types)
        token_type_probs = self.softmax(token_type_scores) # (seq_len, batch_size, num_token_types)

        # Construct dynamic attention mask based on token types
        nats_attn_mask = self.construct_nats_attn_mask(token_type_probs, seq_len, batch_size) #  (batch_size * num_heads, seq_len, seq_len)

        # Combine with provided attn_mask (if any)
        if attn_mask is not None:
            attn_mask = attn_mask + nats_attn_mask
        else:
            attn_mask = nats_attn_mask

        # Scaled dot-product attention
        attn_output, attn_weights = self.scaled_dot_product_attention(q, k, v, key_padding_mask, attn_mask)

        # Concatenate heads and project output
        attn_output = attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, embed_dim)
        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights

    def construct_nats_attn_mask(self, token_type_probs, seq_len, batch_size):
         # token_type_probs shape: (seq_len, batch_size, num_token_types)
        # num_token_types assumed to be: 0=Global, 1=Local, 2=Sliding Window

        nats_attn_mask = torch.zeros((batch_size * self.num_heads, seq_len, seq_len), device=token_type_probs.device)

        # Create masks for each token type
        global_mask = (token_type_probs[:, :, 0] > 0.5).float()  # Example threshold (you might need to adjust)
        local_mask = (token_type_probs[:, :, 1] > 0.5).float()
        sliding_window_mask = (token_type_probs[:, :, 2] > 0.5).float()

        # Iterate through each sequence position
        for i in range(seq_len):
            for j in range(seq_len):
                # Global Token: Can attend to all previous tokens
                if global_mask[i].any():
                    nats_attn_mask[:, i, j] = 0.0  # No mask (attend)

                # Local Token: Can attend to only limited tokens
                elif local_mask[i].any() and i > j :
                    nats_attn_mask[:, i, j] = float('-inf')  # Apply mask (do not attend)

                # Sliding Window Token: Apply a sliding window
                elif sliding_window_mask[i].any():  # Fixed sliding window size of 3 (adjust as needed)
                     if i - j > 3 or i < j:
                         nats_attn_mask[:, i, j] = float('-inf')  # Apply mask (do not attend)
        return nats_attn_mask

    def scaled_dot_product_attention(self, q, k, v, key_padding_mask=None, attn_mask=None):
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)

        if attn_mask is not None:
            scores = scores + attn_mask

        if key_padding_mask is not None:
            scores = scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        attn_output = torch.matmul(attn_weights, v)

        return attn_output, attn_weights

In [20]:
# Example usage:
embed_dim = 512
num_heads = 8
num_token_types = 3
seq_len = 32
batch_size = 16

# Create a sample MultiheadAttention module:
mha = MultiheadAttention(embed_dim, num_heads, num_token_types=num_token_types)

# Create random query, key, and value tensors:
query = torch.randn(seq_len, batch_size, embed_dim)
key = torch.randn(seq_len, batch_size, embed_dim)
value = torch.randn(seq_len, batch_size, embed_dim)

# Run the forward pass:
attn_output, attn_weights = mha(query, key, value)

# Print the output shapes:
print("Attention Output Shape:", attn_output.shape)
print("Attention Weights Shape:", attn_weights.shape)

Attention Output Shape: torch.Size([32, 16, 512])
Attention Weights Shape: torch.Size([128, 32, 32])
