In the previous notebook we prepared and massaged our text into input vectors that encode both tokens as well as their positions.  It's these input vectors that will be handled by the rest of the model.

# Self-Attention



Self-Attention is a technique whereby each position in the input sequence can consider the relevancy of each other position in the same sequence when the representation for the sequence is being computed.  (Traditional "attention", as opposed to self-attention look at relations between two different sequences, input- and output, as opposed as using a single sequence.)

## Simplified Self-Attention

Let's take a simple sentence:
"I am learning this"

This is our input sequencen let's call it $x$ with 4 tokens: $x^{(1)}$, $x^{(2)}$, $x^{(3)}$ and $x^{(4)}$.

In general if our context length is $T$ then we have $x^{(1)} \ldots x^{(T)}$

Each $x^{(i)}$ is a $d$-dimensional embedding vector representing a token.

Now we will calculate a context vector $z^{i}$ for each $x^{i}$.  This vector will contain information from all vectors $x^{(1)} \ldots x^{(T)}$

### Attention weights

So, for example, for $x^{(3)}$ we'll calculate a context vector $z^{(3)}$.  We'll call $x^{(3)}$ our "query" vector and for this query we'll calculate an attention score for each of the tokens in our sentence: for example $w_{31}$ for the attention score between our query (the 3rd token) and the first token.  In this way we'll have: $w_{31}$, $w_{32}$, $w_{33}$ and $w_{34}$

The attention score $w_{31}$ is the dot product from the third token with the first token.  So if:
$$ x^{(3)} = [0.2,0.7,0.9] \text{  (representing the token "learning")} $$
$$ x^{(1)} = [0.8,0.9,0.3] \text{  (representing the token "I")} $$

Then:

$$w_{(31)} = x^{(3)} \cdot x^{(1)} = [(.2*.8) + (.7*.9) + (.9 * .3)] = 1.06 $$

We'll do this between our query $x^{(3)}$ and each of our $x^{(1)},x^{(2)},x^{(3)},x^{(4)}$ which result in an attention vector $w_3 = [ w_{(31)},w_{(32)},w_{(33)},w_{(34)} ]$

Now in python, using an input vector of our 4 tokens, each with an embedding dimension of 3:

In [2]:
#|echo: false
import torch
torch.manual_seed(42)

inputs = torch.rand((4,3))
print(inputs)

tensor([[0.8823, 0.9150, 0.3829],
        [0.9593, 0.3904, 0.6009],
        [0.2566, 0.7936, 0.9408],
        [0.1332, 0.9346, 0.5936]])


We'll calculate the attention vector for $x^{(3)}$ as our query:

In [4]:
query = inputs[2]
attention_score_for_x3 = torch.empty((4))
for i, xi in enumerate(inputs):
    # Calculate the attention score for x3i
    attention_score_for_x3[i] = torch.dot(xi, query)

print(attention_score_for_x3)

tensor([1.3127, 1.1213, 1.5807, 1.3343])


### Normalizing

Now that we've calculated the attention vector $w_3$ for query $x^{(3)}$, we have a vector with a number for every token.  These numbers however are not normalized between $0$ and $1$, which is what we'd really like.  For this, we'll use [Cross-Entropy Loss](../theory/cross-entropy-loss.ipynb), which will make sure the numbers in our attention vector add up to one and each are between $0 \ldots 1$

In [6]:
attention_weights_for_x3 = torch.softmax(attention_score_for_x3, dim=0)
print(attention_weights_for_x3)

tensor([0.2407, 0.1987, 0.3146, 0.2459])


### Context Vector

So now that we have our normalized attention weights for a single query $x^{(3)}$ we can calculate the full context vector that corresponds to $x^{(3)}$

Our attention weights from the previous step were: $[.2, .1, .3, .2]$ or more in general, a vector $[\alpha _{31}, \alpha _{32}, \alpha _{33}, \alpha _{34}]$

To calculate the context vector for $x^{(3)}$ we'll take each $\alpha _{3i}$ and multiply that by $x^{(i)}$.  Then we'll add up all those vectors.

In [None]:
query = inputs[2]
context_vector_for_x3 = torch.zeros((query.shape)) # dimension of query (3 in this case)

for i, xi in enumerate(inputs):
    context_vector_for_x3 += attention_weights_for_x3[i] * xi # z3

print(context_vector_for_x3)

tensor([0.5165, 0.7774, 0.6536])


### All context vectors

What we did so far is to look at how to calculate a single context vector $z^{(3)}$, for a single token $x^{(3)}$ in our input sequence.  We'll need to make this more scalable and figure out a way to calculate $z$ for all tokens in our input sequence.