<a href="https://colab.research.google.com/github/rahiakela/natural-language-processing-case-studies/blob/bert-transformer-labs/Illustrated_self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Illustrated: Self-Attention
Step-by-step guide to self-attention with illustrations and code

[Illustrated: Self-Attention](https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a)

[Raimi Karim](https://towardsdatascience.com/@remykarem)

> Colab made by: [Manuel Romero](https://twitter.com/mrm8488)

![texto alternativo](https://miro.medium.com/max/1973/1*_92bnsMJy8Bl539G4v93yg.gif)

What do *BERT, RoBERTa, ALBERT, SpanBERT, DistilBERT, SesameBERT, SemBERT, MobileBERT, TinyBERT and CamemBERT* all have in common? And I’m not looking for the answer “BERT” 🤭.

Answer: **self-attention** 🤗. We are not only talking about architectures bearing the name “BERT’, but more correctly **Transformer-based architectures**. Transformer-based architectures, which are primarily used in modelling language understanding tasks, eschew the use of recurrence in neural network (RNNs) and instead trust entirely on self-attention mechanisms to draw global dependencies between inputs and outputs. But what’s the math behind this?

That’s what we’re going to find out today. The main content of this post is to walk you through the mathematical operations involved in a self-attention module. By the end of this article, you should be able to write or code a self-attention module from scratch.

## Step 0. What is self-attention?

If you’re thinking if self-attention is similar to attention, then the answer is yes! They fundamentally share the same concept and many common mathematical operations.
A self-attention module takes in n inputs, and returns n outputs. What happens in this module? In layman’s terms, the self-attention mechanism allows the inputs to interact with each other (“self”) and find out who they should pay more attention to (“attention”). The outputs are aggregates of these interactions and attention scores.

The illustrations are divided into the following steps:
1. Prepare inputs
2. Initialise weights
3. Derive key, query and value
4. Calculate attention scores for Input 1
5. Calculate softmax
6. Multiply scores with values
7. Sum weighted values to get Output 1
8. Repeat steps 4–7 for Input 2 & Input 3

In [1]:
import torch

## Step 1: Prepare inputs


For this tutorial, for the shake of simplicity, we start with 3 inputs, each with dimension 4.

![texto alternativo](https://miro.medium.com/max/1973/1*hmvdDXrxhJsGhOQClQdkBA.png)

In [2]:
x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
]

x = torch.tensor(x, dtype=torch.float32)
x

tensor([[1., 0., 1., 0.],
        [0., 2., 0., 2.],
        [1., 1., 1., 1.]])

## Step 2: Initialise weights

Every input must have three representations (see diagram below). These representations are called **key** (orange), **query** (red), and **value** (purple). For this example, let’s take that we want these representations to have a dimension of 3. Because every input has a dimension of 4, this means each set of the weights must have a shape of 4×3.

![texto del enlace](https://miro.medium.com/max/1975/1*VPvXYMGjv0kRuoYqgFvCag.gif)

>**Note**
We’ll see later that the dimension of value is also the dimension of the output.


In order to obtain these representations, every input (green) is multiplied with a set of weights for keys, a set of weights for querys (I know that’s not the right spelling), and a set of weights for values. In our example, we initialise the three sets of weights as follows.

In [3]:
# Weights for key
w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]

# Weights for query
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]

# Weights for value
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]

w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)

print("Weights for key: \n", w_key)
print("Weights for query: \n", w_query)
print("Weights for value: \n", w_value)

Weights for key: 
 tensor([[0., 0., 1.],
        [1., 1., 0.],
        [0., 1., 0.],
        [1., 1., 0.]])
Weights for query: 
 tensor([[1., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 1.]])
Weights for value: 
 tensor([[0., 2., 0.],
        [0., 3., 0.],
        [1., 0., 3.],
        [1., 1., 0.]])


>**Note**: *In a neural network setting, these weights are usually small numbers, initialised randomly using an appropriate random distribution like Gaussian, Xavier and Kaiming distributions.*

## Step 3: Derive key, query and value

Now that we have the three sets of weights, let’s actually obtain the **key**, **query** and **value** representations for every input.

### Obtaining the keys


Key representation for Input 1:

```
               [0, 0, 1]
[1, 0, 1, 0] x [1, 1, 0] = [0, 1, 1]
               [0, 1, 0]
               [1, 1, 0]

               (1x4)(4x3)    = (1x3)
input vector x weight metrix = key vector
```

Use the same set of weights to get the key representation for Input 2:

```
               [0, 0, 1]
[0, 2, 0, 2] x [1, 1, 0] = [4, 4, 0]
               [0, 1, 0]
               [1, 1, 0]
          
```

Use the same set of weights to get the key representation for Input 3:

```
               [0, 0, 1]
[1, 1, 1, 1] x [1, 1, 0] = [2, 3, 1]
               [0, 1, 0]
               [1, 1, 0]
```

A faster way is to vectorise the above operations:

```
               [0, 0, 1]
[1, 0, 1, 0]   [1, 1, 0]   [0, 1, 1]
[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]
[1, 1, 1, 1]   [1, 1, 0]   [2, 3, 1]

           (3x4)(4x3)    = (3x3)
input metrix x weight metrix = key metrix
```
![texto alternativo](https://miro.medium.com/max/1975/1*dr6NIaTfTxEWzxB2rc0JWg.gif)
     

### Obtaining the values

Let’s do the same to obtain the value representations for every input:

```
               [0, 2, 0]
[1, 0, 1, 0]   [0, 3, 0]   [1, 2, 3] 
[0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]
[1, 1, 1, 1]   [1, 1, 0]   [2, 6, 3]

           (3x4)(4x3)    = (3x3)
input metrix x weight metrix = value metrix
```

![texto alternativo](https://miro.medium.com/max/1975/1*5kqW7yEwvcC0tjDOW3Ia-A.gif)


### Obtaining the querys

and finally the query representations:

```
               [1, 0, 1]
[1, 0, 1, 0]   [1, 0, 0]   [1, 0, 2]
[0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]
[1, 1, 1, 1]   [0, 1, 1]   [2, 1, 3]

           (3x4)(4x3)    = (3x3)
input metrix x weight metrix = query metrix
```

![texto alternativo](https://miro.medium.com/max/1975/1*wO_UqfkWkv3WmGQVHvrMJw.gif)

### Calculating key, value and query

>**Notes**: 
In practice, a bias vector may be added to the product of matrix multiplication.

In [4]:
keys = x @ w_key
querys = x @ w_query
values = x @ w_value

print("Keys: \n", keys)
# tensor([[0., 1., 1.],
#         [4., 4., 0.],
#         [2., 3., 1.]])

print("Querys: \n", querys)
# tensor([[1., 0., 2.],
#         [2., 2., 2.],
#         [2., 1., 3.]])
print("Values: \n", values)
# tensor([[1., 2., 3.],
#         [2., 8., 0.],
#         [2., 6., 3.]])

Keys: 
 tensor([[0., 1., 1.],
        [4., 4., 0.],
        [2., 3., 1.]])
Querys: 
 tensor([[1., 0., 2.],
        [2., 2., 2.],
        [2., 1., 3.]])
Values: 
 tensor([[1., 2., 3.],
        [2., 8., 0.],
        [2., 6., 3.]])


## Step 4: Calculate attention scores
        

![texto alternativo](https://miro.medium.com/max/1973/1*u27nhUppoWYIGkRDmYFN2A.gif)

To obtain **attention scores**, we start off with taking a dot product between Input 1’s **query** (red) with **all keys** (orange), including itself. Since there are 3 key representations (because we have 3 inputs), we obtain 3 attention scores (blue).

```
            [0, 4, 2]
[1, 0, 2] x [1, 4, 3] = [2, 4, 4]
            [1, 0, 1]

         (1x3)(3x3)   = (1x3)
query vector x transpose(key metrix)  = attention score
```

Notice that we only use the **query** from Input 1. Later we’ll work on repeating this same step for the other **querys**.

>**Note**: *The above operation is known as dot product attention, one of the several [score functions](https://towardsdatascience.com/attn-illustrated-attention-5ec4ad276ee3). Other score functions include scaled dot product and additive/concat.*   

Let's calculate it for all queries
```
[1, 0, 2]   [0, 4, 2]   [2,  4,  4]
[2, 2, 2] x [1, 4, 3] = [4, 16, 12]
[2, 1, 3]   [1, 0, 1]   [4, 12, 10]


         (3x3)(3x3)   = (3x3)
query metrix x transpose(key metrix)  = attention score metrix
```

In [5]:
attn_scores = querys @ keys.T
print(attn_scores)

# tensor([[ 2.,  4.,  4.],  # attention scores from Query 1
#         [ 4., 16., 12.],  # attention scores from Query 2
#         [ 4., 12., 10.]]) # attention scores from Query 3

tensor([[ 2.,  4.,  4.],
        [ 4., 16., 12.],
        [ 4., 12., 10.]])


## Step 5: Calculate softmax


![texto alternativo](https://miro.medium.com/max/1973/1*jf__2D8RNCzefwS0TP1Kyg.gif)

Take the **[softmax](https://en.wikipedia.org/wiki/Softmax_function)** across these **attention scores** (blue).
```
softmax([2, 4, 4]) = [0.0, 0.5, 0.5]
```

In [6]:
from torch.nn.functional import softmax

print(attn_scores)
attn_scores_softmax = softmax(attn_scores, dim=-1)
print(attn_scores_softmax)
# tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
#         [6.0337e-06, 9.8201e-01, 1.7986e-02],
#         [2.9539e-04, 8.8054e-01, 1.1917e-01]])

# For readability, approximate the above as follows
attn_scores_softmax = [
  [0.0, 0.5, 0.5],
  [0.0, 1.0, 0.0],
  [0.0, 0.9, 0.1]
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)
print(attn_scores_softmax)

tensor([[ 2.,  4.,  4.],
        [ 4., 16., 12.],
        [ 4., 12., 10.]])
tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
        [6.0337e-06, 9.8201e-01, 1.7986e-02],
        [2.9539e-04, 8.8054e-01, 1.1917e-01]])
tensor([[0.0000, 0.5000, 0.5000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 0.9000, 0.1000]])


In [8]:
import numpy as np

attn_scores1 = [
  [ 2.,  4.,  4.],
  [ 4., 16., 12.],
  [ 4., 12., 10.]
]

attn_scores_softmax1 = np.exp(attn_scores1) / np.sum(np.exp(attn_scores1))
print(attn_scores_softmax1)

[[8.00212069e-07 5.91281187e-06 5.91281187e-06]
 [5.91281187e-06 9.62338462e-01 1.76258438e-02]
 [5.91281187e-06 1.76258438e-02 2.38539856e-03]]


## Step 6: Multiply scores with values

![texto alternativo](https://miro.medium.com/max/1973/1*9cTaJGgXPbiJ4AOCc6QHyA.gif)

The softmaxed attention scores for each input (blue) is multiplied with its corresponding **value** (purple). This results in 3 alignment vectors (yellow). In this tutorial, we’ll refer to them as **weighted values**.
```
1: 0.0 * [1, 2, 3] = [0.0, 0.0, 0.0]
2: 0.5 * [2, 8, 0] = [1.0, 4.0, 0.0]
3: 0.5 * [2, 6, 3] = [1.0, 3.0, 1.5]

(3x1).T  * (3x3)   = (3x3) 

transpose(softmax_attention_score) * value metrix = weighted values
``` 

Let's calculate it all together.

```
[[[0.0],   * [[[1., 2., 3.]], = [0.0, 0.0, 0.0]
  [0.0],   *  [[2., 8., 0.]], = [1.0, 4.0, 0.0]
  [0.0]],  *  [[2., 6., 3.]]] = [1.0, 3.0, 1.5]
 [[0.5],
  [1.0],
  [0.9]],
 [[0.5],
  [0.0],
  [0.1]]
]

(3x3x1).T  * (3x1x3)   = (3x3) 
(3x3x1)    * (3x1x3)   = (3x3) 

transpose(softmax_attention_score) * value metrix = weighted values
```

In [9]:
weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]
print(weighted_values)

tensor([[[0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[1.0000, 4.0000, 0.0000],
         [2.0000, 8.0000, 0.0000],
         [1.8000, 7.2000, 0.0000]],

        [[1.0000, 3.0000, 1.5000],
         [0.0000, 0.0000, 0.0000],
         [0.2000, 0.6000, 0.3000]]])


In [28]:
weighted_values.shape

torch.Size([3, 3])

In [23]:
values[:, None]

tensor([[[1., 2., 3.]],

        [[2., 8., 0.]],

        [[2., 6., 3.]]])

In [21]:
values[:, None].shape

torch.Size([3, 1, 3])

In [24]:
attn_scores_softmax.T[:, :, None]

tensor([[[0.0000],
         [0.0000],
         [0.0000]],

        [[0.5000],
         [1.0000],
         [0.9000]],

        [[0.5000],
         [0.0000],
         [0.1000]]])

In [32]:
attn_scores_softmax.T[:, :, None].shape

torch.Size([3, 3, 1])

## Step 7: Sum weighted values


![texto alternativo](https://miro.medium.com/max/1973/1*1je5TwhVAwwnIeDFvww3ew.gif)

Take all the **weighted values** (yellow) and sum them element-wise:

```
  [0.0, 0.0, 0.0]
+ [1.0, 4.0, 0.0]
+ [1.0, 3.0, 1.5]
-----------------
= [2.0, 7.0, 1.5]
```

The resulting vector ```[2.0, 7.0, 1.5]``` (dark green) **is Output 1**, which is based on the **query representation from Input 1** interacting with all other keys, including itself.

## Step 8: Repeat for Input 2 & Input 3

![texto alternativo](https://miro.medium.com/max/1973/1*G8thyDVqeD8WHim_QzjvFg.gif)

Note: *The dimension of **query** and **key** must always be the same because of the dot product score function. However, the dimension of **value** may be different from **query** and **key**. The resulting output will consequently follow the dimension of **value**.*

In [33]:
outputs = weighted_values.sum(dim=0)
print(outputs)

# tensor([[2.0000, 7.0000, 1.5000],  # Output 1
#         [2.0000, 8.0000, 0.0000],  # Output 2
#         [2.0000, 7.8000, 0.3000]]) # Output 3

tensor([2.0000, 8.0000, 0.3000])


## Bonus: Tensorflow 2 implementation

In [34]:
%tensorflow_version 2.x
import tensorflow as tf

### Step 1: Prepare inputs

In [35]:
x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
 ]

x = tf.convert_to_tensor(x, dtype=tf.float32)
print(x)

tf.Tensor(
[[1. 0. 1. 0.]
 [0. 2. 0. 2.]
 [1. 1. 1. 1.]], shape=(3, 4), dtype=float32)


### Step 2: Initialise weights

In [36]:
w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]
w_key = tf.convert_to_tensor(w_key, dtype=tf.float32)
w_query = tf.convert_to_tensor(w_query, dtype=tf.float32)
w_value = tf.convert_to_tensor(w_value, dtype=tf.float32)
print("Weights for key: \n", w_key)
print("Weights for query: \n", w_query)
print("Weights for value: \n", w_value)

Weights for key: 
 tf.Tensor(
[[0. 0. 1.]
 [1. 1. 0.]
 [0. 1. 0.]
 [1. 1. 0.]], shape=(4, 3), dtype=float32)
Weights for query: 
 tf.Tensor(
[[1. 0. 1.]
 [1. 0. 0.]
 [0. 0. 1.]
 [0. 1. 1.]], shape=(4, 3), dtype=float32)
Weights for value: 
 tf.Tensor(
[[0. 2. 0.]
 [0. 3. 0.]
 [1. 0. 3.]
 [1. 1. 0.]], shape=(4, 3), dtype=float32)


### Step 3: Derive key, query and value

In [37]:
keys = tf.matmul(x, w_key)
querys = tf.matmul(x, w_query)
values = tf.matmul(x, w_value)
print(keys)
print(querys)
print(values)

tf.Tensor(
[[0. 1. 1.]
 [4. 4. 0.]
 [2. 3. 1.]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[1. 0. 2.]
 [2. 2. 2.]
 [2. 1. 3.]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[1. 2. 3.]
 [2. 8. 0.]
 [2. 6. 3.]], shape=(3, 3), dtype=float32)


### Step 4: Calculate attention scores

In [38]:
attn_scores = tf.matmul(querys, keys, transpose_b=True)
print(attn_scores)

tf.Tensor(
[[ 2.  4.  4.]
 [ 4. 16. 12.]
 [ 4. 12. 10.]], shape=(3, 3), dtype=float32)


### Step 5: Calculate softmax

In [39]:
attn_scores_softmax = tf.nn.softmax(attn_scores, axis=-1)
print(attn_scores_softmax)

# For readability, approximate the above as follows
attn_scores_softmax = [
  [0.0, 0.5, 0.5],
  [0.0, 1.0, 0.0],
  [0.0, 0.9, 0.1]
]
attn_scores_softmax = tf.convert_to_tensor(attn_scores_softmax)
print(attn_scores_softmax)

tf.Tensor(
[[6.3378938e-02 4.6831051e-01 4.6831051e-01]
 [6.0336647e-06 9.8200780e-01 1.7986100e-02]
 [2.9538720e-04 8.8053685e-01 1.1916770e-01]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[0.  0.5 0.5]
 [0.  1.  0. ]
 [0.  0.9 0.1]], shape=(3, 3), dtype=float32)


### Step 6: Multiply scores with values

In [40]:
weighted_values = values[:, None] * tf.transpose(attn_scores_softmax)[:, :, None]
print(weighted_values)

tf.Tensor(
[[[0.  0.  0. ]
  [0.  0.  0. ]
  [0.  0.  0. ]]

 [[1.  4.  0. ]
  [2.  8.  0. ]
  [1.8 7.2 0. ]]

 [[1.  3.  1.5]
  [0.  0.  0. ]
  [0.2 0.6 0.3]]], shape=(3, 3, 3), dtype=float32)


### Step 7: Sum weighted values

In [41]:
outputs = tf.reduce_sum(weighted_values, axis=0)  # 6
print(outputs)

tf.Tensor(
[[2.        7.        1.5      ]
 [2.        8.        0.       ]
 [2.        7.7999997 0.3      ]], shape=(3, 3), dtype=float32)


>**Notes**

>The dimension of query and key must always be the same because of the dot product score function. However, the dimension of value may be different from query and key. The resulting output will consequently follow the dimension of value.

##Extending to Transformers

So, where do we go from here? Transformers! Indeed we live in exciting times of deep learning research and high compute resources. Transformer is the incarnation from [Attention Is All You Need](https://towardsdatascience.com/attn-illustrated-attention-5ec4ad276ee3), orginally born to perform neural machine translation. Researchers picked up from here, reassembling, cutting, adding and extending the parts, and extend its usage to more language tasks.

Here I will briefly mention how we can extend self-attention to a Transformer architecture.

Within the self-attention module:
- Dimension
- Bias

Inputs to the self-attention module:
- Embedding module
- Positional encoding
- Truncating
- Masking

Adding more self-attention modules:

- Multihead
- Layer stacking

Modules between self-attention modules:

- Linear transformations
- LayerNorm



## References

[Attention Is All You Need](https://arxiv.org/abs/1706.03762) (arxiv.org)

[The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/) (jalammar.github.io)

[Attn: Illustrated Attention](https://towardsdatascience.com/attn-illustrated-attention-5ec4ad276ee3) (towardsdatascience.com)