# Attention in Transformers

Implementing attention mechanisms for transformers. Based on the course [Attention in Transformers: Concepts and Code in PyTorch](https://learn.deeplearning.ai/courses/attention-in-transformers-concepts-and-code-in-pytorch) taught by Josh Starmer of StatQuest.

---

## Positional Encodings

The positional encoding in the transformer model allows the model to consider the order of words in a sequence when creating embeddings.

- A ***context aware*** or ***contextualized*** embedding can therefore be computed.
    - It is equivalent to summing the token embeddings with their positional encodings.
- This extends the concept of word embeddings to the sentence and document level.

***Encoder-only transformer*** models are used to produce contextualized embeddings for a number of downstream tasks like clustering and classification.

- They are "encoder-only" because they ony use the ***encoder*** component of the transformer.
- They use the ***self-attention*** mechanism, and can see all tokens in the sequence.

Of course, there are also ***decoder-only transformer*** models.

- These models only use the ***decoder*** component of the transformer.
- These models use the ***masked self-attention*** mechanism, and can only see *preceding* tokens in a sequence (i.e., no words that come after the word of interest).
- Decoder-only transformers are generative models (e.g., ChatGPT).

---

## Self-Attention

$Attention(Q,K,V) = SoftMax(\frac{QK^T}{\sqrt{d_k}})V$

- $Q \rightarrow$ query
- $K \rightarrow$ key
- $V \rightarrow$ value

The concepts ***query***, ***key***, and ***value*** come from database terminology:

- $query \rightarrow \{key: value\}$

In the context of transformers, we find the ***keys*** that are most similar to the ***queries*** and obtain the ***values*** of the associated ***keys***.

#### Example

Assume we have the following sequence consisting of 3 tokens: `Write a poem`. The tokens get encoded into 2D embeddings (for illustration purposes only). Their shape will be $(3, 2)$, since there are 3 tokens and 2 embedding dimensions:

- $Token Embeddings = \begin{bmatrix} 1.16 && 0.23 \\ 0.57 && 1.36 \\ 4.41 && -2.16 \end{bmatrix}$

We then create a matrix of $(d, d)$ weights for the query matrix, where $d$ is the desired dimensionality (here, we use 2 for illustration purposes):

- $QueryWeights^T = \begin{bmatrix} 0.54 && -0.17 \\ 0.59 && 0.65 \end{bmatrix}$
- **Note:** the $T$ indicates that this is the ***transpose*** of the weight matrix due to how PyTorch stores the weights.

Now we can calculate the query matrix:

- $Q = TokenEmbeddings \times QueryWeights^T$

And we repeat for the key matrix and value matrix:

- $KeyWeights^T = \begin{bmatrix} -0.15 && -0.34 \\ 0.14 && 0.42 \end{bmatrix}$

- $K = TokenEmbeddings \times KeyWeights^T$

- $ValueWeights^T = \begin{bmatrix} 0.62 && 0.61 \\ -0.52 && 0.13 \end{bmatrix}$

- $V = TokenEmbeddings \times ValueWeights^T$

Now we can solve the self-attention equation using our matrices:

- $QK^T \rightarrow$ computes the ***dot products*** between each token in the query and key matrices. 
    - The dot product is essentially an unscaled similarity measure.
    - Cosine similarity scales the dot product between $[-1, 1]$. 
- The dot product similarities in in $QK^T$ are scaled by $\sqrt{d_k}$, where $d_k$ is the dimensionality of the key matrix.
    - This scaling is nothing special, but improves performance.
- We take the softmax of $\frac{QK^T}{\sqrt{d_k}}$ to convert the scaled dot product similarities to probabilities.
    - The rows will sum to $1$, which can be interpreted as the ***percent similarity*** between the tokens at any given position.
- Finally, multiply the softmax percentages by the values in $V$.
    - The percentages from the softmax function tell us how much influence each word should have on the final encoding.

#### Implementation

Now, we can implement self-attention in PyTorch:

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# self-attention module:
class SelfAttention(nn.Module):
    def __init__(self,
                 d_model=2, # dimensionality of embeddings
                 row_dim=0, # keeps track of the row index
                 col_dim=1  # keeps track of the column index
                 ):
        
        super().__init__()

        # initialize weight matrices for Q,K,V:
        self.W_q = nn.Linear(
            in_features=d_model,  # embedding dimensions
            out_features=d_model, # embedding dimensions
            bias=False            # no intercept ("bias")
        )

        self.W_k = nn.Linear(
            in_features=d_model,
            out_features=d_model,
            bias=False
        )

        self.W_v = nn.Linear(
            in_features=d_model,
            out_features=d_model,
            bias=False
        )

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, token_encodings):
        # token_encodings = word embeddings + positional encodings
        # calculate Q and V:
        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)

        # calculate attention:
        sims = torch.matmul(
            q, k.transpose(dim0=self.row_dim, dim1=self.col_dim)
        )

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores


Test it out:

In [2]:
torch.manual_seed(42)

# test encoding matrix:
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

self_attn = SelfAttention(d_model=2, row_dim=0, col_dim=1)

self_attn(encodings_matrix)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

It worked! We can also validate the math:

In [3]:
# retrieve weight matrices:
W_k = self_attn.W_k.weight.transpose(0, 1)
W_q = self_attn.W_q.weight.transpose(0, 1)
W_v = self_attn.W_v.weight.transpose(0, 1)

print(f"W_Q: {W_q}\n\nW_K: {W_k}\n\nW_V: {W_v}")

W_Q: tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)

W_K: tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)

W_V: tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)


Manually calcuate:

In [4]:
print(torch.matmul(encodings_matrix, W_k) == self_attn.W_k(encodings_matrix))
print(torch.matmul(encodings_matrix, W_q) == self_attn.W_q(encodings_matrix))
print(torch.matmul(encodings_matrix, W_v) == self_attn.W_v(encodings_matrix))

tensor([[True, True],
        [True, True],
        [True, True]])
tensor([[True, True],
        [True, True],
        [True, True]])
tensor([[True, True],
        [True, True],
        [True, True]])


Same outputs!

---

## Masked Self-Attention

$ MaskedSelfAttention(Q,K,V,M) = SoftMax(\frac{QK^T}{\sqrt{d_k}}+M)V$

The only difference is the addition of the masking matrix, $M$.
- The purpose of the mask is to prevent tokens from observing any information about tokens that come after them in a sequence.
- For any token $t_i$, they can "see" the full context of tokens ***up to*** their position.

The mask matrix $M$ applies $-\infty$ to the positions of tokens we want to mask out, and $0$ otherwise. For example:

- $\begin{bmatrix} 0 && -\infty && -\infty \\ 0 && 0 && -\infty \\ 0 && 0 && 0 \end{bmatrix}$

$-\infty$ has the effect of producing $0\%$ similarities for masked positions after applying the softmax.

### Implementation

We can code masked self-attention in PyTorch:

In [5]:
# masked self-attention module:
class MaskedSelfAttention(nn.Module):
    def __init__(self,
                 d_model=2, # dimensionality of embeddings
                 row_dim=0, # keeps track of the row index
                 col_dim=1  # keeps track of the column index
                 ):
        
        super().__init__()

        # initialize weight matrices for Q,K,V:
        self.W_q = nn.Linear(
            in_features=d_model,  # embedding dimensions
            out_features=d_model, # embedding dimensions
            bias=False            # no intercept ("bias")
        )

        self.W_k = nn.Linear(
            in_features=d_model,
            out_features=d_model,
            bias=False
        )

        self.W_v = nn.Linear(
            in_features=d_model,
            out_features=d_model,
            bias=False
        )

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, token_encodings, mask=None):
        # Q, K, V:
        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)


        # calculate attention:
        sims = torch.matmul(
            q, k.transpose(dim0=self.row_dim, dim1=self.col_dim)
        )

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        # apply the mask:
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask, value=-1e9)

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

Test:

In [6]:
torch.manual_seed(42)

encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

masked_self_attn = MaskedSelfAttention(d_model=2, row_dim=0, col_dim=1)

# create the mask:
mask = torch.tril(torch.ones(3,3))
mask = mask == 0 # boolean

masked_self_attn(encodings_matrix, mask=mask)

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)

Again, we can manually validate the math:

In [7]:
W_k = masked_self_attn.W_k.weight.transpose(0, 1)
W_q = masked_self_attn.W_q.weight.transpose(0, 1)
W_v = masked_self_attn.W_v.weight.transpose(0, 1)

print(torch.matmul(encodings_matrix, W_k) == masked_self_attn.W_k(encodings_matrix))
print(torch.matmul(encodings_matrix, W_q) == masked_self_attn.W_q(encodings_matrix))
print(torch.matmul(encodings_matrix, W_v) == masked_self_attn.W_v(encodings_matrix))

tensor([[True, True],
        [True, True],
        [True, True]])
tensor([[True, True],
        [True, True],
        [True, True]])
tensor([[True, True],
        [True, True],
        [True, True]])


---

## Multi-Head Attention

We can apply attention multiple times simultaneously to the inputs: this is ***multi-head attention***.
- This is necessary when dealing with longer, or complicated, sequences. 
- Each head has its own associated $Q$, $K$, and $V$ matrices and output their contextualized embeddings.
- To get back to the original dimensionality of the embeddings space, the outputs of multi-head attention are passed through a fully connected neural network layer with $d$ outputs.
    - It's also possible to modify the shape of the value weight matrix.

--- 

## Encoder-Decoder Attention

The original transformer had both an ***encoder*** and a ***decoder***.
- The encoder uses self-attention.
- The decoder uses masked self-attention.
- The encoder and decoder are connected so they can calculate ***encoder-decoder attention***.

Encoder-decoder attention uses the output from the encoder to calculate **keys** and **values**, and the **queries** are calculated from the output of the decoder.
- After calculating $Q$, $K$, and $V$, encoder-decoder attention is calculated just like regular self-attention.

Encoder-only transformers took just the encoder component to do specific tasks (e.g., classification, clustering). Likewise, decoder-only transformers took just the decoder component to do specific tasks (e.g., text generation).
- While Seq2Seq / Encoder-Decoder models have fallen out of flavor, encoder-decoder attention is still used for ***multi-modal models***.
    - *e.g.*, a model might have an encoder trained on images, sound, etc. and a text-based decoder designed to produce captions, respond to audio prompts, etc.

---

## Brining it all Together

Let's bring all the different types of attention, as well as the ability to perform multi-head attention, together in a single class:

In [8]:
class Attention(nn.Module):
    def __init__(self,
                 d_model=2, # dimensionality of embeddings
                 row_dim=0, # keeps track of the row index
                 col_dim=1  # keeps track of the column index
                 ):
        
        super().__init__()

        # initialize weight matrices for Q,K,V:
        self.W_q = nn.Linear(
            in_features=d_model,  # embedding dimensions
            out_features=d_model, # embedding dimensions
            bias=False            # no intercept ("bias")
        )

        self.W_k = nn.Linear(
            in_features=d_model,
            out_features=d_model,
            bias=False
        )

        self.W_v = nn.Linear(
            in_features=d_model,
            out_features=d_model,
            bias=False
        )

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self,
                encodings_for_q, 
                encodings_for_k,
                encodings_for_v,
                mask=None):
        # we've modified the forward pass to accept 
        # different encodings to pass to Q, K, and V
        # Q, K, V:
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)


        # calculate attention:
        sims = torch.matmul(
            q, k.transpose(dim0=self.row_dim, dim1=self.col_dim)
        )

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        # apply the mask:
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask, value=-1e9)

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

Try it out with more flexible encodings:
- Values will be the same for illustrative purposes.

In [9]:
torch.manual_seed(42)

# encodings:
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

# attention:
attn = Attention(d_model=2, row_dim=0, col_dim=1)
attn(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

Now, let's enable **multi-head attention**:

In [10]:
class MultiHeadAttention(nn.Module):
    def __init__(self,
                 d_model=2,
                 row_dim=0,
                 col_dim=1,
                 num_heads=1):
        
        super().__init__()

        # attention heads:
        self.heads = nn.ModuleList(
            [
                Attention(d_model, row_dim, col_dim)
                for _ in range(num_heads)
            ]
        )

        self.col_dim = col_dim
    
    def forward(self,
                encodings_for_q,
                encodings_for_k,
                encodings_for_v):
        
        # compute attention across all heads:
        attn_outputs = torch.cat(
            [
                head(encodings_for_q, encodings_for_k, encodings_for_v)
                for head in self.heads
            ], dim=self.col_dim
        )

        return attn_outputs

Test it out:

In [11]:
torch.manual_seed(42)

N_HEADS = 8

multi_head_attention = MultiHeadAttention(d_model=2,
                                          row_dim=0,
                                          col_dim=1,
                                          num_heads=N_HEADS)

encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])


outputs = multi_head_attention(encodings_for_q, encodings_for_k, encodings_for_v)
print(outputs)

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268,  0.6226,  0.1312,  1.0106,  0.8625,
          0.3422,  0.7333, -0.8037,  1.4087, -0.6674,  0.5665,  0.7700, -0.9269],
        [ 0.2040,  0.7057, -0.7417, -0.9193,  0.5522,  0.2499,  1.4153,  1.0420,
          0.6753,  2.1341, -0.7498,  0.9677, -0.5970,  1.5640,  0.7713, -0.9210],
        [ 3.4989,  2.2427, -0.7190, -0.8447,  0.5669,  0.2324,  0.3679,  0.5894,
          0.1412, -0.1826, -0.9414,  2.2589, -0.7832, -0.0405,  0.7669, -0.8751]],
       grad_fn=<CatBackward0>)


The outputs will have shape `(sequence_length, N_HEADS*d_model)`, since each attention head will output `d_model` features per token:

In [12]:
outputs.shape

torch.Size([3, 16])

**TRIPLE BAM!!!**