# Attention

At its core, attention is a mechanism that allows a token vector to include information about the context of that token (i.e. the surrounding token vectors).

As an example consider the token sequence $T_0, T_1, T_2$ (like "an", "example", "sequence").
It seems sensible that token $T_2$ should have some information about $T_0$ and $T_1$ if we want to model the sequence successfully.

As usual, we will need a few imports from `torch`:

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

<torch._C.Generator at 0x74ac46143770>

## Linear Combinations

We will continue working with our example sequence $T_0, T_1, T_2$.
We will call the vectors that represent the tokens $\mathbf{x_0}, \mathbf{x_1}$ and $\mathbf{x_2}$ respectively.

Let's say that every token is represented by a vector of dimension $5$, i.e. the entire sequence is represented by a tensor of dimension $3\times 5$.

We will use a random tensor throughout this section, in reality this would be the result of the layer that precedes the attention layer:

In [5]:
X = torch.randn(3, 5)
X

tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229],
        [-0.1863,  2.2082, -0.6380,  0.4617,  0.2674],
        [ 0.5349,  0.8094,  1.1103, -1.6898, -0.9890]])

Now let's say that we would like the token vectors to be able to "look" at each other.
The simplest way would be to calculate averages.

For example, if we would like to get the information contained in $T_0$ and $T_1$, we might compute the average of $\mathbf{x_0}$ and $\mathbf{x_1}$:

In [6]:
1 / 2 * X[0] + 1 / 2 * X[1]

tensor([ 0.0752,  1.1685, -0.2018,  0.3460, -0.4278])

Similarly, if we would like to combine the information in $T_0, T_1$ and $T_2$ we might compute the average of $\mathbf{x_0}, \mathbf{x_1}$ and $\mathbf{x_2}$:

In [7]:
1 / 3 * X[0] + 1 / 3 * X[1] + 1 / 3 * X[2]

tensor([ 0.2284,  1.0488,  0.2356, -0.3326, -0.6148])

This doesn't look like a great idea, primarily because not every token is equally important for every token.
Instead we want to have some kind of weights in our linear combinations which should be data-driven, i.e. we would like to compute arbitrary linear combinations:

$w_0 \cdot \mathbf{x_0} + w_1 \cdot \mathbf{x_1} + w_2 \cdot \mathbf{x_2}$

where the weights $w_0, w_1$ and $w_2$ should be data-driven parameters.

That is, the input of a hypothetical attention layer would be a tensor containing the vectors $\mathbf{x_0}, \mathbf{x_1}$ and $\mathbf{x_2}$, while the output would be another tensor containing new vectors of $\mathbf{y_0}, \mathbf{y_1}$ and $\mathbf{y_2}$ that are certain linear combinations of the input vectors:

$w_{00} \cdot \mathbf{x_0} + w_{01} \cdot \mathbf{x_1} + w_{02} \cdot \mathbf{x_2} = \mathbf{y_0}$

$w_{10} \cdot \mathbf{x_0} + w_{11} \cdot \mathbf{x_1} + w_{12} \cdot \mathbf{x_2} = \mathbf{y_1}$

$w_{20} \cdot \mathbf{x_0} + w_{21} \cdot \mathbf{x_1} + w_{22} \cdot \mathbf{x_2} = \mathbf{y_2}$

We could represent this in terms of a matrix-vector 

Put differently, attention "mixes" the input vectors together and gives us new output vectors that were able to "communicate" with each other in some sense.

The big question is now how to compute the attention weights.

## Naive Attention

Let's take a stab at a very naive attention mechanism.
The idea would be to calculate the similarity of every token with every other token.
We could use the dot product for this:

In [9]:
for i in range(3):
    for j in range(3):
        dot_product = torch.dot(X[i], X[j])
        print(f"attention({i}, {j}) = {dot_product}")

attention(0, 0) = 1.4987846612930298
attention(0, 1) = -0.12174583971500397
attention(0, 2) = 1.265914797782898
attention(1, 0) = -0.12174582481384277
attention(1, 1) = 5.602515697479248
attention(1, 2) = -0.06531322002410889
attention(2, 0) = 1.2659146785736084
attention(2, 1) = -0.06531322002410889
attention(2, 2) = 6.007389068603516


This can of course be done much more efficiently via matrix multiplication:

In [15]:
S = torch.matmul(X, X.transpose(0, 1))
S

tensor([[ 1.4988, -0.1217,  1.2659],
        [-0.1217,  5.6025, -0.0653],
        [ 1.2659, -0.0653,  6.0074]])

This yields a matrix of "attention weights".
We now normalize this matrix using the softmax function to get a matrix of attention scores:

In [16]:
W = torch.softmax(W, dim=1)
W

tensor([[0.5025, 0.0994, 0.3981],
        [0.0032, 0.9933, 0.0034],
        [0.0086, 0.0023, 0.9891]])

Now we have a matrix of "attention scores" indicating how much attention vector $\mathbf{x_i}$ should pay to vector $\mathbf{x_j}$.
We can now compute a linear combination using data-driven weights.

$w_{00} \cdot \mathbf{x_0} + w_{01} \cdot \mathbf{x_1} + w_{02} \cdot \mathbf{x_2} = \mathbf{y_0}$

In [18]:
W[0, 0] * X[0] + W[0, 1] * X[1] + W[0, 2] * X[2]

tensor([ 0.3636,  0.6064,  0.4964, -0.5111, -0.9314])

Next: $w_{10} \cdot \mathbf{x_0} + w_{11} \cdot \mathbf{x_1} + w_{12} \cdot \mathbf{x_2} = \mathbf{y_1}$

In [20]:
W[1, 0] * X[0] + W[1, 1] * X[1] + W[1, 2] * X[2]

tensor([-0.1822,  2.1967, -0.6292,  0.4535,  0.2585])

Next: $w_{20} \cdot \mathbf{x_0} + w_{21} \cdot \mathbf{x_1} + w_{22} \cdot \mathbf{x_2} = \mathbf{y_2}$

In [21]:
W[2, 0] * X[0] + W[2, 1] * X[1] + W[2, 2] * X[2]

tensor([ 0.5315,  0.8067,  1.0987, -1.6683, -0.9873])

Again we can realize this much more efficiently via matrix multiplication:

In [24]:
torch.matmul(W, X)

tensor([[ 0.3636,  0.6064,  0.4964, -0.5111, -0.9314],
        [-0.1822,  2.1967, -0.6292,  0.4535,  0.2585],
        [ 0.5315,  0.8067,  1.0987, -1.6683, -0.9873]])

Unfortunately this naive attention will not work well in practice.
The reason is that we need to be able to differentiate between "information that a token vector represents" and "information that a token vector is interested in".

For example, if a token vector encodes that it is the subject of a sentence it will currently pay high attention to other subjects of the sentence (mostly to itself).
Instead, it should probably pay attention to e.g. token vectors that encode predicates of a sentence, articles etc.

We note that this a hand-wavy intuition and token vectors represent mostly inscrutable high-dimensional concepts that often have no real analogy in linguistics.
Despite this, the overall idea still works well in practice.

## Key and Query Vectors

We now introduce the first important component of the real self-attention mechanism - the **key vectors** and **query vectors**.

Every token vector generates a key vector and a query vector.
The key vector indicates the information that the token represents and the query vector contains the information the token is interested in.

To continue our informal example, a token might have the key vector "I am the subject of the sentence" and the query vector "I am interested in the predicate of the sentence".
Of course in reality key and query vectors will not be this interpretable and will represent some instructable high-dimensional concepts that the language model learned during training.

To borrow from databases:

A "query" is analogous to a search query in a database. It represents the current item (e.g., a word or token in a sentence) the model focuses on or tries to understand. The query is used to probe the other parts of the input sequence to determine how much attention to pay to them.

The "key" is like a database key used for indexing and searching. In the attention mechanism, each item in the input sequence (e.g., each word in a sentence) has an associated key. These keys are used to match with the query.

The key vectors and query vectors are computed from the token vectors via simple linear layers.

For our example, we will create random linear layers - in reality this would be parameters that our neural network would have to learn.
Let's call the matrix that will produce the key vectors $W_K$ and the matrix that will produce the query vectors $W_Q$:

In [25]:
W_K = torch.randn(5, 4)
W_Q = torch.randn(5, 4)

We can now compute the key vectors:

In [27]:
torch.matmul(X[0], W_K)

tensor([ 0.6023, -0.7260,  1.1799,  0.2383])

In [26]:
K = torch.matmul(X, W_K)
K

tensor([[ 0.6023, -0.7260,  1.1799,  0.2383],
        [-0.6521,  4.4224, -3.7460, -1.2657],
        [-0.7106, -4.3429,  4.2984, -2.3664]])

We can also compute the query vectors:

In [28]:
Q = torch.matmul(X, W_Q)
Q

tensor([[-1.6964,  1.3355, -0.5133,  0.0674],
        [ 1.6595, -0.4445, -0.1917,  1.7729],
        [-0.1650, -2.9899, -3.8893,  1.2756]])

Now we will compute the attention scores in a similar way as before.
The big difference is that instead of computing the similarity of the token vectors with each other, we will _compute the similarity between the key vectors and the query vectors_:

In [42]:
S = torch.matmul(Q, K.transpose(0, 1))
S

tensor([[-2.5809,  8.8498, -6.9600],
        [ 1.5185, -4.5739, -4.2686],
        [-2.2135, -0.1601, -6.6347]])

A minor, but important technical detail is that we will need to scale the attention scores to avoid numerical instability:

In [43]:
S = S / (4**0.5)
S

tensor([[-1.2905,  4.4249, -3.4800],
        [ 0.7593, -2.2869, -2.1343],
        [-1.1067, -0.0801, -3.3173]])

In [44]:
W = torch.softmax(S, dim=1)
W

tensor([[3.2830e-03, 9.9635e-01, 3.6758e-04],
        [9.0669e-01, 4.3103e-02, 5.0212e-02],
        [2.5632e-01, 7.1558e-01, 2.8102e-02]])

We still have one problem - right now our tokens can "look into the future".
For example token $T_0$ can "see" $T_1$ and $T_2$.

But during inference time, this will not be possible - we can't take into account tokens that haven't been generated yet.
Therefore we should disable this during training as well.

We will "mask" the attention scores of future tokens:

In [59]:
mask = torch.tril(torch.ones(S.shape[0], S.shape[0]))
mask
#S = S.masked_fill(trimat == 0, float("-inf"))
#S

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [60]:
masked_S = S.masked_fill(mask == 0, float("-inf"))
masked_S

tensor([[-1.2905,    -inf,    -inf],
        [ 0.7593, -2.2869,    -inf],
        [-1.1067, -0.0801, -3.3173]])

In [61]:
masked_S = masked_S / (4 ** 0.5)
masked_S

tensor([[-0.6452,    -inf,    -inf],
        [ 0.3796, -1.1435,    -inf],
        [-0.5534, -0.0400, -1.6587]])

In [63]:
masked_W = torch.softmax(masked_S, dim=1)
masked_W

tensor([[1.0000, 0.0000, 0.0000],
        [0.8210, 0.1790, 0.0000],
        [0.3331, 0.5566, 0.1103]])

Additionally, we can apply Dropout:

In [65]:
dropout = nn.Dropout(0.5)

In [66]:
dropout(masked_W)

tensor([[2.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 1.1132, 0.2206]])

## Value Vectors

Right now, we would apply the scores to the token vectors directly.

It turns out that the attention mechanism performs even better if we introduce one more indirection and calculate **value vectors** from the token vectors and only then apply the scores.

To borrow from databases:

The "value" in this context is similar to the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values.

The computation of the value vectors works exactly like the computation of the key and query vectors:

In [39]:
W_V = torch.randn(5, 4)

In [40]:
V = torch.matmul(X, W_V)
V

tensor([[-0.9285,  0.3301,  1.8359, -1.3448],
        [ 0.4676, -0.1512, -0.5678,  0.8648],
        [ 0.6143,  2.6772, -1.3256, -3.2423]])

In [45]:
R = torch.matmul(W, V)
R

tensor([[ 0.4630, -0.1485, -0.5602,  0.8561],
        [-0.7909,  0.4272,  1.5735, -1.3448],
        [ 0.1138,  0.0517,  0.0270,  0.1831]])

## Implementing a Self-Attention Class

Note that we need to take care because we will also have a batch dimension:

In [67]:
X = torch.randn(2, 3, 5)
X

tensor([[[-0.4249,  0.9442, -0.1849,  1.0608,  0.2083],
         [-0.5778,  0.3255, -0.8146, -0.7599, -2.0461],
         [-1.5295,  0.4049,  0.6319,  0.3125,  1.9892]],

        [[-0.4611, -0.0639, -1.3667,  0.3298, -0.9827],
         [ 0.3018,  0.1787,  0.4097, -1.5754,  2.2508],
         [ 1.0012,  1.3642,  0.6333,  0.4050,  0.3416]]])

In [70]:
W_Q = nn.Linear(5, 4, bias=False)
W_K = nn.Linear(5, 4, bias=False)
W_V = nn.Linear(5, 4, bias=False)
K = W_K(X)
Q = W_Q(X)
V = W_V(X)

K.shape, Q.shape, V.shape

(torch.Size([2, 3, 4]), torch.Size([2, 3, 4]), torch.Size([2, 3, 4]))

In [71]:
S = torch.matmul(Q, K.transpose(1, 2))
S

tensor([[[ 0.1268,  0.1442,  0.5098],
         [-0.1395, -0.4737, -0.1583],
         [-0.0595,  0.8114,  0.3408]],

        [[ 0.0802,  0.3535,  0.3662],
         [-0.1526,  0.1022, -0.3185],
         [ 0.0630,  0.1827, -0.6134]]], grad_fn=<UnsafeViewBackward0>)

In [73]:
mask = torch.tril(torch.ones(3, 3))
mask

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [78]:
masked_S = S.masked_fill(mask == 0, float("-inf"))
masked_S

tensor([[[ 0.1268,    -inf,    -inf],
         [-0.1395, -0.4737,    -inf],
         [-0.0595,  0.8114,  0.3408]],

        [[ 0.0802,    -inf,    -inf],
         [-0.1526,  0.1022,    -inf],
         [ 0.0630,  0.1827, -0.6134]]], grad_fn=<MaskedFillBackward0>)

In [79]:
K.shape

torch.Size([2, 3, 4])

Also softmax along dim -1

In [11]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, c_len, dropout):
        super().__init__()

        self.W_Q = nn.Linear(d_in, d_out, bias=False)
        self.W_K = nn.Linear(d_in, d_out, bias=False)
        self.W_V = nn.Linear(d_in, d_out, bias=False)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer("mask", torch.tril(torch.ones(c_len, c_len)))

    def forward(self, X):
        batch_size, num_tokens, d_in = X.shape
        
        Q = self.W_Q(X)
        K = self.W_K(X)
        V = self.W_V(X)

        S = torch.matmul(Q, K.transpose(1, 2))
        masked_S = S.masked_fill(self.mask == 0, float("-inf"))
        masked_S = masked_S / K.shape[-1] ** 0.5
        W = torch.softmax(masked_S, dim=-1)

        W = self.dropout(W)
        
        R = torch.matmul(W, V)

        return R

layer = SelfAttention(d_in=5, d_out=4, c_len=3, dropout=0.5)

In [9]:
X = torch.randn(2, 3, 5)
X.shape

torch.Size([2, 3, 5])

In [10]:
Y = layer(X)
Y.shape

torch.Size([2, 3, 4])

Note that in practice the number of input dimensions and output dimensions will usually be the same, i.e. the self-attention layer doesn't change the dimensionality of the tensor, it only "mixes" the elements of the tensor together.

## Multi - Head Attention

The final piece we are missing is the multi-head attention.

Basically instead of having a single self-attention layer, we use multiple self-attention layers, each with its own weights and combine their outputs.

In [16]:
d_in = 5
d_out = 4
c_len = 3
dropout = 0.5

n_heads = 2
heads = [SelfAttention(d_in, d_out, c_len, dropout) for _ in range(n_heads)]

result = [head(X) for head in heads]

In [17]:
[head_out.shape for head_out in result]

[torch.Size([2, 3, 4]), torch.Size([2, 3, 4])]

Next we combine them:

In [18]:
head_out_combined = torch.cat(result, dim=-1)
head_out_combined.shape

torch.Size([2, 3, 8])

Note that in practice we don't want the multi-head attention layer to blow up the size of the tensor.

Therefore we reduce the value of $d_{out}$ and set it to this:

In [19]:
d_out = 2
heads = [SelfAttention(d_in, d_out, c_len, dropout) for _ in range(n_heads)]

head_out_combined = torch.cat([head(X) for head in heads], dim=-1)
head_out_combined.shape

torch.Size([2, 3, 4])

While this technically already works, it is computationally expensive since we process the heads sequentially.
Instead we can process them in parallel by computing the outputs for all attention heads at the same time:

In [21]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, dropout, n_heads):
        super().__init__()

        self.head_dim = d_out // n_heads

        self.W_K = nn.Linear(d_in, d_out, bias=False)
        self.W_Q = nn.Linear(d_in, d_out, bias=False)
        self.W_V = nn.Linear(d_in, d_out, bias=False)

    def forward(self, X):
        batch_size, n_tokens, d_in = X.shape

        K = self.W_K(X)
        Q = self.W_Q(X)
        V = self.W_V(X)

In [34]:
n_tokens = 3
d_in = 4
d_out = 4
n_heads = 2
head_dim = d_out // n_heads

In [30]:
X = torch.randn(2, n_tokens, d_in)
X.shape

torch.Size([2, 3, 4])

In [31]:
W_K = torch.randn(d_out,  d_in)
W_Q = torch.randn(d_out, d_in)
W_V = torch.randn(d_out, d_in)

In [32]:
K = torch.matmul(X, W_K)
Q = torch.matmul(X, W_Q)
K.shape, Q.shape

(torch.Size([2, 3, 4]), torch.Size([2, 3, 4]))

In [36]:
K_view = K.view(2, n_tokens, n_heads, head_dim)

In [38]:
K_view.shape

torch.Size([2, 3, 2, 2])

In [40]:
Q_view = Q.view(2, n_tokens, n_heads, head_dim)

In [41]:
Q_view.shape

torch.Size([2, 3, 2, 2])

In [42]:
K_view = K_view.transpose(1, 2)
K_view.shape

torch.Size([2, 2, 3, 2])

In [43]:
Q_view = Q_view.transpose(1, 2)
Q_view.shape

torch.Size([2, 2, 3, 2])

In [45]:
S = torch.matmul(Q_view, K_view.transpose(2, 3))
S.shape

torch.Size([2, 2, 3, 3])

In [47]:
S[0]

tensor([[[ -2.7366,  -1.1301,  -2.7070],
         [  2.1529,   0.7935,   1.6183],
         [ -6.5526,  -2.3705,  -4.6866]],

        [[ -0.3216,   3.0144, -13.2677],
         [ -0.6322,  -1.1477,  -0.0831],
         [ -0.2841,   1.0993,  -5.9737]]])

In [49]:
mask = torch.tril(torch.ones(n_tokens, n_tokens))
mask

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [51]:
masked_S = S.masked_fill(mask == 0, float("-inf"))
masked_S[0]

tensor([[[-2.7366,    -inf,    -inf],
         [ 2.1529,  0.7935,    -inf],
         [-6.5526, -2.3705, -4.6866]],

        [[-0.3216,    -inf,    -inf],
         [-0.6322, -1.1477,    -inf],
         [-0.2841,  1.0993, -5.9737]]])

In [53]:
masked_S = masked_S / K.shape[-1] ** 0.5
masked_S[0]

tensor([[[-0.6841,    -inf,    -inf],
         [ 0.5382,  0.1984,    -inf],
         [-1.6381, -0.5926, -1.1716]],

        [[-0.0804,    -inf,    -inf],
         [-0.1581, -0.2869,    -inf],
         [-0.0710,  0.2748, -1.4934]]])

In [55]:
W = torch.softmax(masked_S, dim=-1)
W[0]

tensor([[[1.0000, 0.0000, 0.0000],
         [0.5842, 0.4158, 0.0000],
         [0.1838, 0.5230, 0.2931]],

        [[1.0000, 0.0000, 0.0000],
         [0.5322, 0.4678, 0.0000],
         [0.3767, 0.5324, 0.0908]]])

In [69]:
# Apply droput normally

V = torch.matmul(X, W_V) 
V_view = V.view(2, n_tokens, n_heads, head_dim)
V_view = V_view.transpose(1, 2)
R = torch.matmul(W, V_view)

In [70]:
V_view.shape, W.shape

(torch.Size([2, 2, 3, 2]), torch.Size([2, 2, 3, 3]))

In [71]:
R.shape

torch.Size([2, 2, 3, 2])

In [72]:
R = R.transpose(1, 2)

In [73]:
R.shape

torch.Size([2, 3, 2, 2])

In [75]:
R_combined = R.contiguous().view(2, n_tokens, d_out)
R_combined.shape

torch.Size([2, 3, 4])