In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head Attention mechanism implementation.

    This allows the model to jointly attend to information from different
    representation subspaces at different positions.
    """

    def __init__(self, d_model, num_heads, dropout=0.1):
        """
        Initialize the multi-head attention module.

        Args:
            d_model (int): The dimension of the model (embedding dimension)
            num_heads (int): Number of attention heads
            dropout (float): Dropout probability
        """
        super(MultiHeadAttention, self).__init__()

        # Ensure d_model is divisible by num_heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension of each head

        # Linear projections for Q, K, V and output
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)

        self.output_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x, batch_size):
        """
        Split the last dimension into (num_heads, d_k)
        and transpose to get shape (batch_size, num_heads, seq_len, d_k)
        """
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.permute(0, 2, 1, 3)

    def forward(self, query, key, value, mask=None):
        """
        Forward pass for multi-head attention

        Args:
            query: Query tensor (batch_size, seq_len_q, d_model)
            key: Key tensor (batch_size, seq_len_k, d_model)
            value: Value tensor (batch_size, seq_len_v, d_model)
            mask: Optional mask tensor for masked attention

        Returns:
            output: Attention output (batch_size, seq_len_q, d_model)
            attention_weights: Attention weights
        """
        batch_size = query.size(0)

        # Linear projections and split heads
        Q = self.split_heads(self.query_proj(query), batch_size)  # (batch_size, num_heads, seq_len_q, d_k)
        K = self.split_heads(self.key_proj(key), batch_size)      # (batch_size, num_heads, seq_len_k, d_k)
        V = self.split_heads(self.value_proj(value), batch_size)  # (batch_size, num_heads, seq_len_v, d_k)

        # Scaled dot-product attention
        # (batch_size, num_heads, seq_len_q, d_k) @ (batch_size, num_heads, d_k, seq_len_k)
        # -> (batch_size, num_heads, seq_len_q, seq_len_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)

        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply attention weights to values
        # (batch_size, num_heads, seq_len_q, seq_len_k) @ (batch_size, num_heads, seq_len_v, d_k)
        # -> (batch_size, num_heads, seq_len_q, d_k)
        context = torch.matmul(attention_weights, V)

        # Reshape back to (batch_size, seq_len_q, d_model)
        context = context.permute(0, 2, 1, 3).contiguous()
        context = context.view(batch_size, -1, self.d_model)

        # Final linear projection
        output = self.output_proj(context)

        return output, attention_weights

In [None]:
# Example usage
def test_multi_head_attention():
    batch_size = 2
    seq_len = 10
    d_model = 64
    num_heads = 8

    # Create random input tensors
    query = torch.randn(batch_size, seq_len, d_model)
    key = torch.randn(batch_size, seq_len, d_model)
    value = torch.randn(batch_size, seq_len, d_model)

    # Create mask (for causal/self-attention like in decoder)
    mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1).unsqueeze(0).unsqueeze(0)
    mask = (1.0 - mask).bool()  # Convert to boolean mask where 1 means keep, 0 means mask

    # Initialize multi-head attention layer
    mha = MultiHeadAttention(d_model, num_heads)

    # Forward pass
    output, attention = mha(query, key, value, mask)

    print(f"Input shape: {query.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {attention.shape}")

    return output, attention


if __name__ == "__main__":
    test_multi_head_attention()

Input shape: torch.Size([2, 10, 64])
Output shape: torch.Size([2, 10, 64])
Attention weights shape: torch.Size([2, 8, 10, 10])
