In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size must be divisible by heads"

        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        N = x.shape[0]  # Batch size
        length = x.shape[1]  # Sequence length

        # Split the embedding into multiple heads
        values = self.values(x).view(N, length, self.heads, self.head_dim)
        keys = self.keys(x).view(N, length, self.heads, self.head_dim)
        queries = self.queries(x).view(N, length, self.heads, self.head_dim)

        # Transpose to get dimensions for attention calculation
        values = values.permute(0, 2, 1, 3)  # (N, heads, length, head_dim)
        keys = keys.permute(0, 2, 1, 3)      # (N, heads, length, head_dim)
        queries = queries.permute(0, 2, 1, 3)  # (N, heads, length, head_dim)

        # Calculate attention scores
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])  # (N, heads, length, length)
        attention = F.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        # Apply attention to values
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, length, self.embed_size)
        
        # Pass through the final linear layer
        out = self.fc_out(out)
        return out


In [5]:
# Example usage
embed_size = 256  # Embedding size
num_heads = 8     # Number of attention heads
multihead_attention = MultiHeadAttention(embed_size, num_heads)

# Sample input (batch_size, sequence_length, embedding_size)
x = torch.rand(64, 32, embed_size)  # Example input
output = multihead_attention(x)
print(output.shape)  # Should be (64, 10, 256)

print(output)

torch.Size([64, 10, 256])
tensor([[[ 0.0359, -0.0094,  0.1310,  ...,  0.0443, -0.0459,  0.0820],
         [ 0.1151,  0.0488,  0.1066,  ..., -0.0399, -0.0244, -0.0112],
         [ 0.0831, -0.0396,  0.0956,  ...,  0.0780, -0.0878,  0.0266],
         ...,
         [ 0.0779, -0.0408,  0.0997,  ...,  0.0731, -0.0889,  0.0288],
         [ 0.0772,  0.0261,  0.0332,  ...,  0.0459, -0.0716,  0.0105],
         [ 0.0982,  0.0036,  0.0868,  ...,  0.0141, -0.0214,  0.0481]],

        [[ 0.0586, -0.0611,  0.0391,  ...,  0.0371, -0.0770,  0.0222],
         [ 0.0758,  0.0019,  0.1511,  ...,  0.0600, -0.0333, -0.0160],
         [ 0.0576, -0.0013,  0.1217,  ...,  0.0369, -0.0620,  0.0041],
         ...,
         [ 0.0622, -0.0016,  0.1240,  ...,  0.0355, -0.0647,  0.0038],
         [ 0.0351, -0.0278,  0.1407,  ...,  0.0837, -0.0816,  0.0079],
         [ 0.0376, -0.0127,  0.0706,  ...,  0.0588, -0.0509,  0.0212]],

        [[-0.0069, -0.0388,  0.1328,  ...,  0.0380, -0.1013,  0.0230],
         [ 0.0482, 