# Self Attention

The below personal learning notes made use of [Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch](https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html)

In [None]:
#|hide
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(123)
torch.set_printoptions(precision=1, sci_mode=False, profile='short')

## What is self-attention?

Self-Attention started out from research work in translation and was introduced to give access to all elements in a sequence at each time step.  In language tasks, the meaning of a word can depend on the context within a larger text document.  Attention enables the model to weigh the importance of different elements in the input sequence and adjust their influence on the output.

## Embedding an Input Sentence

Our input is: "Playing music makes me very happy".  We'll create an embeding for this entire sentence first.

In [None]:
sentence = "Playing music makes me very happy"

sentence_words = sentence.split()
sentence_words

['Playing', 'music', 'makes', 'me', 'very', 'happy']

In [None]:
sentence_words_sorted = sorted(sentence_words)
sentence_words_sorted

['Playing', 'happy', 'makes', 'me', 'music', 'very']

In [None]:
dict = {word_str:word_idx for word_idx, word_str in enumerate(sentence_words_sorted)}
dict

{'Playing': 0, 'happy': 1, 'makes': 2, 'me': 3, 'music': 4, 'very': 5}

`dict` is our dictionary, conveniently restricted to just the words we're using here.  Every word we're using has a number associated (the index in our dictionary.  

We can now translate our sentence in an array of integers:

In [None]:
sentence_int = torch.tensor([dict[word] for word in sentence_words])
sentence_int

tensor([0, 4, 2, 3, 5, 1])

Now that our sentence is translated into a list of integers, we can use those with an embedding layer to encode the inputs into a real vector embedding.  Let's use 16 dimensions, so that each word is translated/mapped onto an embedding of 16 floats.

If our sentence is 6 words (or whatever is the context length we end up choosing), the resulting vector after our embedding layer will be: $6 \times 16$.  We'll create a pytorch embedding layer with 6 possible indices and a 16-dimensional embedding vector for each index.

In [None]:
embed = torch.nn.Embedding(6,16)
sentence_embedded = embed(sentence_int).detach()
print(sentence_embedded)
print(sentence_embedded.shape)

tensor([[ 0.3, -0.2, -0.3, -0.6,  0.3,  0.7, -0.2, -0.4,  0.8, -1.2,  0.7, -1.4,
          0.2,  1.9,  0.5,  0.3],
        [ 0.5,  1.0, -0.3, -1.1, -0.0,  1.6, -2.3,  1.1,  0.7,  0.7, -0.9, -0.1,
         -0.2,  0.1,  0.4, -1.4],
        [-1.3,  0.2, -2.1,  1.1, -0.4, -0.9, -0.5, -1.1,  0.9,  1.6,  0.6, -0.2,
          0.1, -0.1,  0.3, -0.6],
        [ 0.9,  1.6, -1.5,  1.1, -1.2,  1.3,  1.1,  0.1,  2.2, -0.8, -0.3,  0.8,
         -0.7, -0.8,  0.2,  0.2],
        [ 0.3, -0.5,  1.0,  0.8, -0.4,  0.5, -0.2, -1.7, -1.6, -1.1,  0.9, -0.7,
         -0.6, -0.7,  0.6, -1.4],
        [-0.1, -1.0, -0.2,  0.9,  1.6,  1.3,  1.3, -0.2,  0.5, -1.6,  1.0, -1.1,
         -1.2,  0.3, -0.6, -2.8]])
torch.Size([6, 16])


So, we gave the embedding a tensor of 6 integers, which got translated in $6 \times 16$ tensors, meaning: each index, representing a word, it translated into an array of 16 floats.  We can look into the weights of our embedding layer here as well:

In [None]:
embed.weight

Parameter containing:
tensor([[ 0.3, -0.2, -0.3, -0.6,  0.3,  0.7, -0.2, -0.4,  0.8, -1.2,  0.7, -1.4,
          0.2,  1.9,  0.5,  0.3],
        [-0.1, -1.0, -0.2,  0.9,  1.6,  1.3,  1.3, -0.2,  0.5, -1.6,  1.0, -1.1,
         -1.2,  0.3, -0.6, -2.8],
        [-1.3,  0.2, -2.1,  1.1, -0.4, -0.9, -0.5, -1.1,  0.9,  1.6,  0.6, -0.2,
          0.1, -0.1,  0.3, -0.6],
        [ 0.9,  1.6, -1.5,  1.1, -1.2,  1.3,  1.1,  0.1,  2.2, -0.8, -0.3,  0.8,
         -0.7, -0.8,  0.2,  0.2],
        [ 0.5,  1.0, -0.3, -1.1, -0.0,  1.6, -2.3,  1.1,  0.7,  0.7, -0.9, -0.1,
         -0.2,  0.1,  0.4, -1.4],
        [ 0.3, -0.5,  1.0,  0.8, -0.4,  0.5, -0.2, -1.7, -1.6, -1.1,  0.9, -0.7,
         -0.6, -0.7,  0.6, -1.4]], requires_grad=True)

This is basically a kind of "lookup" matrix, where we can lookup the embedding vector corresponding to every token in our dictionary.  As such: dictionary token with index:

- `0` will return: `[0.3, -0.2, -0.3, -0.6,  0.3,  0.7, -0.2, -0.4,  0.8, -1.2,  0.7, -1.4, 0.2,  1.9,  0.5,  0.3]`
- `1` will return: `[-0.1, -1.0, -0.2,  0.9,  1.6,  1.3,  1.3, -0.2,  0.5, -1.6,  1.0, -1.1, -1.2,  0.3, -0.6, -2.8]`
- `2` will return: `[-1.3,  0.2, -2.1,  1.1, -0.4, -0.9, -0.5, -1.1,  0.9,  1.6,  0.6, -0.2, 0.1, -0.1,  0.3, -0.6]`

and so on.  Given our sentence had tokens with indexes: $\begin{bmatrix}
0 & 4 & 2 & 3 & 5 & 1
\end{bmatrix}$ we expect first the first row, then the 5th, then 3rd, ... and so on, which gives the same end result:

\begin{bmatrix}
0.3 & -0.2 & -0.3 & -0.6 & 0.3 & 0.7 & -0.2 & -0.4 & 0.8 & -1.2 & 0.7 & -1.4 & 0.2 & 1.9 & 0.5 & 0.3 \\
0.5 & 1.0 & -0.3 & -1.1 & -0.0 & 1.6 & -2.3 & 1.1 & 0.7 & 0.7 & -0.9 & -0.1 & -0.2 & 0.1 & 0.4 & -1.4 \\
-1.3 & 0.2 & -2.1 & 1.1 & -0.4 & -0.9 & -0.5 & -1.1 & 0.9 & 1.6 & 0.6 & -0.2 & 0.1 & -0.1 & 0.3 & -0.6 \\
0.9 & 1.6 & -1.5 & 1.1 & -1.2 & 1.3 & 1.1 & 0.1 & 2.2 & -0.8 & -0.3 & 0.8 & -0.7 & -0.8 & 0.2 & 0.2 \\
0.3 & -0.5 & 1.0 & 0.8 & -0.4 & 0.5 & -0.2 & -1.7 & -1.6 & -1.1 & 0.9 & -0.7 & -0.6 & -0.7 & 0.6 & -1.4 \\
-0.1 & -1.0 & -0.2 & 0.9 & 1.6 & 1.3 & 1.3 & -0.2 & 0.5 & -1.6 & 1.0 & -1.1 & -1.2 & 0.3 & -0.6 & -2.8
\end{bmatrix}

## Defining weight matrices

### Set up and dimensions

Self-attention has 3 weight matrices which are each adjusted, like other model parameters, during training.

- $W_{q}$: projects our input to the *query*
- $W_{k}$: projects our input to the *key*
- $W_{v}$: projects our input to the *value*

each of *query* $q$, *key* $k$ and *value* $v$ are vectors of an input element.  We can calculate those through matrix multiplication between those $W$ matrices and the embedded inputs $x$.  Our sequence has length $T$.

- $q^{i} = W_{q} x^{(i)}$ for the element on index i, i between $0$ and $T-1$
- $k^{i} = W_{k} x^{(i)}$ for the element on index i, i between $0$ and $T-1$
- $v^{i} = W_{v} x^{(i)}$ for the element on index i, i between $0$ and $T-1$

This will give us three vectors for each input element (token) in our sequence.

Let's assume that $d$ is the size (number of dimensions) of each (embedded) word vector x (here 16).  Our vector $q^{i}$ is the query vector for word at index $i$ and has a dimension we can choose.  We'll call this $d_q$.  In the same way we'll call $d_k$ as the dimension for $k^{i}$.

We'll calculate the dot product between the query and key vectors, this means that each of them needs to have the same dimensions: $d_q = d_k$. Let's choose $d_q = d_k = 24$ in this case.<br/>
If $$q^{i} = W_{q} x^{(i)}$$ then: 

- the dimension for $q^{i}$ is $d_q$ which is the same as $d_k$, here 24, something we chose
- the dimension for $W_{q}$ is $d_q \times d$, here 24 by 16, because every word is represented by 16 floats
- the dimension for $x^{(i)}$ is $d$, here 16 (16 floats for every word)

Our dimension for the value vector can be chosen arbitrarily, let's say: 28 in our example.  That's the size of the resulting context vector.

Let's set up some arbitrary weight matrices:


In [None]:
d = 16
d_q, d_k, d_v = 24, 24, 28
W_query = torch.rand(d_q,d)
W_key = torch.rand(d_k,d)
W_value = torch.rand(d_v,d)

W_query.shape

torch.Size([24, 16])

### Calculate the query, key and value for one word

In [None]:
x_2 = sentence_embedded[2]
query_2 = W_query.matmul(x_2)
print(f'W_query has shape {W_query.shape}, x_2 has shape {x_2.shape} and resulting query_2 has shape: {query_2.shape}\n')
print('the resulting tensor is our query tensor for word at index 2:')
print(query_2)

W_query has shape torch.Size([24, 16]), x_2 has shape torch.Size([16]) and resulting query_2 has shape: torch.Size([24])

the resulting tensor is our query tensor for word at index 2:
tensor([-2.4, -1.3,  0.0,  0.4, -0.2, -1.8, -1.0, -0.7, -1.9, -0.0, -1.6, -0.7,
        -2.0, -1.3, -1.6, -1.5, -1.1, -2.8, -0.4,  0.7, -1.7,  1.0, -1.1, -3.2])


we can do the same to get the key and value vector for the word at index 2:

In [None]:
query_2 = W_query @ x_2 # same as matmul
key_2 = W_key @ x_2
value_2 = W_value @ x_2

print(query_2)
print(key_2)
print(value_2)

tensor([-2.4, -1.3,  0.0,  0.4, -0.2, -1.8, -1.0, -0.7, -1.9, -0.0, -1.6, -0.7,
        -2.0, -1.3, -1.6, -1.5, -1.1, -2.8, -0.4,  0.7, -1.7,  1.0, -1.1, -3.2])
tensor([ 0.6, -2.3, -1.8, -1.3, -1.9, -0.6, -1.5, -3.0,  0.4, -1.9, -0.7, -2.1,
        -2.0, -0.9, -1.6, -2.1, -0.4, -0.2,  0.5, -1.1, -2.5, -0.4,  0.4, -3.0])
tensor([-1.1, -0.9, -3.0, -0.7, -2.2,  0.1,  0.0, -2.8, -2.1,  0.7, -0.7, -1.6,
        -2.6, -1.3, -0.9, -0.5, -1.8, -3.0, -0.7, -1.3,  0.5, -1.1, -1.8, -2.2,
         0.6, -0.0, -1.8, -1.3])


### Generalizing the calculation to all inputs in the sequence

We can generalize what we did for a single token or word to all of our inputs in our sequence now.

In [None]:
keys = (W_key @ sentence_embedded.T).T
values = (W_value @ sentence_embedded.T).T

print(f'keys: \n{keys}')
print(f'values: \n {values}')

keys: 
tensor([[ 0.9,  1.2,  2.2,  1.3,  0.8, -1.6, -0.2,  2.0, -0.4,  1.7,  0.1,  2.2,
         -0.2, -0.7, -0.2, -0.2,  0.8,  1.0,  0.7,  1.7,  2.8,  1.5, -0.9,  1.1],
        [ 0.9,  0.1,  0.7, -1.1,  1.3, -0.2,  1.0, -0.7,  1.5,  0.3, -0.3,  0.3,
          1.5, -1.1,  1.4,  0.4, -2.5,  0.4,  0.0, -0.1,  1.2,  1.3,  1.7,  0.5],
        [ 0.6, -2.3, -1.8, -1.3, -1.9, -0.6, -1.5, -3.0,  0.4, -1.9, -0.7, -2.1,
         -2.0, -0.9, -1.6, -2.1, -0.4, -0.2,  0.5, -1.1, -2.5, -0.4,  0.4, -3.0],
        [ 2.5,  2.0,  1.3,  2.5,  2.3,  3.6,  2.9,  1.0,  3.3,  2.8,  3.6,  1.1,
          3.1,  2.8,  1.8,  1.9,  0.4,  1.4,  2.4,  1.3,  2.2,  2.2,  2.4,  1.8],
        [-3.1, -2.5, -1.1, -3.5, -4.7, -6.2, -0.9, -3.2, -1.4, -3.5, -2.8, -2.3,
         -1.3, -3.1, -2.3,  0.4, -2.5, -3.9, -4.2, -1.6, -2.0, -1.7, -1.0, -5.0],
        [-1.1, -1.4,  0.9, -2.3, -2.7, -3.2, -1.4, -1.0, -0.8,  1.0, -2.0, -0.7,
         -0.7, -2.5, -2.9, -1.0, -1.0, -1.2, -3.1, -0.6,  1.4, -0.7, -0.9, -1.8]])
values: 
 tens

This is a matrix with one row per word in our input sequence, each such row representing the key or value vector for the correponding word.

If we want to get to the attention-vector for the second input element, that element will act as the query.  We will matrix-multiply that query with 

## Resources

- [Attention is all you need](https://arxiv.org/abs/1706.03762)
- [Thinking Like Transformers](https://arxiv.org/abs/2106.06981)