# Attention Mechanism

This notebook presents the self-attention mechanism, which is a key component of the Transformer model. The self-attention mechanism allows the model to focus on different parts of the input sequence when making predictions. This is particularly useful for tasks that involve long sequences, such as machine translation and text generation.
The input sentence is first transformed into a sequence of vectors, which are then used to compute the attention scores. The attention scores are used to compute a weighted sum of the input vectors, which is then passed through a feedforward neural network to produce the final output.

Input sentence -> Input vectors -> Attention scores -> Weighted sum -> Feedforward neural network -> Output

The input vectors are typically obtained by embedding the input tokens into a high-dimensional vector space. The attention scores are computed using a similarity function, which measures how similar each input vector is to the current context vector. These vectors are:
- Query vector: What I am looking for
- Key vector: What I am looking at
- Value vector: What I am paying attention to

Input sentence example: "The cat sat on the mat."

In [17]:
import numpy as np
from math import inf

In [2]:
sentence = ["The", "cat", "sat", "on", "the", "mat."]

In [28]:
q = np.random.randn(6, 4)
k = np.random.randn(6, 4)
v = np.random.randn(6, 4)

print(q)
print(k)
print(v)

[[-0.17726915 -0.60715105  0.27040047  1.32262995]
 [-0.35466501 -1.00190095  0.49971724  0.52404003]
 [-0.96664612 -0.92880235  0.31031606  1.47667384]
 [ 2.58801299  0.91252951  0.47811754 -1.90542747]
 [ 0.39023975  1.89121848 -0.77398807  0.55986816]
 [-0.11781247  0.66704404  0.74030888  0.80908705]]
[[ 1.31088812  0.03025595  0.02526448 -0.6110974 ]
 [ 0.03582415  0.53263921 -0.34596446  0.67464531]
 [-0.21136     1.38581759 -2.44668208 -0.46287517]
 [-0.4030099  -0.90153947 -0.72863338 -1.69069868]
 [-0.05841235  0.88533137  0.23723818  1.64457024]
 [ 0.90458896 -0.71120042 -0.77826392  1.28070143]]
[[ 0.76514641 -1.69868336 -1.59656269 -0.76914076]
 [-0.81664305 -0.10608875 -1.38933315 -2.52314108]
 [ 0.20812633  0.43433177 -1.68935144 -0.07477778]
 [ 1.00473554  0.71972715 -0.40171648 -0.90225516]
 [-1.03943008 -2.32277988  1.68366562  0.53501308]
 [ 1.49572558  0.46566221 -0.26506452  1.31825037]]


\begin{align}
self attention(Q, K, V) = softmax \left( \dfrac{Q \cdot K^T}{\sqrt{d_k}} + M \right) \cdot V
\end{align}

## Masking
The masking matrix $M$ is used to prevent the model from attending to certain parts of the input sequence. For example, in machine translation tasks, the model should not be allowed to attend to the future tokens in the input sequence. This is achieved by setting the masking matrix to zero for the future tokens and one for the past tokens.

In [29]:
def create_mask(size):
    mask = np.tril(np.ones((size, size)))
    mask[mask == 0] = -inf
    mask[mask == 1] = 0
    return mask

mask = create_mask(len(sentence))
mask

array([[  0., -inf, -inf, -inf, -inf, -inf],
       [  0.,   0., -inf, -inf, -inf, -inf],
       [  0.,   0.,   0., -inf, -inf, -inf],
       [  0.,   0.,   0.,   0., -inf, -inf],
       [  0.,   0.,   0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.,   0.,   0.]])

In [15]:
def softmax(x):
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x, axis=1, keepdims=True)

In [30]:
w = softmax(np.dot(q, k.T) + mask)
w

array([[1.        , 0.        , 0.        , 0.        , 0.        ,
        0.        ],
       [0.39241945, 0.60758055, 0.        , 0.        , 0.        ,
        0.        ],
       [0.06890137, 0.8818494 , 0.04924923, 0.        , 0.        ,
        0.        ],
       [0.95480366, 0.00402563, 0.01479945, 0.02637126, 0.        ,
        0.        ],
       [0.01492158, 0.06423303, 0.78734935, 0.00128508, 0.13221096,
        0.        ],
       [0.04566295, 0.15950434, 0.02440449, 0.00717102, 0.68879132,
        0.07446588]])

Here, "sat" focuses on "cat" and "on" focuses on "the".

In [26]:
def attention(q, k, v, mask=None):
    scaled = np.dot(q, k.T)
    if mask is not None:
        scaled = scaled + mask
    weights = softmax(scaled)
    output = np.dot(weights, v)
    return output

In [31]:
self_attention = attention(q, k, v, mask)
self_attention

array([[ 0.76514641, -1.69868336, -1.59656269, -0.76914076],
       [-0.19591809, -0.73105386, -1.47065405, -1.83483723],
       [-0.65718648, -0.18920541, -1.41838721, -2.28170805],
       [ 0.75685338, -1.59692818, -1.56559208, -0.76943592],
       [-0.01330302,  0.00363734, -1.22109125, -0.1628469 ],
       [-0.68760497, -1.64396236,  0.80133911,  0.02080885]])

## Multi-head Attention
The multi-head attention mechanism allows the model to attend to different parts of the input sequence in parallel. This is achieved by splitting the input vectors into multiple heads, which are then processed independently. The outputs of the different heads are concatenated and passed through a linear transformation to produce the final output. This allows the model to capture different aspects of the input sequence and learn more complex patterns.

Input sentence -> Input vectors -> Split into multiple heads -> Process independently -> Concatenate outputs -> Linear transformation -> Output

Parameters involved:
- $d_{model}$: Dimension of the input vectors
- $d_{k}$: Dimension of the key vectors
- $d_{v}$: Dimension of the value vectors
- $h$: Number of heads

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [36]:
sequence_length = 6 # Length of the input sequence
batch_size = 1 # Number of sequences in a batch
d_model = 512 # Dimension of the input vectors

x = torch.randn(batch_size, sequence_length, d_model) # encoding of the input sequence
x.size()

torch.Size([1, 6, 512])

In [38]:
qkv_layer = nn.Linear(d_model, 3 * d_model)
qkv = qkv_layer(x)
qkv.size()

torch.Size([1, 6, 1536])

In [41]:
num_heads = 8 # Number of heads
d_k = d_model // num_heads # Dimension of the key vectors
d_v = d_model // num_heads # Dimension of the value vectors

qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3 * d_k)
qkv.size()

torch.Size([1, 6, 8, 192])

In [42]:
qkv = qkv.permute(0, 2, 1, 3)
qkv.size()

torch.Size([1, 8, 6, 192])

In [43]:
q, k, v = qkv.chunk(3, dim=-1)
q.size(), k.size(), v.size() # [batch_size, num_heads, sequence_length, d_k]

(torch.Size([1, 8, 6, 64]),
 torch.Size([1, 8, 6, 64]),
 torch.Size([1, 8, 6, 64]))

In [44]:
scaled = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_k)
scaled.size()

torch.Size([1, 8, 6, 6])

In [46]:
mask = torch.full(scaled.size(), float(-inf))
mask = torch.triu(mask, diagonal=1)
mask[0][1]

tensor([[0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0.]])

In [47]:
scaled = scaled + mask

attention_weights = F.softmax(scaled, dim=-1)
attention_weights.size()

torch.Size([1, 8, 6, 6])

In [48]:
values = torch.matmul(attention_weights, v)
values.size()

torch.Size([1, 8, 6, 64])

In [49]:
def scaled_dot_product(q, k, v, mask=None):
    scaled = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention_weights = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention_weights, v)
    return attention_weights, values

In [53]:
attention_weights, values = scaled_dot_product(q, k, v, mask)
attention_weights.size(), values.size()

(torch.Size([1, 8, 6, 6]), torch.Size([1, 8, 6, 64]))

In [54]:
values = values.reshape(batch_size, sequence_length, num_heads * d_v)
values.size()

torch.Size([1, 6, 512])

In [55]:
linear_layer = nn.Linear(num_heads * d_v, d_model)
output = linear_layer(values)
output.size()

torch.Size([1, 6, 512])

In [68]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_v = d_model // num_heads
        self.qkv_layer = nn.Linear(d_model, 3 * d_model)
        self.linear_layer = nn.Linear(num_heads * self.d_v, d_model)

    def forward(self, x, mask=None):
        print(f'x.size(): {x.size()}')
        batch_size, sequence_length, _ = x.size()
        qkv = self.qkv_layer(x)
        print(f'qkv.size(): {qkv.size()}')
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.d_k)
        print(f'qkv.size(): {qkv.size()}')
        qkv = qkv.permute(0, 2, 1, 3)
        print(f'qkv.size(): {qkv.size()}')
        q, k, v = qkv.chunk(3, dim=-1)
        print(f'q.size(): {q.size()}, k.size(): {k.size()}, v.size(): {v.size()}')
        attention_weights, values = scaled_dot_product(q, k, v, mask)
        print(f'attention_weights.size(): {attention_weights.size()}, values.size(): {values.size()}')
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.d_v)
        output = self.linear_layer(values)
        print(f'output.size(): {output.size()}')
        return output

In [69]:
multi_head_attention = MultiHeadAttention(d_model, num_heads)
new_values = multi_head_attention(x, mask)

x.size(): torch.Size([1, 6, 512])
qkv.size(): torch.Size([1, 6, 1536])
qkv.size(): torch.Size([1, 6, 8, 192])
qkv.size(): torch.Size([1, 8, 6, 192])
q.size(): torch.Size([1, 8, 6, 64]), k.size(): torch.Size([1, 8, 6, 64]), v.size(): torch.Size([1, 8, 6, 64])
attention_weights.size(): torch.Size([1, 8, 6, 6]), values.size(): torch.Size([1, 8, 6, 64])
output.size(): torch.Size([1, 6, 512])
