# Attention

At its core, attention is a mechanism that allows a token vector to include information about the context of that token (i.e. the surrounding token vectors).

As an example consider the token sequence $T_0, T_1, T_2$.
It seems sensible that token $T_2$ should have some information about $T_0$ and $T_1$ if we want to model the sequence successfully.

As usual, we will need a few imports from `torch`:

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

<torch._C.Generator at 0x7e85061a7790>

## Linear Combinations

We will continue working with our example sequence $T_0, T_1, T_2$.
We will call the vectors that represent the tokens $\mathbf{t_0}, \mathbf{t_1}$ and $\mathbf{t_2}$ respectively.

Let's say that every token is represented by a vector of dimension $5$, i.e. the entire sequence is represented by a tensor of dimension $3\times 5$.

We will use a random tensor throughout this section, in reality this would be the result of the layer that precedes the attention layer:

In [3]:
T = torch.randn(3, 5)
T

tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229],
        [-0.1863,  2.2082, -0.6380,  0.4617,  0.2674],
        [ 0.5349,  0.8094,  1.1103, -1.6898, -0.9890]])

Now let's say that we would like the token vectors to be able to "look" at each other.
The simplest way would be to calculate linear combinations.

For example, if we would like to get the information contained in $T_0$ and $T_1$, we might compute a linear combination of $\mathbf{t_0}$ and $\mathbf{t_1}$:

In [8]:
1 / 2 * T[0] + 1 / 2 * T[1]

tensor([ 0.0752,  1.1685, -0.2018,  0.3460, -0.4278])

Similarly, if we would like to combine the information in $T_0, T_1$ and $T_2$ we might compute a linear combination of $\mathbf{t_0}, \mathbf{t_1}$ and $\mathbf{t_2}$:

In [5]:
1 / 3 * T[0] + 1 / 3 * T[1] + 1 / 3 * T[2]

tensor([ 0.2284,  1.0488,  0.2356, -0.3326, -0.6148])

This doesn't look like a great idea, primarily because not every token is equally important for every token.
Instead we want to have some kind of weights in our linear combinations which should be data-driven, i.e. we would like the linear combination to be:

$w_0 \cdot \mathbf{t_0} + w_1 \cdot \mathbf{t_1} + w_2 \cdot \mathbf{t_2}$

where the weights $w_0, w_1$ and $w_2$ should be data-driven parameters.

That is, the input of a hypothetical attention layer would be a tensor containing the vectors $\mathbf{t_0}, \mathbf{t_1}$ and $\mathbf{t_2}$, while the output would be another tensor containing new vectors of $\mathbf{\hat{t}_0}, \mathbf{\hat{t}_1}$ and $\mathbf{\hat{t}_2}$ that are certain linear combinations of the input vectors.

Put differently, attention "mixes" the input vectors together and gives us new output vectors that were able to "communicate" with each other in some sense.

## Naive Attention

Let's take a stab at a very naive attention mechanism.
The idea would be to calculate the similarity of every token with every other token.
We could use the dot product for this:

In [9]:
for i in range(3):
    for j in range(3):
        dot_product = torch.dot(T[i], T[j])
        print(f"attention({i}, {j}) = {dot_product}")

attention(0, 0) = 1.4987846612930298
attention(0, 1) = -0.12174583971500397
attention(0, 2) = 1.265914797782898
attention(1, 0) = -0.12174582481384277
attention(1, 1) = 5.602515697479248
attention(1, 2) = -0.06531322002410889
attention(2, 0) = 1.2659146785736084
attention(2, 1) = -0.06531322002410889
attention(2, 2) = 6.007389068603516


This can of course be done much more efficiently via matrix multiplication:

In [11]:
S = torch.matmul(T, T.transpose(0, 1))
S

tensor([[ 1.4988, -0.1217,  1.2659],
        [-0.1217,  5.6025, -0.0653],
        [ 1.2659, -0.0653,  6.0074]])

Now we have a matrix of "attention scores" indicating how much attention vector $\mathbf{t_i}$ should pay to vector $\mathbf{t_j}$.
We can now compute a linear combination using data-driven weights:

In [12]:
S[0, 0] * T[0] + S[0, 1] * T[1] + S[0, 2] * T[2]

tensor([ 1.2045,  0.9488,  1.8346, -1.8501, -2.9674])

Again we can realize this much more efficiently via matrix multiplication:

In [45]:
torch.matmul(W, T)

tensor([[ 1.2045,  0.9488,  1.8346, -1.8501, -2.9674],
        [-1.1198, 12.3029, -3.6754,  2.6688,  1.6991],
        [ 3.6518,  4.8810,  7.0084, -9.8898, -7.3800]])

Unfortunately this naive attention will not work well in practice.
The reason is that we need to be able to differentiate between "information that a token vector represents" and "information that a token vector is interested in".

For example, if a token vector encodes that it is the subject of a sentence it will currently pay high attention to other subjects of the sentence (mostly to itself).
Instead, it should probably pay attention to e.g. token vectors that encode predicates of a sentence, articles etc.

We note that this a hand-wavy intuition and token vectors represent mostly inscrutable high-dimensional concepts that often have no real analogy in linguistics.
Despite this, the overall idea still works well in practice.

## Key and Query Vectors

We now introduce the first important component of the real self-attention mechanism - the **key vectors** and **query vectors**.

Every token vector generates a key vector and a query vector.
The key vector indicates the information that the token represents and the query vector contains the information the token is interested in.

To continue our informal example, a token might have the key vector "I am the subject of the sentence" and the query vector "I am interested in the predicate of the sentence".

The key vectors and query vectors are computed from the token vectors via simple linear layers.

For our example, we will create random linear layers - in reality this would be parameters that our neural network would have to learn.
Let's call the matrix that will produce the key vectors $W_K$ and the matrix that will produce the query vectors $W_Q$:

In [13]:
W_K = torch.randn(5, 4)
W_Q = torch.randn(5, 4)

We can now compute the key vectors:

In [14]:
K = torch.matmul(T, W_K)
K

tensor([[ 0.6023, -0.7260,  1.1799,  0.2383],
        [-0.6521,  4.4224, -3.7460, -1.2657],
        [-0.7106, -4.3429,  4.2984, -2.3664]])

We can also compute the query vectors:

In [15]:
Q = torch.matmul(T, W_Q)
Q

tensor([[-1.6964,  1.3355, -0.5133,  0.0674],
        [ 1.6595, -0.4445, -0.1917,  1.7729],
        [-0.1650, -2.9899, -3.8893,  1.2756]])

Now we will compute the attention scores in a similar way as before.
The big difference is that instead of computing the similarity of the token vectors with each other, we will _compute the similarity between the key vectors and the query vectors_:

In [16]:
S = torch.matmul(Q, K.transpose(0, 1))
S

tensor([[-2.5809,  8.8498, -6.9600],
        [ 1.5185, -4.5739, -4.2686],
        [-2.2135, -0.1601, -6.6347]])

A minor, but important technical detail is that we will need to scale the attention scores to avoid numerical instability:

In [17]:
S = S * (4**-0.5)
S

tensor([[-1.2905,  4.4249, -3.4800],
        [ 0.7593, -2.2869, -2.1343],
        [-1.1067, -0.0801, -3.3173]])

We still have one problem - right now our tokens can "look into the future".
For example token $T_0$ can "see" $T_1$ and $T_2$.

But during inference time, this will not be possible - we can't take into account tokens that haven't been generated yet.
Therefore we should disable this during training as well.

We will "mask" the attention scores of future tokens:

In [19]:
trimat = torch.tril(torch.ones(3, 3))
S = S.masked_fill(trimat == 0, float("-inf"))
S

tensor([[-1.2905,    -inf,    -inf],
        [ 0.7593, -2.2869,    -inf],
        [-1.1067, -0.0801, -3.3173]])

Finally, we will normalize the attention scores by computing the `softmax`:

In [20]:
S = F.softmax(S, dim=-1)
S

tensor([[1.0000, 0.0000, 0.0000],
        [0.9546, 0.0454, 0.0000],
        [0.2563, 0.7156, 0.0281]])

## Value Vectors

Right now, we would apply the scores to the token vectors directly.

It turns out that the attention mechanism performs even better if we introduce one more indirection and calculate **value vectors** from the token vectors and only then apply the scores.

The computation of the value vectors works exactly like the computation of the key and query vectors:

In [21]:
W_V = torch.randn(5, 4)

In [22]:
V = torch.matmul(T, W_V)
V

tensor([[-0.9285,  0.3301,  1.8359, -1.3448],
        [ 0.4676, -0.1512, -0.5678,  0.8648],
        [ 0.6143,  2.6772, -1.3256, -3.2423]])

In [23]:
R = torch.matmul(S, V)
R

tensor([[-0.9285,  0.3301,  1.8359, -1.3448],
        [-0.8652,  0.3082,  1.7268, -1.2445],
        [ 0.1138,  0.0517,  0.0270,  0.1831]])