# Writing Attention mechanisms
The attention mechanism in large language models (LLMs) allows each token in a sequence to dynamically focus on the most relevant other tokens when building its representation, instead of relying only on fixed-size contexts like in traditional RNNs or CNNs. By computing similarity scores between queries, keys, and values derived from token embeddings, attention lets the model capture dependencies across arbitrary distances — for example, linking a pronoun to its noun several sentences earlier. This enables LLMs to model long-range context, disambiguate meaning, and integrate information from the whole sequence efficiently, which is the foundation of how transformers achieve state-of-the-art performance in natural language understanding and generation.

We will implement four different variants of attention mechanisms These different attention variants build on each other.
1. **Simplified self-attention**
   A simplified self-attention technique to introduce the broader idea.
3. **Self-attention**
   Self-attention with trainable weights that forms the basis of the mechanism used in LLMs
5. **Causal attention**
   A type of self-attention used in LLMs that allows a model to consider only previous and current inputs in a sequence, ensuring temporal order during the text generation
7. **Multi-head attention**
   An extension of self-attention and causal attention that enables the model to simultaneously attend to information from different representation subspaces.

Lets start with our data input. Consider we have a sentence <b>"Your journey starts with one step"</b>. This sentence has been tokenized and embeded, so we end up with the tensor we have bellow.

In [2]:
import torch
inputs = torch.tensor(
 [[0.43, 0.15, 0.89], # Your (x^1)
 [0.55, 0.87, 0.66], # journey (x^2)
 [0.57, 0.85, 0.64], # starts (x^3)
 [0.22, 0.58, 0.33], # with (x^4)
 [0.77, 0.25, 0.10], # one (x^5)
 [0.05, 0.80, 0.55]] # step (x^6)
)

print(inputs[0])

tensor([0.4300, 0.1500, 0.8900])


Lets start with a simplified self attention mechanism.

## 1. Attending to different parts of the input with self-attention
We’ll now cover the inner workings of the self-attention mechanism and learn how to code it from the ground up. Self-attention serves as the cornerstone of every LLM based on the transformer architecture.

### a. Simplified self-attention mechanism without trainable weights (we will just do this for 1 token)
Let’s begin by implementing a simplified variant of self-attention, free from any trainable weights. The goal is to illustrate a few key concepts in self-attention before adding trainable weights.

The first step of implementing self-attention is to compute the intermediate values `ω`, referred to as attention scores. We calculate the intermediate attention scores between the query token and each input token. We determine these scores by computing the dot product of the query, `x(2)`, with every other input token. To demostrate this we have picked the second input embeding `[0.55, 0.87, 0.66]` for the word **"journey"**.

In [3]:
query = inputs[1]

# this Creates a new uninitialized tensor of shape inputs.shape[0]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    # here we calculate the dot product of the input of the current iterataion of the input 
    # embeding tensor 'x_i' to our selected input "input[1].
    attn_scores_2[i] = torch.dot(x_i, query)
# we get a responce of a tensor with shape [1, 6]
print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


We got a result of a tensor of shape `[1, 6]`. This shows that we have we have 1 score for each embeding on our current sentence, hence in the end every word->token->embediding will have scores->weight for every other word in the sentence.
<br>
To convert scores to weights we need to normalize the output tensor. We use softmax for this.

In [4]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


Now that we have computed the normalized attention **weights**, we are ready for the final step of calculating the **context** vector `z(2)` by multiplying the **embedded input tokens**, `x(i)`, with the corresponding attention weights and then summing the resulting vectors. Thus, context vector `z(2)` is the weighted sum of all input vectors, obtained by multiplying each input vector by its corresponding attention weight:

In [5]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


** Summary of Self-Attention **
- Compute scores: Each token compares itself to all other tokens using a similarity measure (dot product).
- Normalize scores: Apply softmax to convert scores into attention weights that sum to 1.
- Compute context vectors: Multiply the attention weights with the value vectors to get a new representation for each token. The context vectors have the same shape as the original embeddings.

```
Token Embeddings (X)
   ┌───────────────┐
   │  x1  x2  x3   │  <-- shape [T, d_model]
   └───────────────┘
           │
           ▼
Compute Scores (Q · K^T / √d)
   ┌───────────────┐
   │ s11 s12 s13   │
   │ s21 s22 s23   │  <-- shape [T, T]
   │ s31 s32 s33   │
   └───────────────┘
           │
           ▼
Softmax to get Attention Weights (α)
   ┌───────────────┐
   │ α11 α12 α13   │
   │ α21 α22 α23   │  <-- rows sum to 1
   │ α31 α32 α33   │
   └───────────────┘
           │
           ▼
Weighted sum of Values to get Context Vectors (C)
   ┌───────────────┐
   │ c1  c2  c3    │  <-- shape [T, d_model], same as embeddings
   │ c1  c2  c3    │
   │ c1  c2  c3    │
   └───────────────┘
```
---

### b. Computing attention weights for all input tokens
So far, we have computed attention weights and the context vector for input 2 (`input[1]`). Now let’s extend this computation to calculate attention weights and context vectors for all inputs.

In [6]:
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Each element in the tensor represents an attention score between each pair of inputs. Note that the values in that figure are normalized, which is
why they differ from the unnormalized attention scores in the preceding tensor. 

When computing the preceding attention score tensor, we used for loops in Python. However, for loops are generally slow, and we can achieve the same results using matrix multiplication:

In [7]:
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Then we normalize each row.

In [8]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In the context of using PyTorch, the dim parameter in functions like `torch.softmax` specifies the dimension of the input tensor along which the function will be computed. By setting dim=-1, we are instructing the softmax function to apply the normalization along the last dimension of the attn_scores tensor. If attn_scores is a two-dimensional tensor (for example, with a shape of [rows, columns]), it will normalize across the columns so that the values in each row (summing over the column dimension) sum up to 1.

In the third and final step, we use these attention weights to compute all context vectors via matrix multiplication:

In [9]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## 2. Implementing self-attention with trainable weights
Our next step will be to implement the self-attention mechanism used in the original transformer architecture, the GPT models, and most other popular LLMs. This self-attention mechanism is also called scaled dot-product attention. The self-attention mechanism with trainable weights builds on the previous concepts. We want to compute context vectors as weighted sums over the input vectors specific to a certain input element.<br>
We will implement the self-attention mechanism step by step by introducing thethree trainable weight matrices 
- Wq (Query weight matrix): A learned projection that transforms the input embeddings into queries, representing what each token is looking for.
- Wk (Key weight matrix): A learned projection that transforms the input embeddings into keys, representing what each token has to offer for matching.
- Wv (Value weight matrix): A learned projection that transforms the input embeddings into values, representing the actual content or information to 

### a. Computing the attention weights step by step
Earlier, we defined the second input element `x(2)` as the query when we computed the simplified attention weights to compute the context vector `z(2)`. Then we generalized this to compute all context vectors `z(1) ... z(T)` for the six-word input sentence **“Your journey starts with one step.”** Similarly, we start here by computing only one context vector, `z(2)`, for illustration purposes. We will then modify this code to calculate all context vectors.

In [10]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

Note that in GPT-like models, the input and output dimensions are usually the same, but to better follow the computation, we’ll use different input `(d_in=3)` and output `(d_out=2)` dimensions here.<br><br>
 <b>Next</b>, we initialize the three weight matrices Wq, Wk, and Wv.

In [11]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

We set `requires_grad=False` to reduce clutter in the outputs, but if we were to use the weight matrices for model training, we would set `requires_grad=True` to update these matrices during model training.<br><br>
 <b>Next</b>, we compute the query, key, and value vectors:

In [12]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print('query', query_2)
print('key', key_2)
print('value', value_2)

query tensor([0.4306, 1.4551])
key tensor([0.4433, 1.1419])
value tensor([0.3951, 1.0037])


Even though our temporary goal is only to compute the one context vector, `z(2)`, we still require the key and value vectors for all input elements as they are involved in computing the attention weights with respect to the query `q(2)`. We can obtain all keys and values via matrix multiplication:

In [13]:
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


First, let’s compute the attention score ω22:

In [14]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


Again, we can generalize this computation to all attention scores via matrix
multiplication:

In [15]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


Now, we want to go from the attention scores to the attention weights. We compute the attention weights by scaling the attention scores and using the softmax function. However, now we scale the attention scores by dividing them by the square root of the embedding dimension of the keys (taking the square root is mathematically the same as exponentiating by 0.5):

In [16]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


Finally we will compute the context vector as a weighted sum over the value vectors as we have done before. The attention weights serve as a weighting factor that weighs the respective importance of each value vector. Also as before, we can use matrix multiplication to obtain the output in one step

In [17]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


### b. Implementing a compact self-attention Python class
At this point, we have gone through a lot of steps to compute the self-attention outputs. We did so mainly for illustration purposes so we could go through one step at a time. In practice, with the LLM implementation in the next chapter in mind, it is helpful to organize this code into a Python class, as shown in the following listing.

In [18]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
        
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

We create `SelfAttention_v1` class inherited from `nn.Module`, which is a fundamental building block of PyTorch models that provides necessary model layer creation and management.
 The `__init__` method initializes trainable weight matrices **(W_query, W_key, and W_value)** for queries, keys, and values, each transforming the input dimension `d_in` to an output dimension `d_out`. During the forward pass, using the forward method, we compute the attention scores `(attn_scores)` by multiplying queries and keys, normalizing these scores using `softmax`. Finally, we create a context vector by weighting the values with these normalized attention scores.<br><br>
We can use this class as follows:

In [19]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


Since inputs contains six **embedding vectors**, this results in a matrix storing the six **context vectors**:

 Self-attention involves the trainable weight matrices Wq, Wk, and Wv. These matrices transform input data into queries, keys, and values, respectively, which are crucial components of the attention mechanism. As the model is exposed to more data during training, it adjusts these trainable weights, as we will see in upcoming chapters

### c. A self-attention class using PyTorch’s Linear layers
 We can improve the `SelfAttention_v1` by using fully connected (dense) layer `nn.Linear` layers, which effectively perform matrix multiplication when
the bias units are disabled. Additionally, a significant advantage of using `nn.Linear` instead of manually implementing `nn.Parameter(torch.rand(...))` is that `nn.Linear` has an optimized weight initialization scheme, contributing to more stable and effective model training.

In [20]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

You can use the `SelfAttention_v2` similar to `SelfAttention_v1`:

In [21]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


## 3. Hiding future words with Causal attention

Wel want the self-attention mechanism to consider only the tokens that appear prior to the current position when predicting the next token in a sequence. Causal attention, also known as masked attention, is a specialized form of selfattention. It restricts a model to only consider previous and current inputs in a sequence. This is in contrast to the standard self-attention mechanism, which allows access to the entire input sequence at once.<br>
To achieve this in GPT-like LLMs, for each token processed, we mask out the future tokens, which come after the current token in the input text, as illustrated bellow: 

| Query \ Key | your | journey | starts | with | one | step |
| ----------- | ---- | ------- | ------ | ---- | --- | ---- |
| your        | ✓    | <span style="color:red">✗</span>       | <span style="color:red">✗</span>      | <span style="color:red">✗</span>    | <span style="color:red">✗</span>   | <span style="color:red">✗</span>    |
| journey     | ✓    | ✓       | <span style="color:red">✗ </span>     | <span style="color:red">✗</span>    | <span style="color:red">✗</span>   | <span style="color:red">✗</span>    |
| starts      | ✓    | ✓       | ✓      | <span style="color:red">✗</span>    | <span style="color:red">✗</span>   | <span style="color:red">✗</span>    |
| with        | ✓    | ✓       | ✓      | ✓    | <span style="color:red">✗   | <span style="color:red">✗</span>    |
| one         | ✓    | ✓       | ✓      | ✓    | ✓   | <span style="color:red">✗</span>    |
| step        | ✓    | ✓       | ✓      | ✓    | ✓   | ✓    |


### a. Applying a causal attention mask
To implement the apply a causal attention mask to obtain the masked attention weights, let’s work with the attention scores and weights from the previous section to code the causal attention mechanism. 
In the first step, we compute the attention weights using the softmax function as we have done previously:

In [22]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


We can implement the second step using PyTorch’s tril function to create a mask
where the values above the diagonal are zero:


In [24]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


Now, we can multiply this mask with the attention weights to zero-out the values above
the diagonal:

In [25]:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


The third step is to renormalize the attention weights to sum up to 1 again in each
row. We can achieve this by dividing each element in each row by the sum in each row:

In [26]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


Let’s take a mathematical property of the softmax function and implement the computation of the masked attention weights more efficiently in fewer steps. The softmax function converts its inputs into a probability distribution. When negative infinity values (-∞) are present in a row, the softmax function treats them as zero probability. (Mathematically, this is because e –∞ approaches 0.)<br>
 We can implement this more efficient masking “trick” by creating a mask with 1s
above the diagonal and then replacing these 1s with negative infinity (-inf) values:

In [27]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


Now all we need to do is apply the softmax function to these masked results, and we
are done:

In [26]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


### b. Masking additional attention weights with dropout
Dropout in deep learning is a technique where randomly selected hidden layer units are ignored during training, effectively “dropping” them out. This method helps prevent overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units. It’s important to emphasize that dropout is only used during training and is disabled afterward.<br>
 In the transformer architecture, including models like GPT, dropout in the attention mechanism is typically applied at two specific times: after calculating the attention weights or after applying the attention weights to the value vectors. Here we will apply the dropout mask after computing the attention weights, because it’s the more common variant in practice.
 In the following code example, we use a dropout rate of 50%, which means masking out half of the attention weights. (When we train the GPT model,
we will use a lower dropout rate, such as 0.1 or 0.2.)<br> 4
As a demostration We have applied PyTorch’s dropout implementation first to a `6 × 6` tensor consisting of 1s for simplicity:

In [29]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print("6 x 6 matrix before dropout: \n", example)
print("6 x 6 matrix after dropout: \n", dropout(example))

6 x 6 matrix before dropout: 
 tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
6 x 6 matrix after dropout: 
 tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 = 2. This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.<br>
 Now let’s apply dropout to the attention weight matrix itself:

In [28]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


### c. Implementing a compact causal attention class
We will now incorporate the causal attention and dropout modifications into the
SelfAttention Python class we developed in section 3.4. This class will then serve as a
template for developing multi-head attention, which is the final attention class we will
implement.
 But before we begin, let’s ensure that the code can handle batches consisting of
more than one input so that the CausalAttention class supports the batch outputs
produced by the data loader we implemented in chapter 2.
 For simplicity, to simulate such batch inputs, we duplicate the input text example:

In [32]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) 

torch.Size([2, 6, 3])


The following CausalAttention class is similar to the SelfAttention class we implemented earlier, except that we added the dropout and causal mask components

In [33]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
         super().__init__()
         self.d_out = d_out
         self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
         self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
         self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
         self.dropout = nn.Dropout(dropout)
         self.register_buffer(
             'mask',
             torch.triu(torch.ones(context_length, context_length),
             diagonal=1)
         )
        
    def forward(self, x):
         b, num_tokens, d_in = x.shape
         keys = self.W_key(x)
         queries = self.W_query(x)
         values = self.W_value(x)
         attn_scores = queries @ keys.transpose(1, 2)
         attn_scores.masked_fill_(
         self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
         attn_weights = torch.softmax(
             attn_scores / keys.shape[-1]**0.5, dim=-1
         )
         attn_weights = self.dropout(attn_weights)
         context_vec = attn_weights @ values
         return context_vec

While all added code lines should be familiar at this point, we now added a `self.register_buffer()` call in the `__init__` method. The use of register_buffer in PyTorch is not strictly necessary for all use cases but offers several advantages here. For instance, when we use the CausalAttention class in our LLM, buffers are automatically moved to the appropriate device (CPU or GPU) along with our model, which will
be relevant when training our LLM. This means we don’t need to manually ensure these tensors are on the same device as your model parameters, avoiding device mismatch errors.<br><br>
We can use the CausalAttention class as follows, similar to SelfAttention previously:

In [35]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


## 4. Extending single-head attention to multi-head attention
Our final step will be to extend the previously implemented causal attention class over <b>multiple heads</b>. This is also called <b>multi-head attention.</b><br><br>
    The term “multi-head” refers to dividing the attention mechanism into multiple “heads,” each operating independently. In this context, a single causal attention module can be considered single-head attention, where there is only one set of attention weights processing the input sequentially.<br><br>
    We will tackle this expansion from causal attention to multi-head attention. First, we will intuitively build a multi-head attention module by stacking multiple CausalAttention modules. Then we will then implement the same multi-head attention module in a more complicated but more computationally efficient way.<br>

### a. Stacking multiple single-head attention layers
In practical terms, implementing multi-head attention involves creating multiple instances of the self-attention mechanism, each with its own weights,
and then combining their outputs. this is important for the kind accuracy we are going for.<br><br>
  
This will  run the attention mechanism multiple times (in parallel). Here is a simple MultiHeadAttentionWrapper class that stacks multiple instances of our previously implemented CausalAttention module.

In [38]:
class MultiHeadAttentionWrapper(nn.Module):
     def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
         super().__init__()
         self.heads = nn.ModuleList(
             [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
         )
         
     def forward(self, x):
         return torch.cat([head(x) for head in self.heads], dim=-1)

For example, if we use this MultiHeadAttentionWrapper class with two attention heads (via `num_heads=2`) and CausalAttention output dimension d_out=2, we get a fourdimensional context vector (`d_out*num_heads=4`). To illustrate this further with a concrete example, we can use the MultiHeadAttentionWrapper class similar to the CausalAttention class before:

In [33]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
     d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


We can combine `CausalAttention` and `MultiHeadAttentionWrapper`s into a single `MultiHeadAttention` class. Also, in addition to merging the `MultiHeadAttentionWrapper` with the `CausalAttention` code, we will make some other modifications to implement multi-head attention more efficiently.<br>
 In the `MultiHeadAttentionWrapper`, multiple heads are implemented by creating a list of CausalAttention objects `(self.heads)`, each representing a separate attention head. The `CausalAttention` class independently performs the attention mechanism, and the results from each head are concatenated. In contrast, the following `MultiHeadAttention` class integrates the multi-head functionality within a single class.<br>
It splits the input into multiple heads by reshaping the projected `query`, `key`, and `value` tensors and then combines the results from these heads after computing attention.<br><br>
Let’s take a look at the MultiHeadAttention clas

In [43]:
class MultiHeadAttention(nn.Module):
     def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
         super().__init__()
         assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
         self.d_out = d_out
         self.num_heads = num_heads
         self.head_dim = d_out // num_heads # Reduces the projection dim to match the desired output dim
         
         self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
         self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
         self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

         # Uses a Linear layer to combine head outputs
         self.out_proj = nn.Linear(d_out, d_out)
         self.dropout = nn.Dropout(dropout)
         self.register_buffer(
             "mask",
             torch.triu(torch.ones(context_length, context_length), diagonal=1)
         )
         
     def forward(self, x):
         b, num_tokens, d_in = x.shape

         # Tensor shape: (b, num_tokens, d_out)
         keys = self.W_key(x)
         queries = self.W_query(x)
         values = self.W_value(x)

         # We implicitly split the matrix by adding a num_heads dimension.  
         # Then we unroll the last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
         keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
         values = values.view(b, num_tokens, self.num_heads, self.head_dim)
         queries = queries.view(
             b, num_tokens, self.num_heads, self.head_dim
         )

         # Transposes from shape (b, num_tokens, num_heads, head_dim) 
         # to (b, num_heads, num_tokens, head_dim)
         keys = keys.transpose(1, 2)
         queries = queries.transpose(1, 2)
         values = values.transpose(1, 2)

         attn_scores = queries @ keys.transpose(2, 3) # Computes dot product for each head
         mask_bool = self.mask.bool()[:num_tokens, :num_tokens] # Masks truncated to the number of tokens
        
         attn_scores.masked_fill_(mask_bool, -torch.inf) # Uses the mask to fill attention scores
         
         attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
         attn_weights = self.dropout(attn_weights)
         
         context_vec = (attn_weights @ values).transpose(1, 2) # Tensor shape: (b, num_tokens, n_heads, head_dim)
         context_vec = context_vec.contiguous().view(
             b, num_tokens, self.d_out
         ) # Combines heads, where self.d_out = self.num_heads * self.head_dim
         
         context_vec = self.out_proj(context_vec) # Adds an optional linear projection
         return context_vec

Even though the reshaping (`.view`) and transposing (`.transpose`) of tensors inside the `MultiHeadAttention` class looks very mathematically complicated, the `MultiHeadAttention` class implements the same concept as the `MultiHeadAttentionWrappe`r earlier.<br>

 On a big-picture level, in the previous `MultiHeadAttentionWrapper`, we stacked multiple single-head attention layers that we combined into a multi-head attention layer. The `MultiHeadAttention` class takes an integrated approach. It starts with a multi-head layer and then internally splits this layer into individual attention heads.<br>
 
 The splitting of the query, key, and value tensors is achieved through tensor reshaping and transposing operations using PyTorch’s `.view` and `.transpose` methods. The input is first transformed (via linear layers for queries, keys, and values) and then reshaped to represent multiple heads.<br>
 
 The key operation is to split the `d_out` dimension into `num_head`s and `head_dim`, where `head_dim = d_out / num_heads`. This splitting is then achieved using the `.view` method: a tensor of dimensions `(b, num_tokens, d_out)` is reshaped to dimension `(b, num_tokens, num_heads, head_dim)`.

The tensors are then transposed to bring the `num_heads` dimension before the `num_tokens` dimension, resulting in a shape of `(b, num_heads, num_tokens, head_dim)`. This transposition is crucial for correctly aligning the queries, keys, and values across the different heads and performing batched matrix multiplications efficiently. <br>
To illustrate this batched matrix multiplication, suppose we have the tensor below, Then perform a batched matrix multiplication between the tensor itself and a view of the tensor where we transposed the last two dimensions, `num_tokens` and `head_dim`:

In [44]:
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
 [0.8993, 0.0390, 0.9268, 0.7388],
 [0.7179, 0.7058, 0.9156, 0.4340]],
 [[0.0772, 0.3565, 0.1479, 0.5331],
 [0.4066, 0.2318, 0.4545, 0.9737],
 [0.4606, 0.5159, 0.4220, 0.5786]]]])

print(a @ a.transpose(2, 3))


tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


In this case, the matrix multiplication implementation in PyTorch handles the fourdimensional input tensor so that the matrix multiplication is carried out between the two
last dimensions (num_tokens, head_dim) and then repeated for the individual heads.
 For instance, the preceding becomes a more compact way to compute the matrix
multiplication for each head separately:

In [45]:
first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head:\n", first_res)
second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)

First head:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])

Second head:
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


The results are exactly the same results as those we obtained when using the batched
matrix multiplication `print(a @ a.transpose(2, 3))`:

Continuing with MultiHeadAttention, after computing the attention weights and context vectors, the context vectors from all heads are transposed back to the shape `(b, num_tokens, num_heads, head_dim)`. These vectors are then reshaped (flattened) into the shape `(b, num_tokens, d_out)`, effectively combining the outputs from all heads.<br>
 Additionally, we added an output projection layer (`self.out_proj`) to `MultiHeadAttention` after combining the heads, which is not present in the `CausalAttention` class. This output projection layer is not strictly necessary, but it is commonly used in many LLM architectures, which is why I
added it here for completeness.
 Even though the MultiHeadAttention class looks more complicated than the MultiHeadAttentionWrapper due to the additional reshaping and transposition of tensors, it is more efficient. The reason is that we only need one matrix multiplication to compute the keys, for instance, `keys = self.W_key(x)` (the same is true for the queries and values). In the MultiHeadAttentionWrapper, we needed to repeat this matrix multiplication, which is computationally one of the most expensive steps, for each
attention head.<br>
 The MultiHeadAttention class can be used similar to the SelfAttention and CausalAttention classes we implemented earlier:

In [42]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


We have now implemented the MultiHeadAttention class that we will use when we
implement and train the LLM. Note that while the code is fully functional, I used
relatively small embedding sizes and numbers of attention heads to keep the outputs
readable.
 For comparison, the smallest GPT-2 model (117 million parameters) has 12 attention heads and a context vector embedding size of 768. The largest GPT-2 model (1.5
billion parameters) has 25 attention heads and a context vector embedding size of
1,600. The embedding sizes of the token inputs and context embeddings are the same
in GPT models (d_in = d_out).