# **Multi Head attention**

In [1]:
import torch
import torch.nn as nn

In [11]:
class MultiHeadAttention(nn.Module):

  def __init__(self, d_in, d_out, context_len, dropout, num_heads, qkv_bias=False):
    super().__init__()
    assert (d_out%num_heads==0),"d_out must be divisible by num_heads"

    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    self.d_in = d_in
    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out//self.num_heads  #divides q*k head into num_head parts
    #number of heads we are keeping 2, therefore head_dim is 3

    # Here we decided to keep d_in=d_out=6, num_tokens=3

    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.output_projection = nn.Linear(d_out, d_out)
    self.dropout = nn.Dropout(dropout)

    self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_len, context_len), diagonal=1)
    )

  def forward(self, x):

    b, num_tokens, d_out = x.shape  #(3,3,6)

    keys = self.W_key(x)   #shape = (b, num_tokens, d_out)
    queries = self.W_query(x)  #b is batch size so if its 3, we will get 3 batches each having tokens in num_tokens qty
    values = self.W_value(x)

    #now take 1 batch:
    # input (1,3,6)---roll out to---->(1,3,2,3)
    # (b, num_tokens, d_out) turns into (b, num_tokens, num_heads, head_dim)
    # it means we divided one head into 2 heads using view function
    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)

    #answer ==> keys = torch.tensor([[[0.1,0.2,0.3],  token 1 head 1
    #                                [0.4,0.5,0.6]],  token 1 head 2
    #                                                                       This is one batch, grouped by tokens
    #                                [[0.1,0.2,0.3],  token 2 head 1
    #                                [0.4,0.5,0.6]])  token 2 head 2

    #                                [[0.1,0.2,0.3],  token 3 head 1
    #                                [0.4,0.5,0.6]])  token 3 head 2

    #now we have (1,3,2,3) we need to change to (1,2,3,3) so we transpose
    #(b, num_tokens, num_heads, head_dim)--->(b, num_heads, num_tokens, head_dim)
    keys = keys.transpose(1,2)
    queries = queries.transpose(1,2)
    values = values.transpose(1,2)

    #answer ==> keys = torch.tensor([[[0.1,0.2,0.3],      head 1 token 1
    #                                [0.45,0.85, 0.91]    head 1 token 2
    #                                [0.32,0.7, 0.5]]]    head 1 token 3
    #                                                                       This is one batch, grouped by heads
    #                                [[0.1,0.2,0.3],      head 2 token 1
    #                                [0.4,0.5,0.6],       head 2 token 2
    #                                [0.83,0.98,0.12]]])  head 2 token 3


    #only for keys (b, num_heads, num_tokens, head_dim)-------->(b, num_heads, head_dim, num_tokens)
    attn_scores = queries @ keys.transpose(2,3)   #transpose to change row to col matrix/ dot product for each head

    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]    #causal attention
    attn_scores.masked_fill_(mask_bool, -torch.inf)

    # scale by dividing it by root of head_dim since it is callign unchanged keys
    attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)

    #shape = (b, num_tokens, num_heads, head_dim) grouped by tokens
    context_vecs = (attn_weights@values).transpose(1,2)

    #now we combine heads where d_out = num_heads*head_dim
    context_vecs = context_vecs.contiguous().view(b, num_tokens, self.d_out)  #we roll out back to 3 dims
    context_vecs = self.output_projection(context_vecs)

    return context_vecs



In [15]:
import torch
torch.manual_seed(123)

inputs = torch.tensor([
    [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
    [0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
    [9.0, 8.0, 7.0, 6.0, 5.0, 4.0]
], dtype=torch.float32)

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)


#num_tokens is the context_length
batch_size, num_tokens , d_in = batch.shape
d_out = batch.shape[-1]

mha = MultiHeadAttention(d_in, d_out, num_tokens, 0.0, num_heads=2)
context_vecs = mha(batch)
print("Context vectors: \n",context_vecs)
print("\nShape of context vectors: ",context_vecs.shape)

torch.Size([2, 3, 6])
Context vectors: 
 tensor([[[ 1.4622e-01,  5.7911e-02,  9.8725e-01,  5.6106e-01, -1.0803e+00,
          -1.0395e-01],
         [ 1.4315e-01, -1.0073e-03,  8.3840e-01,  5.8250e-01, -1.0526e+00,
          -9.7606e-02],
         [ 1.6335e-01, -3.4609e-03,  8.1940e-01,  5.8568e-01, -1.0263e+00,
          -8.2506e-02]],

        [[ 1.4622e-01,  5.7911e-02,  9.8725e-01,  5.6106e-01, -1.0803e+00,
          -1.0395e-01],
         [ 1.4315e-01, -1.0073e-03,  8.3840e-01,  5.8250e-01, -1.0526e+00,
          -9.7606e-02],
         [ 1.6335e-01, -3.4609e-03,  8.1940e-01,  5.8568e-01, -1.0263e+00,
          -8.2506e-02]]], grad_fn=<ViewBackward0>)

Shape of context vectors:  torch.Size([2, 3, 6])
