# Attention

### Recommended Material:

<a href="https://www.youtube.com/watch?v=eMlx5fFNoYc">Great Attention Visual Explainer Video by 3B1B</a> (Highly Recommended)

<a href="https://nlp.seas.harvard.edu/annotated-transformer/#:~:text=interactive()%0A%20%20%20%20)%0A%0A%0Ashow_example(example_mask)-,Attention,-An%20attention%20function">Attention Section from "The Annotated Transformer"</a>

## Motivation for Attention
Context is really important. Many words can take on very different meaning depending on the contexts they appear in.

For example the word "bar" could refer to
- the place where you get drinks
- a long rod of material
- a measure of music
- the law exam

the list goes on.

With word embeddings alone, regardless of the context, or the other tokens in the sequence, the string "bar" will be mapped to the same word embedding despite having many unrelated meanings.

In order to determine what definition bar refers to we need to be able to gather information from the other words in the text.

Another example where this is useful is pronouns, since there needs to be some way to determine which proper noun is being referred to.

In general, the model would benefit from some way to share information between tokens, to add context.

Since all the information about a token is stored in that token's high-dimensional vector (initially just the word embedding + positional embedding), we want to add/subtract to this token's vector based on each token around it to store this additional information. The 3B1B video has some nice visual intuition for this.

## Big Idea
In a nutshell, the self-attention mechanism allows each token to look backwards at the tokens that come before it and add information to it's vector by adding or subtracting.

It can be thought of as the following steps:

Each token emits a **Query** vector that roughly represents "What am I looking for?"

Each token also emits a **Key** vector that roughly represents "What am I in the context of this attention head?" or "What is my answer to the query?"

Each token also emits a **Value** vector that roughly represents "If someone finds me to be a match (my key matches their query), what information should I give to them / add to their vector?"

## Details

### Obtaining Q, K, V
The way that the attention mechanism computes the Query, Key, and Value vectors is that each token (the vector) is passed through 3 separate linear layers in in parallel. The weights and biases for these layers are parameters learned by the model. For example to get the query, Q = W_Q * residual + b_Q. (The vector dimensions go from d_model to d_head, sometimes in multiheaded attention the model dimension is broken up between each attention head)

In practice, all the vectors are passed in at once with some matrix multiplication.

### Attending to other tokens

Once we obtain these matrices containing the queries, keys, and values of each token. We want to take the dot product of the queries and keys to get a vector for each token that represents how much the other tokens' keys "answered"/matched my query.

The dot product will be higher between some query and key vector when that query and key vector are aligned. Sometimes this is referred to as the token with the query "attending" to the other token with an aligned key.

#### Rescaling

We rescale the attention matrix by dividing it by the square root of d_head or d_k, a.k.a. the head dimension.

#### Causal Masking

We apply a causal mask such that only the upper triangular values of the matrix are left. In self-attention we don't want any token to attend to tokens in front of it, because during training we have the model predict the next token for each token in parallel. Seeing forward, would allow the model to know the ground-truth answer.

Since we apply this mask before softmaxing, for the values we want to remove we will fill them with negative infinity so that after softmaxing, it will be 0.

#### Softmax

We want the attention values for a particular token to add up to 1, so we will apply softmax.

#### Dotting with Values

Then to find out what we should add, we take the dot product of this with the Value matrix such that the attention score with the other token is dotted with the value of the other token and these are summed up and added to each token.

$$Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V$$

After this, the values are passed through another linear layer and in the case of multiheaded attention, the outputs are added across the heads and finally added back onto the residual stream.

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/transformer-attn-21.png" width="1400">



In [1]:
import einops
import torch
from torch import nn
from torch import Tensor

### References for Einops Einsum
https://einops.rocks/api/einsum/

Understanding Einsum in general (video uses np):
https://www.youtube.com/watch?v=pkVwUVEHmfI

`einsum` and `einops` streamline tensor operations in attention mechanisms by eliminating the need for repetitive and error-prone tensor transpositions and reshaping. Specifically:

`einsum`: Enables concise, readable, and efficient tensor contractions and summations across multiple dimensions by specifying how indices are related. This is particularly useful in attention mechanisms for computing dot products, such as the query-key dot product, without manual transpositions.

You can find an implementation without einops in Karpathy's GPT from scratch but you would basically do the transposing and then dot products rather than being able to notate it the two operations in one step.

# Implementation

In [1]:
def hidden_apply_causal_mask(
    attn_scores: Tensor, masked_value: float = float("-inf")
) -> Tensor:
    # Define a mask that is True for all positions we want to set probabilities to zero for
    all_ones = torch.ones(
        attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device
    )
    mask = torch.triu(all_ones, diagonal=1).bool()
    # Apply the mask to attention scores, then return the masked scores
    attn_scores.masked_fill_(mask, masked_value)

    return attn_scores


class HiddenAttention(nn.Module):
    def __init__(self, num_heads: Tensor, dim_model: Tensor, dim_head: Tensor) -> None:
        super().__init__()

        # hyper parameters
        self.num_heads = num_heads
        self.dim_model = dim_model
        self.dim_head = dim_head

        # weights
        self.W_Q = nn.Parameter(torch.ones((num_heads, dim_model, dim_head)))
        self.W_K = nn.Parameter(torch.ones((num_heads, dim_model, dim_head)))
        self.W_V = nn.Parameter(torch.ones((num_heads, dim_model, dim_head)))
        self.W_O = nn.Parameter(torch.ones((num_heads, dim_head, dim_model)))

        # biases
        self.b_Q = nn.Parameter(torch.zeros((num_heads, dim_head)))
        self.b_K = nn.Parameter(torch.zeros((num_heads, dim_head)))
        self.b_V = nn.Parameter(torch.zeros((num_heads, dim_head)))
        self.b_O = nn.Parameter(torch.zeros((dim_model)))

    """
    Forward pass of the attention layer.
    Takes a tensor of shape [batch, tokens, dim_model]
    Outputs a tensor of shape [batch, tokens, dim_model]
    """

    def forward(self, x: Tensor) -> Tensor:
        batch_size = x.shape[0]
        tokens_size = x.shape[1]

        # Calculate query, key and value vectors
        q = (
            einops.einsum(
                x,
                self.W_Q,
                "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
            )
            + self.b_Q
        )
        assert q.shape == torch.Size(
            [batch_size, tokens_size, self.num_heads, self.dim_head]
        )
        k = (
            einops.einsum(
                x,
                self.W_K,
                "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
            )
            + self.b_K
        )
        assert k.shape == torch.Size(
            [batch_size, tokens_size, self.num_heads, self.dim_head]
        )
        v = (
            einops.einsum(
                x,
                self.W_V,
                "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
            )
            + self.b_V
        )
        assert v.shape == torch.Size(
            [batch_size, tokens_size, self.num_heads, self.dim_head]
        )

        # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
        attn_scores = einops.einsum(
            q,
            k,
            "batch posn_Q nheads d_head, batch posn_K nheads d_head -> batch nheads posn_Q posn_K",
        )
        assert attn_scores.shape == torch.Size(
            [batch_size, self.num_heads, tokens_size, tokens_size]
        )
        attn_scores_masked = hidden_apply_causal_mask(
            attn_scores / self.dim_head**0.5, float("-inf")
        )
        attn_pattern = attn_scores_masked.softmax(-1)

        # Take weighted sum of value vectors, according to attention probabilities
        z = einops.einsum(
            v,
            attn_pattern,
            "batch posn_K nheads d_head, batch nheads posn_Q posn_K -> batch posn_Q nheads d_head",
        )
        assert z.shape == torch.Size(
            [batch_size, tokens_size, self.num_heads, self.dim_head]
        )

        # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
        attn_out = (
            einops.einsum(
                z,
                self.W_O,
                "batch posn_Q nheads d_head, nheads d_head d_model -> batch posn_Q d_model",
            )
            + self.b_O
        )
        assert attn_out.shape == torch.Size([batch_size, tokens_size, self.dim_model])

        return attn_out

In [3]:
"""
Applies a causal mask to attention scores, and returns masked scores.
Takes an input of size [batch, n_heads, query_pos, key_pos]
And outputs a tensor of size [batch, n_heads, query_pos, key_pos]
"""


def apply_causal_mask(
    attn_scores: Tensor, masked_value: float = float("-inf")
) -> Tensor:
    # Define a mask that is True for all positions we want to set probabilities to zero for
    mask = None
    # Apply the mask to attention scores and replace the masked values with the ignore value masked_value
    masked_attn_scores = None

    return masked_attn_scores

In [4]:
"""
Test case for your apply_causal_mask
"""
ignore = float("-inf")
test1 = apply_causal_mask(
    torch.tensor(
        [
            [1.0, 2, 3],
            [4, 5, 6],
            [7, 8, 9],
        ]
    ),
    ignore,
)

assert torch.allclose(
    test1, torch.tensor([[1.0, ignore, ignore], [4.0, 5.0, ignore], [7.0, 8.0, 9.0]])
), "Oh no it looks like your matrix doesnt pass test 1"
print(test1)

TypeError: allclose(): argument 'input' (position 1) must be Tensor, not NoneType

In [None]:
class Attention(nn.Module):
    def __init__(self, num_heads: Tensor, dim_model: Tensor, dim_head: Tensor) -> None:
        super().__init__()

        # hyper parameters
        self.num_heads = num_heads
        self.dim_model = dim_model
        self.dim_head = dim_head

        # weights
        self.W_Q = nn.Parameter(torch.ones((num_heads, dim_model, dim_head)))
        self.W_K = nn.Parameter(torch.ones((num_heads, dim_model, dim_head)))
        self.W_V = nn.Parameter(torch.ones((num_heads, dim_model, dim_head)))
        self.W_O = nn.Parameter(torch.ones((num_heads, dim_head, dim_model)))

        # biases
        self.b_Q = nn.Parameter(torch.zeros((num_heads, dim_head)))
        self.b_K = nn.Parameter(torch.zeros((num_heads, dim_head)))
        self.b_V = nn.Parameter(torch.zeros((num_heads, dim_head)))
        self.b_O = nn.Parameter(torch.zeros((dim_model)))

    """
    Forward pass of the attention layer.
    Takes a tensor of shape [batch, tokens, dim_model]
    Outputs a tensor of shape [batch, tokens, dim_model]
    """

    def forward(self, x: Tensor) -> Tensor:
        batch_size = x.shape[0]
        tokens_size = x.shape[1]

        # Calculate query, key and value vectors
        q = None
        assert q.shape == torch.Size(
            [batch_size, tokens_size, self.num_heads, self.dim_head]
        )
        k = None
        assert k.shape == torch.Size(
            [batch_size, tokens_size, self.num_heads, self.dim_head]
        )
        v = None
        assert v.shape == torch.Size(
            [batch_size, tokens_size, self.num_heads, self.dim_head]
        )

        # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
        attn_scores = None
        assert attn_scores.shape == torch.Size(
            [batch_size, self.num_heads, tokens_size, tokens_size]
        )
        attn_scores_masked = None
        attn_probs = None

        # Take weighted sum of value vectors, according to attention probabilities
        z = None
        assert z.shape == torch.Size(
            [batch_size, tokens_size, self.num_heads, self.dim_head]
        )

        # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
        attn_out = None
        assert attn_out.shape == torch.Size([batch_size, tokens_size, self.dim_model])

        return attn_out

In [None]:
batch_size = 12
tokens_dim = 20
dim_model = 30
dim_heads = 10
num_heads = 2
ground_truth = HiddenAttention(num_heads, dim_model, dim_heads)
user_model = Attention(num_heads, dim_model, dim_heads)
test = torch.rand((batch_size, tokens_dim, dim_model))

truth_output = ground_truth(test)
user_output = user_model(test)

assert torch.allclose(
    truth_output, user_output
), "Uh oh your model doesn't give the same outputs"
print("passed all tests!")