In the previous notebook we prepared and massaged our text into input vectors that encode both tokens as well as their positions.  It's these input vectors that will be handled by the rest of the model.

# Self-Attention



Self-Attention is a technique whereby each position in the input sequence can consider the relevancy of each other position in the same sequence when the representation for the sequence is being computed.  (Traditional "attention", as opposed to self-attention look at relations between two different sequences, input- and output, as opposed as using a single sequence.)

## Simplified Self-Attention

Let's take a simple sentence:
"I am learning this"

This is our input sequencen let's call it $x$ with 4 tokens: $x^{(1)}$, $x^{(2)}$, $x^{(3)}$ and $x^{(4)}$.

In general if our context length is $T$ then we have $x^{(1)} \ldots x^{(T)}$

Each $x^{(i)}$ is a $d$-dimensional embedding vector representing a token.

Now we will calculate a context vector $z^{i}$ for each $x^{i}$.  This vector will contain information from all vectors $x^{(1)} \ldots x^{(T)}$

### Attention scores

So, for example, for $x^{(3)}$ we'll calculate a context vector $z^{(3)}$.  We'll call $x^{(3)}$ our "query" vector and for this query we'll calculate an attention score for each of the tokens in our sentence: for example $w_{31}$ for the attention score between our query (the 3rd token) and the first token.  In this way we'll have: $w_{31}$, $w_{32}$, $w_{33}$ and $w_{34}$

The attention score $w_{31}$ is the dot product from the third token with the first token.  So if:
$$ x^{(3)} = [0.2,0.7,0.9] \text{  (representing the token "learning")} $$
$$ x^{(1)} = [0.8,0.9,0.3] \text{  (representing the token "I")} $$

Then:

$$w_{(31)} = x^{(3)} \cdot x^{(1)} = [(.2*.8) + (.7*.9) + (.9 * .3)] = 1.06 $$

We'll do this between our query $x^{(3)}$ and each of our $x^{(1)},x^{(2)},x^{(3)},x^{(4)}$ which result in an attention vector $w_3 = [ w_{(31)},w_{(32)},w_{(33)},w_{(34)} ]$

Now in python, using an input vector of our 4 tokens, each with an embedding dimension of 3:

In [56]:
#|echo: false
import torch
torch.manual_seed(42)

embedding_dim = 3
inputs = torch.rand((4,embedding_dim))
print(inputs)

tensor([[0.8823, 0.9150, 0.3829],
        [0.9593, 0.3904, 0.6009],
        [0.2566, 0.7936, 0.9408],
        [0.1332, 0.9346, 0.5936]])


We'll calculate the attention vector for $x^{(3)}$ as our query:

In [57]:
query = inputs[2]
attention_score_for_x3 = torch.empty((4))
for i, xi in enumerate(inputs):
    # Calculate the attention score for x3i
    attention_score_for_x3[i] = torch.dot(xi, query)

print(attention_score_for_x3)

tensor([1.3127, 1.1213, 1.5807, 1.3343])


### Normalizing

Now that we've calculated the attention vector $w_3$ for query $x^{(3)}$, we have a vector with a number for every token.  These numbers however are not normalized between $0$ and $1$, which is what we'd really like.  For this, we'll use [Cross-Entropy Loss](../theory/cross-entropy-loss.ipynb), which will make sure the numbers in our attention vector add up to one and each are between $0 \ldots 1$

In [58]:
attention_weights_for_x3 = torch.softmax(attention_score_for_x3, dim=0)
print(attention_weights_for_x3)

tensor([0.2407, 0.1987, 0.3146, 0.2459])


### Context Vector

So now that we have our normalized attention weights for a single query $x^{(3)}$ we can calculate the full context vector that corresponds to $x^{(3)}$

Our attention weights from the previous step were: $[.2, .1, .3, .2]$ or more in general, a vector $[\alpha _{31}, \alpha _{32}, \alpha _{33}, \alpha _{34}]$

To calculate the context vector for $x^{(3)}$ we'll take each $\alpha _{3i}$ and multiply that by $x^{(i)}$.  Then we'll add up all those vectors.

In [59]:
query = inputs[2]
context_vector_for_x3 = torch.zeros((query.shape)) # dimension of query (3 in this case)

for i, xi in enumerate(inputs):
    context_vector_for_x3 += attention_weights_for_x3[i] * xi # z3

print(context_vector_for_x3)

tensor([0.5165, 0.7774, 0.6536])


### All context vectors

What we did so far is to look at how to calculate a single context vector $z^{(3)}$, for a single token $x^{(3)}$ in our input sequence.  We'll need to make this more scalable and figure out a way to calculate $z$ for all tokens in our input sequence.

This means, what we have done for $x^{(3)}$ we need to do for all inputs:

In [60]:
attention_scores_manual = torch.empty((4, 4))  # for each query of the 4 inputs, calculate 4 attention scores
for i, query in enumerate(inputs):
    for j, xi in enumerate(inputs):
        attention_scores_manual[i][j] = torch.dot(query, xi)

print("attention scores:")
print(attention_scores_manual)
print("for comparison, the attention scores for x3 as query:")
print(attention_score_for_x3)

attention scores:
tensor([[1.7622, 1.4337, 1.3127, 1.1999],
        [1.4337, 1.4338, 1.1213, 0.8494],
        [1.3127, 1.1213, 1.5807, 1.3343],
        [1.1999, 0.8494, 1.3343, 1.2435]])
for comparison, the attention scores for x3 as query:
tensor([1.3127, 1.1213, 1.5807, 1.3343])


Using `for` loops however is slow and can't be optimized using CUDA, so let's find a way to do the same, but to use pure tensor calculations instead.  Our inputs look like the below, with 4 tokens, each with a dimension of 3

In [61]:
#| echo: false
print(inputs)

tensor([[0.8823, 0.9150, 0.3829],
        [0.9593, 0.3904, 0.6009],
        [0.2566, 0.7936, 0.9408],
        [0.1332, 0.9346, 0.5936]])


We can transpose this vector, so it looks like:

In [62]:
#| echo: false
print(inputs.T)

tensor([[0.8823, 0.9593, 0.2566, 0.1332],
        [0.9150, 0.3904, 0.7936, 0.9346],
        [0.3829, 0.6009, 0.9408, 0.5936]])


If we now take multiply these two matrices, we get:

In [63]:
attention_scores = inputs @ inputs.T  # matrix multiplication
print(inputs @ inputs.T) # matrix multiplication
print("for comparison, the attention scores for x3 as query:")
print(attention_score_for_x3)

tensor([[1.7622, 1.4337, 1.3127, 1.1999],
        [1.4337, 1.4338, 1.1213, 0.8494],
        [1.3127, 1.1213, 1.5807, 1.3343],
        [1.1999, 0.8494, 1.3343, 1.2435]])
for comparison, the attention scores for x3 as query:
tensor([1.3127, 1.1213, 1.5807, 1.3343])


These are just attention scores, not yet attention weights, so lets normalize them:

In [64]:
attention_weights = torch.softmax(attention_scores, dim=-1)  # normalize the attention scores
print("attention weights:")
print(attention_weights)
print("for comparison, the attention weights for x3 as query:")
print(attention_weights_for_x3)

attention weights:
tensor([[0.3415, 0.2459, 0.2179, 0.1946],
        [0.3040, 0.3040, 0.2225, 0.1695],
        [0.2407, 0.1987, 0.3146, 0.2459],
        [0.2569, 0.1809, 0.2938, 0.2683]])
for comparison, the attention weights for x3 as query:
tensor([0.2407, 0.1987, 0.3146, 0.2459])


From here we can calculate our context vectors:

In [65]:
all_context_vectors = attention_weights @ inputs  # matrix multiplication
print("context vectors for all inputs:")
print(all_context_vectors)
print("for comparison, the context vector for x3:")
print(context_vector_for_x3)

context vectors for all inputs:
tensor([[0.6191, 0.7634, 0.5991],
        [0.6395, 0.7318, 0.6090],
        [0.5165, 0.7774, 0.6536],
        [0.5113, 0.7897, 0.6428]])
for comparison, the context vector for x3:
tensor([0.5165, 0.7774, 0.6536])


## Self-Attention with trainable weights

From here we'll expand to what is called "scaled dot-product attention".  Also here we'll want to calculate context vectors (one for each of our input tokens) as a weighted sum over (some abstraction) of the inputs.  There are some differences with what we've done so far though:

- instead of taking directly $q^{(i)} = x^{(i)}$ as the query vector, we'll use a projection of $x^{(i)}$.  We do this using a trainable weight matrix $W_q$, used to calculate our query
- instead of taking directly the dot product of $q^{(i)}$ and each $x^{(i)}$ to calculate our attention scores and weights, we'll do the dot product with a projection of $x^{(i)}$.  We do this using a trainable weight matrix $W_k$, used to calculate our keys.
- instead of calculating the weighted average using these scores with each $x^{(i)}$, we'll do this with a projection of $x^{(i)}$.  We do this using a trainable weight matrix $W_v$, used to calculate our values.

We'll have trainable weight matrixes: $W_q$, $W_k$, $W_v$ so that the model, when trained can learn to use these to project respectively the query, key and value vectors.

### Single context vector (1)

Our token embedding vectors are of a certain dimension `embedding_dim = 3` in our example.  We can choose to project these into another dimension, from 3 to 5 for example.  (This is not usually done however, it's likely kept the same.)  For illustration purposes, let's go with 5 here:

In [66]:
d_in = embedding_dim
d_out = 2

W_query_1 = torch.rand((d_in, d_out))
W_key_1 = torch.rand((d_in, d_out))
W_value_1 = torch.rand((d_in, d_out))

We can now use  $W_q$, $W_k$, $W_v$ to project an embedding vector from it's normal dimension into a dimension of 5.  Let's use $x^{(3)}$ as an example for our query:

In [67]:
x3 = inputs[2]
print("x3 input:")
print(x3)

print("W_query:")
print(W_query_1)

query = x3 @ W_query_1
print("projected query:")
print(query)

x3 input:
tensor([0.2566, 0.7936, 0.9408])
W_query:
tensor([[0.8694, 0.5677],
        [0.7411, 0.4294],
        [0.8854, 0.5739]])
projected query:
tensor([1.6442, 1.0264])


Instead of defining our projection matrices like above, we'll define them as pytorch parameters:

In [68]:
W_query = torch.nn.Parameter(W_query_1, requires_grad=False)
W_key = torch.nn.Parameter(W_key_1, requires_grad=False)
W_value = torch.nn.Parameter(W_value_1, requires_grad=False)


Let's calculate our key and value vectors for every input token in our sentence "I am learning this": $x^{(1)}$: "I", $x^{(2)}$: "am", $x^{(3)}$: "learning", $x^{(4)}$: "this", 

In [69]:
values = inputs @ W_value
print("values:")
print(values)
keys = inputs @ W_key
print("keys:")
print(keys)

values:
tensor([[0.6307, 0.4225],
        [0.5699, 0.3401],
        [0.8266, 0.2332],
        [0.6742, 0.2259]])
keys:
tensor([[0.5956, 1.2759],
        [0.5394, 1.2740],
        [0.5617, 1.2937],
        [0.4637, 0.9897]])


Now we can calculate the attention scores, for our query vector that got projected from $x^{(3)}$

In [70]:
#| echo: false
print("query projection from x3:")
print(query)

query projection from x3:
tensor([1.6442, 1.0264])


The attention score $w_{31}$ is:

In [71]:
x1 = inputs[0]
print("x1 input:")
print(x1)
print("key projection from x1: ")
key = x1 @ W_key
print(key)
print("... which is the same as: ")
print(keys[0])  # keys[0] is the key for x1

print("attention score between our query and x1's projected key:")
attention_score_x3_x1 = query.dot(key)
print(attention_score_x3_x1)

x1 input:
tensor([0.8823, 0.9150, 0.3829])
key projection from x1: 
tensor([0.5956, 1.2759])
... which is the same as: 
tensor([0.5956, 1.2759])
attention score between our query and x1's projected key:
tensor(2.2888)


In the same way as we got the attention score between the query and x1, we can get all attention scores for our query like this:

In [72]:
attention_scores_x3_as_query = query @ keys.T
print("attention scores for x3 as query:")
print(attention_scores_x3_as_query)

attention scores for x3 as query:
tensor([2.2888, 2.1945, 2.2514, 1.7783])


As is shown above, for our chosen query, we'll end up with 4 attention scores: one for each input token.  Like before, we'll want to normalize our attention scores into attention weights but instead of doing a pure softmax function, we'll first scale the attention scores by dividing them by the square root of the dimension of our projected keys (2 in this case).

In [73]:
print("dimension of our projected key: ", d_out)
scaled_attention_weights_x3_as_query = \
  torch.softmax(attention_scores_x3_as_query / d_out**0.5, dim=-1)
print("scaled attention weights for x3 as query:")
print(scaled_attention_weights_x3_as_query)

dimension of our projected key:  2
scaled attention weights for x3 as query:
tensor([0.2773, 0.2594, 0.2700, 0.1933])


We don't have our full context vector yet.  For that we'll still need to make a weighted combination of our projected value vectors.  Before we do so, let's digress a little on the need for scaling the softmax.

### Why scaling?

As the dimension for our projected key vector grows, the dot products can become large numbers.  An example:

In [88]:
key_dim = 3
print("dimension of the key: ", key_dim)
small_proj_key_dim = 2
large_proj_key_dim = 64

test_token_embeddings = torch.rand((4, key_dim)) # 4 tokens, each with key_dim features
print("test token embeddings for 4 token:")
print(test_token_embeddings)

test_query  = test_token_embeddings[2]  # let's take the 3rd token as query
print("test query:")
print(test_query)

W_k_for_small_output_dim = torch.rand(key_dim, small_proj_key_dim)
W_q_for_small_output_dim = torch.rand(key_dim, small_proj_key_dim)
W_k_for_large_output_dim = torch.rand(key_dim, large_proj_key_dim)
W_q_for_large_output_dim = torch.rand(key_dim, large_proj_key_dim)

projected_small_keys = test_token_embeddings @ W_k_for_small_output_dim
projected_small_query = test_query @ W_q_for_small_output_dim
projected_large_keys = test_token_embeddings @ W_k_for_large_output_dim
projected_large_query = test_query @ W_q_for_large_output_dim

print("projected small query:")
print(projected_small_query)
print("projected small keys:")
print(projected_small_keys)
print("projected large query:")
print(projected_large_query)
print("projected large keys:")
print(projected_large_keys)

dimension of the key:  3
test token embeddings for 4 token:
tensor([[0.2754, 0.1046, 0.5527],
        [0.1031, 0.1160, 0.0935],
        [0.3645, 0.9416, 0.6955],
        [0.4407, 0.6029, 0.6991]])
test query:
tensor([0.3645, 0.9416, 0.6955])
projected small query:
tensor([1.6317, 1.3363])
projected small keys:
tensor([[0.4279, 0.0668],
        [0.1197, 0.0325],
        [0.7030, 0.2156],
        [0.6780, 0.1689]])
projected large query:
tensor([1.3598, 0.5360, 1.1592, 0.6897, 0.7081, 1.1354, 1.4924, 0.7659, 0.5333,
        0.6289, 1.6174, 1.0476, 1.2979, 0.9037, 1.1668, 1.0850, 1.2032, 0.7618,
        0.6385, 0.5907, 0.9166, 1.4778, 1.0010, 1.3485, 0.9444, 1.3983, 0.7808,
        0.6974, 1.0598, 0.9657, 1.0985, 0.4365, 1.1316, 0.7342, 1.8704, 0.9181,
        1.0740, 0.4736, 0.8474, 1.0835, 0.4376, 1.0398, 0.3557, 1.6521, 1.1620,
        1.5043, 1.2346, 0.4994, 1.5802, 1.1302, 1.5070, 1.2899, 0.6512, 0.6052,
        1.6926, 1.2671, 0.7991, 1.2501, 0.6242, 0.8402, 1.2997, 1.0477, 0.7979,


In [89]:
attention_scores_for_small_keys = projected_small_query @ projected_small_keys.T
print("attention scores for small keys:")
print(attention_scores_for_small_keys)
attention_scores_for_large_keys = projected_large_query @ projected_large_keys.T
print("attention scores for large keys:")
print(attention_scores_for_large_keys)

attention scores for small keys:
tensor([0.7875, 0.2388, 1.4352, 1.3320])
attention scores for large keys:
tensor([29.5998,  9.9446, 63.8539, 55.4694])


See how the attention score for a large projected key dimension is so much larger?  Let's see what happens if we just apply softmax to both the small and the large key:

In [None]:
attention_weights_for_small_keys = torch.softmax(attention_scores_for_small_keys, dim=0)
print("attention weights for small keys:")
print(attention_weights_for_small_keys)
attention_weights_for_large_keys = torch.softmax(attention_scores_for_large_keys, dim=0)
print("attention weights for large keys:")
print(attention_weights_for_large_keys)

attention weights for small keys:
tensor([0.1918, 0.1108, 0.3666, 0.3307])
attention weights for large keys:
tensor([1.3290e-15, 3.8671e-24, 9.9977e-01, 2.2833e-04])


If we now scale these values before applying softmax, we'll get a different output that brings the result for a large projected key vector much closer in range compared to the small projected key:

In [92]:
scaled_att_weights_for_small_keys = \
  torch.softmax(attention_scores_for_small_keys / small_proj_key_dim**0.5, dim=0)
print("scaled attention weights for small keys:")
print(scaled_att_weights_for_small_keys)
scaled_att_weights_for_large_keys = \
  torch.softmax(attention_scores_for_large_keys / large_proj_key_dim**0.5, dim=0)
print("scaled attention weights for large keys:")
print(scaled_att_weights_for_large_keys)

scaled attention weights for small keys:
tensor([0.2115, 0.1435, 0.3343, 0.3108])
scaled attention weights for large keys:
tensor([0.0101, 0.0009, 0.7323, 0.2567])


### Single context vector (2)

Now that we have our scaled attention weights, we can calculate the context vector for our query.  We'll do this by making a weighted combination of the projected values.  Our scaled attention weights were:

In [94]:
#| echo: false
print("scaled attention weights, one for each token:")
print(scaled_attention_weights_x3_as_query)

scaled attention weights, one for each token:
tensor([0.2773, 0.2594, 0.2700, 0.1933])


These we'll multiply with our projected values.  Our values were:

In [95]:
#| echo: false
print("projected values:")
print(values)

projected values:
tensor([[0.6307, 0.4225],
        [0.5699, 0.3401],
        [0.8266, 0.2332],
        [0.6742, 0.2259]])


The calculated attention vector for query from $x^{(3)}$:

In [100]:
context_vector_x3 = scaled_attention_weights_x3_as_query @ values
print("context vector for x3:")
print(context_vector_x3)
print("double check: the first number is the same as:")
print(0.2773*0.6307+0.2594*0.5699+0.27*0.8266+0.1933*0.6742)
print("the second number is the same as:")
print(0.2773*0.4225+0.2594*0.3401+0.27*0.2332+0.1933*0.2259)


context vector for x3:
tensor([0.6762, 0.3120])
double check: the first number is the same as:
0.67623003
the second number is the same as:
0.31201166
