# Implementation for Multi-head attention 

Ramin Anushiravani \
09/14/2024

This notebook implements a multihead attention using pytorch. 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        assert (
            self.head_dim * num_heads == embed_size
        ), "Embedding size must be divisible by num_heads"

        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        N = x.shape[0]
        seq_len = x.shape[1]
        
        # Split the embedding into multiple heads
        values = self.values(x).view(N, seq_len, self.num_heads, self.head_dim)
        keys = self.keys(x).view(N, seq_len, self.num_heads, self.head_dim)
        queries = self.queries(x).view(N, seq_len, self.num_heads, self.head_dim)
        
        values = values.permute(0, 2, 1, 3)
        keys = keys.permute(0, 2, 1, 3)
        queries = queries.permute(0, 2, 1, 3)

        # Calculate the attention score
        attention_score = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])  # (N, num_heads, query_len, key_len)
        
       # Calculate the probability of attention weights 
        weighted_sum = torch.nn.functional.softmax(attention_score / (self.embed_size ** (1 / 2)), dim=3)
        
         # Calculate final output by multiplying the weight vector with the attention weights
        out = torch.einsum("nhql,nlhd->nqhd", [weighted_sum, values])
        out = out.permute(0, 2, 1, 3).contiguous().view(N, seq_len, self.embed_size)
        out = self.fc_out(out)
        
        return out


In [None]:
# Test the MultiHeadSelfAttention class
embed_size = 256  # Embedding size (must be divisible by num_heads)
num_heads = 8  # Number of attention heads

self_attn = MultiHeadSelfAttention(embed_size, num_heads)

# Dummy inputs
N = 10  # Batch size
seq_len = 20  # Sequence length
x = torch.rand(N, seq_len, embed_size)  # Input tensor

output = self_attn(x)
print(output.shape)  # Expected output shape: (N, seq_len, embed_size)
