In [1]:
# Multi-Head Attention 
# MultiHead(Q, K, V) = Concat(head_1,...., head_h) W^0
# where  each head_i = Attention(Q^WQ_i, KW^K_i, VW^V_i)

In [1]:
# intuition
# head1: Might learn syntatic relationships 
# head2 : Might learn semantic relationships
# head3: Might learn positioning patterens 
# head4 : might learn topic relationships

# each head specializes in different aspects!

In [2]:
import torch.nn as nn

In [3]:
class MultiHeadAttention(nn.Module):
    # Arch 
    # 1.Run h independent attention heads 
    # 2.Concatenate their outputs 
    # 3.Project back to embedding dimension 


    def __init__(self, embedding_dim, num_heads, dropout = 0.1):
        super().__init__()

        # embedding_dim must be divisible by num_heads
        assert embedding_dim % num_heads == 0 
        self.num_heads  = num_heads
        self.head_dim = embedding_dim // num_heads
        self.embedding_dim = embedding_dim

        # create numtiple attention heads 
        self.heads = nn.ModuleList([

            SelfAttention(embedding_dim, self.head_dim)
            for _ in range(num_heads)
        ])

        # output projection 
        self.output_proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)



    def forward(self, x):

        # x : (batch, seq_len, embedding_dim)
        # process
        #  1. Each head processes input independently 
        #  2. Concaatenate all head outputs 
        #  3. Project to embedding_dim 

        batch, seq_len, emb_dim = x.shape 

        # step 1 : run all heads in parallel
        head_outputs = []
        all_attention_weights = []

        for head in self.heads:
            output, attn_weights = head(x)
            head_outputs.append(output)
            all_attention_weights.append(attn_weights)

            # step 2: Concatenate 
            # each head output : (batch, seq_len, head_dim)
            # concatenated : (batch, seq_len, num_heads * head_dim)
                            # = (batch, seq_lenm, embedding_dim)

            # step 3: Output projection 
            output = self.output_proj(multi_head_output)
            output = self.dropout(output)

            return output, all_attention_weights

