In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size must be divisible by heads"

        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        N = x.shape[0]  # Batch size
        length = x.shape[1]  # Sequence length

        # Split the embedding into multiple heads
        values = self.values(x).view(N, length, self.heads, self.head_dim)
        keys = self.keys(x).view(N, length, self.heads, self.head_dim)
        queries = self.queries(x).view(N, length, self.heads, self.head_dim)

        # Transpose to get dimensions for attention calculation
        values = values.permute(0, 2, 1, 3)  # (N, heads, length, head_dim)
        keys = keys.permute(0, 2, 1, 3)      # (N, heads, length, head_dim)
        queries = queries.permute(0, 2, 1, 3)  # (N, heads, length, head_dim)

        # Calculate attention scores
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])  # (N, heads, length, length)
        attention = F.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        # Apply attention to values
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, length, self.embed_size)
        
        # Pass through the final linear layer
        out = self.fc_out(out)
        return out


In [6]:
# Example usage
embed_size = 256  # Embedding size
num_heads = 8     # Number of attention heads
multihead_attention = MultiHeadAttention(embed_size, num_heads)

# Sample input (batch_size, sequence_length, embedding_size)
x = torch.rand(64, 32, embed_size)  # Example input
output = multihead_attention(x)
print(output.shape)  # Should be (64, 10, 256)

print(output)

torch.Size([64, 32, 256])
tensor([[[ 0.0025,  0.0996, -0.0957,  ..., -0.0349,  0.1079, -0.0818],
         [-0.0396,  0.1765, -0.1102,  ..., -0.0694,  0.1320, -0.0667],
         [-0.0265,  0.1216, -0.0877,  ..., -0.0087,  0.0482, -0.0539],
         ...,
         [-0.0362,  0.1734, -0.1068,  ..., -0.0709,  0.1313, -0.0634],
         [-0.0278,  0.1208, -0.0864,  ..., -0.0125,  0.0451, -0.0537],
         [-0.0769,  0.0321, -0.1482,  ..., -0.0271,  0.1362, -0.0634]],

        [[-0.0301,  0.1330, -0.0632,  ..., -0.0793,  0.1258, -0.0237],
         [-0.0393,  0.1000, -0.0926,  ...,  0.0104,  0.1036, -0.0437],
         [-0.0874,  0.0553, -0.0957,  ..., -0.0500,  0.1250, -0.1058],
         ...,
         [-0.0369,  0.0989, -0.0921,  ...,  0.0080,  0.1016, -0.0424],
         [-0.0864,  0.0557, -0.0990,  ..., -0.0474,  0.1272, -0.1068],
         [-0.0970,  0.1262, -0.1569,  ..., -0.0162,  0.0853, -0.0452]],

        [[-0.0248,  0.0954, -0.0913,  ..., -0.0661,  0.1111,  0.0180],
         [ 0.0015, 