In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size must be divisible by heads"

        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        N = x.shape[0]  # Batch size
        length = x.shape[1]  # Sequence length

        # Split the embedding into multiple heads
        values = self.values(x).view(N, length, self.heads, self.head_dim)
        keys = self.keys(x).view(N, length, self.heads, self.head_dim)
        queries = self.queries(x).view(N, length, self.heads, self.head_dim)

        # Transpose to get dimensions for attention calculation
        values = values.permute(0, 2, 1, 3)  # (N, heads, length, head_dim)
        keys = keys.permute(0, 2, 1, 3)      # (N, heads, length, head_dim)
        queries = queries.permute(0, 2, 1, 3)  # (N, heads, length, head_dim)

        # Calculate attention scores
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])  # (N, heads, length, length)
        attention = F.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        # Apply attention to values
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, length, self.embed_size)
        
        # Pass through the final linear layer
        out = self.fc_out(out)
        return out


In [None]:
# Example usage
embed_size = 256  # Embedding size
num_heads = 8     # Number of attention heads
multihead_attention = MultiHeadAttention(embed_size, num_heads)

# Sample input (batch_size, sequence_length, embedding_size)
x = torch.rand(64, 10, embed_size)  # Example input
output = multihead_attention(x)
print(output.shape)  # Should be (64, 10, 256)