## Coding Attention Mechanisms

This notebook covers Chapter 3 of [*Build a Large Language Model from Scratch*](https://www.manning.com/books/build-a-large-language-model-from-scratch) by Sebastian Raschka (2025).

### The Problem of Modeling Long Sequences

- Before transformers, recurrent neural networks (RNNs) were often used for tasks like language translation.
    - RNNs are good for sequences, since they use outputs of previous steps as inputs to "current" step.
        - RNNs have an encoder and a decoder.
        - The encoder processes the entire input context into a hidden state ("memory cell;" essentially an embedding).
        - The encoder uses the hidden state to generate the output.
    - Problematically, the decoder has no way of accessing earlier hidden states at decoding time, relying ***solely*** on the "current hidden state," which can lead to a loss of context (p. 53).

### Capturing Dependencies with Attention Mechanisms

> "Although RNNs work fine for translating short sentences, they don't work well for longer texts as they don't have direct access to previous words in the input. One major shortcoming int his approach is that the RNN ***must remember the entire encoded input in a single hidden state*** before passing it to the decoder" (Raschka 2025:54).

- The [***Bahdanau attention***](https://arxiv.org/abs/1409.0473) mechanism was developed in 2014 to allow the decoder of an RNN to "selectively access different parts of the input sequence at each decoding step" (p. 54).
    - The selectivity of the mechanism means that some tokens will be more important than others. 
- In 2017, the [***transformer***](https://arxiv.org/abs/1706.03762) architecture was proposed.
    - It was discovered that RNNs are not necessary for processing language and text.
- The transformer model uses ***self-attention***.
    - With self-attention, the relevancy of each token in each position can be considered.
    - Each token in an input sequence can "attend to" all other tokens in every other position of the same sequence.

### Self-attention

- The "self" of self-attention refers to "the mechanisms' ability to compute attention weights by relating different positions within a single input sequence" (p. 56).
    - Dependencies between different tokens at different positions are learned, as are their relative importance.
- The goal is to calculate ***context vectors*** that are essentially and enriched embedding.
    - They are "enriched" because they contain information about all the other tokens in the sequence. 

**Let's start with a simplified version of self-attention without any trainable weights:**

In [1]:
import torch

# assume these are our our initial (random) 3D embedding vectors
# for the sequence "Your journey starts with one step":
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # x_1: Your
     [0.55, 0.87, 0.66], # x_2: journey
     [0.57, 0.85, 0.64], # x_3: starts
     [0.22, 0.58, 0.33], # x_4: with
     [0.77, 0.25, 0.10], # x_5: one
     [0.05, 0.80, 0.55]  # X_6: step 
    ]
)

We can compute the ***intermediate attention scores*** between a "query token" (i.e., the token at the current position) and each token in the input sequence by taking the dot product of the query token $x^{(q)}$ and every single token in the input sequence:

In [2]:
# assume our query token is position 1 [index 1], "journey":
query = inputs[1]

# initiate a tensor of shape inputs.shape[0] to store weights:
attention_scores = torch.empty(inputs.shape[0])

# compute attention scores for each (query, token) pair:
for idx, x_i in enumerate(inputs):
    attention_scores[idx] = torch.dot(x_i, query)

print(attention_scores)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


The dot product yields a scalar value, but it also is a ***measure of similarity***: "a higher dot product indicates a greater degree of alignment or similarity between the vectors" (p. 59).

**Next, we noramlize attention scores so they sum to 1:**
- This is useful for both interpretation and LLM training stability.

In [3]:
# simple approach:
attention_scores_normed = attention_scores / attention_scores.sum()
print(f"Attention weights:\n {attention_scores}")
print(f"Normalized attention scores:\n {attention_scores_normed}")
print(f"Sum of normed attention weights: {attention_scores_normed.sum()}")

Attention weights:
 tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
Normalized attention scores:
 tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum of normed attention weights: 1.0000001192092896


In actual practice, we'd use the ***softmax*** function:
- Softmax ensures weights are always positive.
- Allows for interpreting weights as probabilities / relative importance.
    - Higher weights indicate more improtance. 

In [4]:
# simple softmax:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

# apply:
attention_scores_softmax = softmax_naive(attention_scores)
print(f"Softmax attention scores:\n {attention_scores_softmax}")
print(f"Sum: {attention_scores_softmax.sum()}")

Softmax attention scores:
 tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: 1.0


Due to stability issues when dealing with very large or small input values (i.e., "overflow" and "underflow"), use the PyTorch softmax implementation, which is optimized:

In [5]:
torch_softmax_attn_scores = torch.softmax(attention_scores, dim=0)
print(f"PyTorch softmax weights:\n {torch_softmax_attn_scores}")
print(f"Sum: {torch_softmax_attn_scores.sum()}")

PyTorch softmax weights:
 tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: 1.0


**Now, we calculate the context vectors,** $z^{q}$:
- Multiply the embedded input tokens by the attention weights.
- Then, sum the vectors.
- The context vector is a ***weighted sum*** of all of the input vectors.

In [6]:
# again, assume we are interested in the token at position 2 [index 1]:
query = inputs[1]

# placeholder for context vector:
context_vector = torch.zeros(query.shape)

# iterate and compute:
for i, x_i in enumerate(inputs):
    context_vector += torch_softmax_attn_scores[i]*x_i

print(context_vector)

tensor([0.4419, 0.6515, 0.5683])


The context vector has the same dimensionality (here, `3D`) as the original token embeddings.

### Computing attention weights for all input tokens

We'll generalize the attention computation to the entire sequence:

In [7]:
# our toy sequence has a length of 6, so let's create an empty tensor:
attention_scores = torch.empty(6, 6)

# now, loop over the entire sequence:
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        # attention scores tor tokens (i, j):
        attention_scores[i, j] = torch.dot(x_i, x_j)

# now, we have attention vectors for each token in the sequence:
print(attention_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Eliminate the expensive `for` loop by using matrix multiplication:

In [8]:
attention_scores = inputs @ inputs.T
print(attention_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Normalize with softmax:

In [9]:
attention_weights = torch.softmax(attention_scores, dim=-1) # dim=-1 sets the normalization to the "last" dimension of the tensor.
print(attention_scores)
print(f"Sums: {attention_weights.sum(dim=-1)}")

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
Sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


Lastly, create all the context vectors, $z^{(i)}$, again using the dot product:

In [10]:
context_vectors = attention_weights @ inputs
print(context_vectors)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


### Implementing trainable weights

Let's implement the original ***scaled dot-product attention*** from the transformers paper.

There are three trainable weight matrices:

1. The ***query*** matrix, $W_q$
    - A query is the current token being interpreted by the model.
    - The query vector is used to determine which other tokens in the sequence are relevant to understanding the meaning of the current token.
2. The ***key*** matrix, $W_k$
    - A key is basically an index (used for searching), and each token has a key vector.
    - Keys are used to match the query.
3. The ***value*** matrix, $W_v$
    - The value vector is meant to be analogous to a `{key: value}` pair: 
        - It contains the "representation" of the input.
        - The model figures out which keys are most important to the query, and obtains its value vector. 

By way of example, let's start at position 2:

In [11]:
# get the token embedding:
x_2 = inputs[1]

# get the embedding dimensionality:
d_in = inputs.shape[1]

# set the output embedding size (note: these are usually equal to the input dimension):
d_out = 2

# create matrices (requires_grad=False just for simplicity):
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# get the query, key, and value vectors for token 2:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2)

tensor([0.4306, 1.4551])


But we should use matrix multiplication to perform all key and value computations at once:

In [12]:
keys = inputs @ W_key
values = inputs @ W_value
print(f"Keys shape: {keys.shape}")
print(f"Values shape: {values.shape}")

Keys shape: torch.Size([6, 2])
Values shape: torch.Size([6, 2])


Now we can get the unnormalized attention scores for, e.g., token 2:

In [13]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


And generalize the computation across all tokens:

In [14]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


We now want to create the attention weights. But instead of performing a softmax on the raw attention scores, we scale the scores by the square root of the key matrix's embedding dimensionality:

- Scaling improves training performance because it helps avoid small gradients:

    > "...large dot products can result in very small gradients during backpropagation due to the softmax function applied to them. As dot products increase, the softmax function behaves more like a step function, resulting in gradients nearing zero" (p. 69). 

In [15]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


Finally, we compute the context vector:

In [16]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


### Implementing a self-attention class

Such step-by-step estimations are useful for learning, but not in practice. Let's implement self-attention as a class:

In [17]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        
        # query matrix:
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))

        # key matrix:
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))

        # value matrix:
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    # forward pass method:
    def forward(self, x):
        # get queries, keys, and values:
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        # raw attention scores ('omega'):
        attn_scores = queries @ keys.T

        # scaled attention weights:
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        # context vectors:
        context_vec = attn_weights @ values
        return context_vec

Try it out:

In [18]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


Let's refactor the code to use `nn.Linear` layers instead of `nn.Parameter`:

- When `bias=False`, `nn.Linear` essentially performs matrix multiplication.
- It also has optimized weight initialization, which makes it more stable during training (see p. 72).

In [19]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        # simpler matrix math because of nn.Linear:
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # attention:
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        # context vector:
        context_vec = attn_weights @ values
        return context_vec

Try it out:

In [20]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


The weight initialization is differnt for `nn.Linear` and `nn.Parameter`, but we can show they are similar by transferring weights:

In [21]:
torch.manual_seed(123)
sa_v2_v1 = SelfAttention_v2(d_in, d_out)
sa_v2_v1_cv = sa_v2_v1(inputs)

sa_v1_v2 = SelfAttention_v1(d_in, d_out)
sa_v1_v2_cv = sa_v1_v2(inputs)

In [22]:
# nn.Linear:
sa_v2_v1_cv

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

In [23]:
# nn.Parameter:
sa_v1_v2_cv

tensor([[1.4035, 1.0391],
        [1.4410, 1.0669],
        [1.4391, 1.0655],
        [1.3786, 1.0178],
        [1.3653, 1.0086],
        [1.4025, 1.0361]], grad_fn=<MmBackward0>)

Reassign weights (**note:** we must transpose `nn.Parameter` objects to match `nn.Linear.weight` objects):

In [24]:
sa_v2_v1.W_query.weight = nn.Parameter(sa_v1_v2.W_query.T)
sa_v2_v1.W_key.weight = nn.Parameter(sa_v1_v2.W_key.T)
sa_v2_v1.W_value.weight = nn.Parameter(sa_v1_v2.W_value.T)

# rerun:
sa_v2_v1(inputs)

tensor([[1.4035, 1.0391],
        [1.4410, 1.0669],
        [1.4391, 1.0655],
        [1.3786, 1.0178],
        [1.3653, 1.0086],
        [1.4025, 1.0361]], grad_fn=<MmBackward0>)

We get the same outputs!

### Multi-head and Causal Attention

***Causal***, or ***masked***, attention mechanisms introduce masking so that the model cannot see future tokens in a sequence.
- Each predicted word should only depend on preceding tokens.

***Multi-headed attention***, well, splits the attention mechanism across multiple *heads*. 
- Each *head* is supposed to learn different things from the input data.
- The model should learn to attend to information from different "subspaces" at different positions.

**Starting with masked attention:**
- For GPT-like models, future tokens become masked.
- In the square attention matrix, tokens above the diagonal are masked.
- Normalization is applied to ensure each row sums to `1`.

Example:

In [25]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(
    attn_scores / keys.shape[-1]**0.5, dim=-1
)

print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


Use `tril` to mask above the diagonal:

In [26]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


Now we can simply multiple this mask matrix to the attention weights:

In [27]:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


We must renormalize the the weights so each row sums to 1:

In [28]:
row_sum = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sum
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


We can simplify this by replacing values above the diagonal in the attention weight matrix with $-\infty$, since the softmax transforms $-\infty$ into zero:

- We can cuse `torch.triu` to implement this trick.

In [29]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


Finally, apply the softmax, this time to `dim=1` to sum all rows to `1`:

In [30]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


**Adding dropout to masked attention:**

In ***dropout***, we randomly "drop" hidden layer units to prevent overfitting or biasing the model to a sepcific set of hidden layer units.

- Dropout is ***only*** used during training.

For transformer and GPT-like models, dropout usually is applied in one of two points:

1. After calculating attention weights (more common).
2. After applying attention weights to value vectors.

Example:

In [31]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) # dropout rate of 50% (in practice, more like 10% - 20%)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


We can apply the dropout to the attention weights:

- **Note:** due the reduction in elements caused by dropout, the remaining elements are scaled up by a factor of $\frac{1}{\text{Dropout Rate}}$

In [32]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


**Package a causal attention class**

Let's simulate a batch by duplicating the inputs:

In [33]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [34]:
print(batch)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


Now build our `CausalAttention` class:

In [35]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        # output dims:
        self.d_out = d_out

        # query, key, and value matrices:
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # dropout:
        self.dropout = nn.Dropout(dropout)

        # this buffer helps ensure our tensors are all on the same device
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    # forward pass:
    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape
        
        # queries, keys, and values:
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # attention:
        attn_scores = queries @ keys.transpose(1, 2) # we transpose (1, 2), because the batch size is at index 0.
        attn_scores.masked_fill_( # performed in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
       
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        attn_weights = self.dropout(attn_weights)

        # context vector:
        context_vec = attn_weights @ values
        return context_vec


Try it out:

In [36]:
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print(f"Context vectors shape: {context_vecs.shape}")
print(context_vecs)

Context vectors shape: torch.Size([2, 6, 2])
tensor([[[-0.2834, -0.2539],
         [-0.3675, -0.1289],
         [-0.3957, -0.0917],
         [-0.3571, -0.0482],
         [-0.3523, -0.0932],
         [-0.3354, -0.0407]],

        [[-0.2834, -0.2539],
         [-0.3675, -0.1289],
         [-0.3957, -0.0917],
         [-0.3571, -0.0482],
         [-0.3523, -0.0932],
         [-0.3354, -0.0407]]], grad_fn=<UnsafeViewBackward0>)


### Implementing multi-head attention

In multi-headed attention, we have multiple, independent attention heads with their own sets of weights.

- We can implement multi-headed causal attentio by "stacking" multiple `CausalAttention` modules.
- That is, we'll create several instances of `CausalAttention`.
- The attention heads, while independent, are run in parallel.

We'll have multiple query, weight, value, and context matrices.

Simple example:

In [37]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, 
                 dropout, num_heads, qkv_bias=False):
        super().__init__()

        # create multiple heads:
        self.heads = nn.ModuleList(
            [
                CausalAttention(
                    d_in, d_out, context_length, dropout, qkv_bias
                )
                for head in range(num_heads)
            ]
        )
        
    # forward pass:
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


Running on our input with `2` heads:

In [38]:
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)

tensor([[[ 0.3678, -0.2396,  0.2942,  0.0256],
         [ 0.4388, -0.1884,  0.3751,  0.1230],
         [ 0.4649, -0.1622,  0.4020,  0.1552],
         [ 0.4257, -0.1461,  0.3591,  0.1498],
         [ 0.4004, -0.0560,  0.3546,  0.1504],
         [ 0.3921, -0.0950,  0.3361,  0.1494]],

        [[ 0.3678, -0.2396,  0.2942,  0.0256],
         [ 0.4388, -0.1884,  0.3751,  0.1230],
         [ 0.4649, -0.1622,  0.4020,  0.1552],
         [ 0.4257, -0.1461,  0.3591,  0.1498],
         [ 0.4004, -0.0560,  0.3546,  0.1504],
         [ 0.3921, -0.0950,  0.3361,  0.1494]]], grad_fn=<CatBackward0>)


Because our output dimension is `2`, and we have `2` attention heads, the final embedding dimension is `output_dim * num_heads = 4`:

In [39]:
print(context_vecs.shape)

torch.Size([2, 6, 4])


Here, `context_vecs.shape` corresponds to:

- `2` $\rightarrow$ batch size (number of texts)
- `6` $\rightarrow$ the number of tokens per document (context length)
- `4` $\rightarrow$ the final, `4D` embedding of each token

We can simplify the multi-head attention calculations by combining the functionality into a single `MultIheadAttention` class.

In [40]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out,
                 context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()

        # logic check:
        assert (d_out % num_heads == 0), "Error: d_out must be divisible by num_heads!"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # final embedding size

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # linear layer for head outputs
        # (not strictly necessary, but commonly used):
        self.out_proj = nn.Linear(d_out, d_out)

        # dropout:
        self.dropout = nn.Dropout(dropout)

        # register buffer:
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape

        # queries, keys, values
        # of shape (batch_size, num_tokens, d_out):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        # split the matrices:
        queries = queries.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim)

        # transpose from (batch_size, num_tokens, num_heads, head_dim)
        # to (batch_size, num_heads, num_tokens, head_dim):
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # attention scores:
        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens] # mask
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        # context vectors:
        context_vec = (attn_weights @ values).transpose(1, 2) # (1, 2) --> (num_tokens, num_heads)
        context_vec = context_vec.contiguous().view( # tensor of shape (batch_size, num_tokens, num_heads, head_dim)
            batch_size, num_tokens, self.d_out
        )
        context_vec = self.out_proj(context_vec)
        return context_vec

    

To understand the shapes and transposes, play around with this example:

In [41]:
batch_size = 2
num_tokens = 6
token_embedding_dims = 3
num_heads = 2
out_dim = 2
head_dim = out_dim // num_heads

torch.manual_seed(123)
linear = nn.Linear(token_embedding_dims, out_dim, bias=False)
queries = linear(batch)

print(queries.shape) # (batch_size, num_tokens, out_dim)
print(queries)

torch.Size([2, 6, 2])
tensor([[[-0.3536,  0.3965],
         [-0.3021, -0.0289],
         [-0.3015, -0.0232],
         [-0.1353, -0.0978],
         [-0.2052,  0.0870],
         [-0.1542, -0.1499]],

        [[-0.3536,  0.3965],
         [-0.3021, -0.0289],
         [-0.3015, -0.0232],
         [-0.1353, -0.0978],
         [-0.2052,  0.0870],
         [-0.1542, -0.1499]]], grad_fn=<UnsafeViewBackward0>)


The `view` call changes the shape to `(batch_size, num_tokens, num_heads, head_dim)`:

In [42]:
viewed_queries = queries.view(
    batch_size, num_tokens, num_heads, head_dim
)

print(viewed_queries.shape)
print(viewed_queries[0].shape)

torch.Size([2, 6, 2, 1])
torch.Size([6, 2, 1])


**Using the new class:**

In [43]:
# seed:
torch.manual_seed(123)

# information from our batch of inputs:
batch_size, num_tokens, d_in = batch.shape

print(f"Batch size: {batch_size}")
print(f"Num tokens: {num_tokens}")
print(f"Input token embedding dimensions: {d_in}\n")

# multi-headed attention:
d_out = 2 # controls dimensionality of the context vectors
num_heads = 2

mha = MultiHeadAttention(
    d_in=d_in, d_out=d_out, context_length=num_tokens,
    dropout=0.0, num_heads=num_heads
)

context_vecs = mha(batch)

print(f"Context matrix shape: {context_vecs.shape}")
print(f"Context matrix:\n {context_vecs}")

Batch size: 2
Num tokens: 6
Input token embedding dimensions: 3

Context matrix shape: torch.Size([2, 6, 2])
Context matrix:
 tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)


### Exercise: Match the smallest GPT-2 model

- `12` attention heads
- `768` dimension input and out embeddings
- `1024` context length

In [None]:
# embeddings and context:
EMBEDDING_DIM = 768
CONTEXT_LENGTH = 1024

# simulate an input sequence:
doc = torch.randn(CONTEXT_LENGTH, EMBEDDING_DIM)

# simulate a batch:
gpt_batch = torch.stack((doc, doc), dim=0)

In [45]:
print(f"Doc shape: {doc.shape}")
print(f"Doc tensor: {doc}")

Doc shape: torch.Size([1024, 768])
Doc tensor: tensor([[-0.1606, -0.4015,  0.6957,  ...,  1.6873,  1.7270,  0.7496],
        [ 0.7789,  0.7275, -1.2886,  ..., -0.6831, -0.2214, -1.1118],
        [-0.5157,  0.2716, -0.9348,  ..., -1.3915, -1.0310,  1.6252],
        ...,
        [ 0.6208,  1.7514,  0.0539,  ...,  1.2243, -0.5144,  0.4884],
        [-0.0042, -0.7276,  0.8973,  ...,  0.2352,  0.3677, -0.1048],
        [ 0.2063,  0.1570,  0.0966,  ...,  0.8885, -1.0166, -1.6929]])


In [46]:
print(f"Batch shape: {gpt_batch.shape}")
print(f"Batch tensor: {gpt_batch}")

Batch shape: torch.Size([2, 1024, 768])
Batch tensor: tensor([[[-0.1606, -0.4015,  0.6957,  ...,  1.6873,  1.7270,  0.7496],
         [ 0.7789,  0.7275, -1.2886,  ..., -0.6831, -0.2214, -1.1118],
         [-0.5157,  0.2716, -0.9348,  ..., -1.3915, -1.0310,  1.6252],
         ...,
         [ 0.6208,  1.7514,  0.0539,  ...,  1.2243, -0.5144,  0.4884],
         [-0.0042, -0.7276,  0.8973,  ...,  0.2352,  0.3677, -0.1048],
         [ 0.2063,  0.1570,  0.0966,  ...,  0.8885, -1.0166, -1.6929]],

        [[-0.1606, -0.4015,  0.6957,  ...,  1.6873,  1.7270,  0.7496],
         [ 0.7789,  0.7275, -1.2886,  ..., -0.6831, -0.2214, -1.1118],
         [-0.5157,  0.2716, -0.9348,  ..., -1.3915, -1.0310,  1.6252],
         ...,
         [ 0.6208,  1.7514,  0.0539,  ...,  1.2243, -0.5144,  0.4884],
         [-0.0042, -0.7276,  0.8973,  ...,  0.2352,  0.3677, -0.1048],
         [ 0.2063,  0.1570,  0.0966,  ...,  0.8885, -1.0166, -1.6929]]])


Compute:

In [47]:
# GPT-2 has 12 attention heads:
NUM_HEADS = 12

# To keep 768 dimensionality as final output,
# OUT_DIM must be equal to NUM_HEADS * EMBEDDING_DIM
OUT_DIM = NUM_HEADS * EMBEDDING_DIM

# attention object:
mha = MultiHeadAttention(
    d_in=EMBEDDING_DIM, d_out=EMBEDDING_DIM, context_length=CONTEXT_LENGTH,
    dropout=0.0, num_heads=NUM_HEADS, qkv_bias=False
)

# compute:
context_vecs = mha(gpt_batch)

# inspect:
print(f"Context matrix shape: {context_vecs.shape}")
print(f"First batch context matrix:\n {context_vecs[0]}")

Context matrix shape: torch.Size([2, 1024, 768])
First batch context matrix:
 tensor([[ 0.4348, -0.1356, -0.6276,  ..., -0.0624, -0.1139, -0.0113],
        [ 0.2209, -0.0992, -0.2100,  ..., -0.1952, -0.0247, -0.1179],
        [ 0.1565,  0.0024, -0.2139,  ..., -0.0586, -0.0182, -0.2590],
        ...,
        [ 0.0015,  0.0321,  0.0141,  ..., -0.0213, -0.0011,  0.0176],
        [-0.0076,  0.0298,  0.0232,  ..., -0.0211,  0.0050,  0.0185],
        [-0.0084,  0.0305,  0.0135,  ..., -0.0069, -0.0016,  0.0177]],
       grad_fn=<SelectBackward0>)


Done!