https://www.youtube.com/watch?v=HQn1QKQYXVg&list=PLTl9hO2Oobd97qfWC40gOSU8C0iu0m2l4&index=2

![pic](../doc/img//multi_head_attention.png)

each word in the input sequence will have three types of vectors associated to it :
- $q$ = query vector, represents what I am looking for
- $k$ = what I can offer
- $v$ = what I actually offer

each vector has the same number of dimensions as the embedding vector.

The q, k, v vectors are then split into multiple attention heads of equal size. In the paper, they use 8 attention heads, so each vector is split into 8 vectors of size 1x64 (given an embedding dimension of 512).

Then for each head, we generate an attention matrix based on the heads of other words

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
sequence_length = 4 # typically you would set a mximum sequence length so that all vectors are of the same size
batch_size = 1 
input_dim = 512
d_model = 512 # output of the attention unit for every single word
# randomly sampled input, represents the value that is inputed into the multi head attention block in the transforme architecture
x = torch.randn((batch_size, sequence_length, input_dim)) 

In [3]:
x

tensor([[[ 0.2875,  0.4003, -1.5862,  ...,  0.4906,  0.1627,  0.7256],
         [-0.2272, -0.6080,  1.0705,  ..., -0.1585, -1.4422,  0.3697],
         [ 0.9237,  0.2495,  1.9846,  ...,  1.0509,  0.0411,  1.3766],
         [ 0.0165, -0.1440,  1.4073,  ...,  0.0226, -0.4571,  1.0191]]])

In [4]:
# concatenated q,k, v vectors for all attention heads
qkv_layer = nn.Linear(input_dim, 3 * d_model)

In [7]:
qkv_layer

Linear(in_features=512, out_features=1536, bias=True)

In [9]:
qkv = qkv_layer(x) # pass input to the layer

In [10]:
qkv.shape

torch.Size([1, 4, 1536])

In [11]:
num_heads = 8
head_dim = d_model // num_heads # 64

In [12]:
# split qkv vectors by attention head by adding a 4 th dimension
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3 * head_dim)

In [14]:
qkv.shape

torch.Size([1, 4, 8, 192])

In [15]:
qkv = qkv.permute(0, 2, 1, 3) # put the head dimension before the sequence length dimension for parallelization on the last 2 dimensiosn
qkv.shape

torch.Size([1, 8, 4, 192])

In [16]:
# now split separate q, k, v vectors
q, k, v = qkv.chunk(3, dim=-1) # -1 = last dimension

In [17]:
q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

## Self Attention

self attention = $softmax(\frac{Q.K^T}{\sqrt(d_k)} + M) $  
new V = self attention.V

Every word has a query vector, and it will compare its query (Q) vector to every other word's key (K) vector to get its attention values

In [18]:
d_k = q.size()[-1] # 64
# must use k.transpose instead of k.T, since it is 4 dimensional and we want to specify only a subset of dimension on which to transpose
# here we transpose on the last 2 dimensions which are the sequence length and the head dimension size 
scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # attention matrices for all heads within each batch

In [19]:
scaled.shape

torch.Size([1, 8, 4, 4])

In [20]:
mask = torch.full(scaled.size(), float('-inf'))
mask = torch.triu(mask, 
                  diagonal =1) # sepcifies how many diagonals above the 0's to replace with 0's
mask[0][0] # mask for a single head
# we use -inf so that the exponents of -inf become 0's in the softmax function

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [21]:
(scaled + mask)[0][0]

tensor([[ 0.2036,    -inf,    -inf,    -inf],
        [ 0.7623,  0.3230,    -inf,    -inf],
        [-0.3756,  0.3354, -0.0095,    -inf],
        [-0.3710, -0.1733, -0.2924,  0.0059]], grad_fn=<SelectBackward0>)

In [22]:
# mask the attention values for the decoder so that each word cannot only fetch information from words appearing before it
# ensure the autoregressive nature of a translation task
scaled += mask

In [23]:
attention = F.softmax(scaled, dim=-1)

In [24]:
# generated the new value vector that are more context aware then its previous values
values = torch.matmul(attention, v)

torch.Size([1, 4, 512])

In [31]:
values.shape

torch.Size([1, 8, 4, 64])

In [32]:
# reconcatenate value vectors along the number of heads dimension
values = values.reshape(batch_size, sequence_length, num_heads * head_dim) 

In [33]:
values.shape

torch.Size([1, 4, 512])

In [None]:
# enable heads to communicate with each other by passing them through a feedforward layer
linear_layer = nn.Linear(d_model, d_model)
out = linear_layer(values)


## Putting it all together

In [38]:
# put all self attention lines of code above into one function
# note that mask is optional to reflect the fact self attention can be calculated on either the encoder (no mask) or the decoder (mask)
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1] # 64
    # must use k.transpose instead of k.T, since it is 4 dimensional and we want to specify only a subset of dimension on which to transpose
    # here we transpose on the last 2 dimensions which are the sequence length and the head dimension size 
    scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # attention matrices for all heads within each batch
    if mask:
        # mask the attention values for the decoder so that we each word cannot only fetch information from words appearing before it
        # ensure the autoregressive nature of a translation task
        scaled += mask
    
    attention = F.softmax(scaled, dim=-1)
    # generated the new value vector that are more context aware then its previous values
    values = torch.matmul(attention, v)
    return values, attention

In [42]:
class MultiheadAttention(nn.Module):
    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim, 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        # get concatenated qkv vectors, this time split by attention head
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 *self. head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3) # put the head dimension before the sequence length dimension for parallelization on the last 2 dimensiosn
        print(f"qkv.size(): {qkv.size()}")
        # now split separate q, k, v vectors
        q, k, v = qkv.chunk(3, dim=-1) # -1 = last dimension
        print(f"q.size(): {q.size()} k.size(): {k.size()}, v.size(): {v.size()}")
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        print(f"values.size(): {values.size()}, attention.size(): {attention.size()}")
        # reconcatenate value vectors along the number of heads dimension
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size() : {out.size()}")
        return out
        

In [43]:
input_dim = 1024
input_dim = 512
num_head = 8

batch_size = 30
sequence_length = 5

x = torch.randn((batch_size, sequence_length, input_dim)) 

In [44]:
model = MultiheadAttention(input_dim, d_model, num_head) 
out = model.forward(x)

x.size(): torch.Size([30, 5, 512])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q.size(): torch.Size([30, 8, 5, 64]) k.size(): torch.Size([30, 8, 5, 64]), v.size(): torch.Size([30, 8, 5, 64])
values.size(): torch.Size([30, 8, 5, 64]), attention.size(): torch.Size([30, 8, 5, 5])
values.size(): torch.Size([30, 5, 512])
out.size() : torch.Size([30, 5, 512])


In [45]:
out

tensor([[[-6.2482e-03,  1.8415e-01,  1.3552e-01,  ..., -3.3307e-01,
           8.7331e-02,  1.9772e-01],
         [-2.4796e-01,  1.9846e-01,  2.6852e-02,  ...,  1.5789e-02,
           9.6477e-03,  3.0723e-02],
         [ 6.4421e-02, -1.4961e-01, -1.7929e-02,  ...,  2.6415e-01,
          -8.2693e-02,  7.8689e-02],
         [ 7.5609e-02, -2.8181e-01, -1.5872e-01,  ...,  1.2011e-02,
          -1.2643e-03, -5.6821e-02],
         [ 1.8020e-01,  3.6276e-01, -2.0193e-01,  ...,  2.4924e-01,
          -8.4760e-02, -4.8158e-01]],

        [[ 2.3121e-02,  5.3049e-02,  1.3574e-01,  ..., -1.1809e-03,
           1.6767e-01, -1.9779e-01],
         [ 2.0019e-01,  1.6813e-01, -7.7353e-02,  ...,  2.4081e-01,
           3.4660e-02,  4.4718e-02],
         [-1.5035e-01,  1.6051e-01,  2.4635e-01,  ...,  5.8178e-02,
           2.5456e-01,  7.6498e-02],
         [ 2.8846e-01, -4.7870e-02,  1.2407e-01,  ...,  1.4104e-01,
          -1.8601e-01, -3.1178e-01],
         [-8.7335e-02, -2.8972e-01, -7.5568e-03,  ...