In [2]:
import torch
print(torch.__version__)

2.1.2+cu118


- This chapter covers attention mechanisms, the engine of LLMs:

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/01.webp" width="700px">,

## 3.3 Attending to different parts of the input with self-attention

### 3.3.1 A simple self-attention mechanism without trainable weights


- This section explains a very simplified variant of self-attention, which does not contain any trainable weights
- This is purely for illustration purposes and NOT the attention mechanism that is used in transformers
- The next section, section 3.3.2, will extend this simple attention mechanism to implement the real self-attention mechanism

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/07.webp" width="700px">,

In [3]:
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your     (x^1)
        [0.55, 0.87, 0.66],  # journey  (x^2)
        [0.57, 0.85, 0.64],  # starts   (x^3)
        [0.22, 0.58, 0.33],  # with     (x^4)
        [0.77, 0.25, 0.10],  # one      (x^5)
        [0.05, 0.80, 0.55]   # step     (x^6)
    ]
)

print(inputs.shape)
print(inputs)

torch.Size([6, 3])
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])


- (In this book, we follow the common machine learning and deep learning convention where training examples are represented as rows and feature values as columns; in the case of the tensor shown above, each row represents a word, and each column represents an embedding dimension)

- The primary objective of this section is to demonstrate how the context vector 
 is calculated using the second input sequence, `x^(2)`, as a `query token`

- The figure depicts the initial step in this process, which involves calculating the attention scores ω between 
 x^(2) (`query`) and all other input elements through a dot product operation

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/08.webp" width="700px">,


In [4]:
query = inputs[1]             # A
attn_scores_2 = torch.empty(inputs.shape[0])

print(f'query shape: {query.shape} and attn_scores_2 shape: {attn_scores_2.shape}')
query, attn_scores_2

#A The second input token serves as query

query shape: torch.Size([3]) and attn_scores_2 shape: torch.Size([6])


(tensor([0.5500, 0.8700, 0.6600]), tensor([0., 0., 0., 0., 0., 0.]))

In [5]:
# computing attention scores
for i, x_i in enumerate(inputs):
    #print(x_i.shape)
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)
print("attn_scores_2 shape: ", attn_scores_2.shape)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
attn_scores_2 shape:  torch.Size([6])


- In the context of self attention mechanisms, the dot product determines the extent to which elements in a
 sequence attend to each other: the higher the dot product, the higher the similarity and attention score between two elements.

In the next step, as shown in Figure 3.9, we normalize each of the attention scores that  we computed previously.

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/09.webp" width="700px">,

**Figure 3.9** - After computing the attention scores ω21 to ω2T with respect to the input query x(2) , the next step is to obtain the attention weights α21 to α2T by normalizing the attention scores.

- The main goal behind the normalization shown in Figure 3.9 is to obtain attention weights that sum up to 1. This normalization is a convention that is useful for interpretation and for maintaining training stability in an LLM. Here's a straightforward method for achieving this normalization step:

In [6]:
attn_weigths_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights: ", attn_weigths_2_tmp)
print("Sum attn weights: ", attn_weigths_2_tmp.sum())

Attention weights:  tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum attn weights:  tensor(1.0000)


- In practice, it's more common and advisable to use the softmax function for normalization.
 This approach is better at managing extreme values and offers more favorable gradient
 properties during training. Below is a basic implementation of the softmax function for
 normalizing the attention scores:

In [7]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weigths_2_naive = softmax_naive(attn_scores_2)
print("Attention weights: ", attn_weigths_2_naive)
print("Sum attn weights: ", attn_weigths_2_naive.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum attn weights:  tensor(1.)


- In addition, the softmax function ensures that the attention weights are always positive.
 This makes the output interpretable as probabilities or relative importance, where higher
 weights indicate greater importance.

- Note that this naive softmax implementation (softmax_naive) may encounter numerical
 instability problems, such as overflow and underflow, when dealing with large or small input
 values. Therefore, in practice, it's advisable to use the PyTorch implementation of softmax,
 which has been extensively optimized for performance:

In [8]:
attn_weigths_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights: ", attn_weigths_2)
print("Sum attn weights: ", attn_weigths_2.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum attn weights:  tensor(1.)


- Now that we computed the normalized attention weights, we are ready for the final step
 illustrated in Figure 3.10: calculating the context vector z(2) by multiplying the embedded
 input tokens, x(i), with the corresponding attention weights and then summing the resulting
 vectors.

 <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/10.webp" width="700px">,

  **Figure 3.10** The final step, after calculating and normalizing the attention scores to obtain the attention
 weights for query x(2) , is to compute the context vector z(2) . This context vector is a combination of all input
 vectors x(1) to x(T) weighted by the attention weights.


In [9]:
query = inputs[1]  # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weigths_2[i] * x_i
    print(f'iteartion i = {i} : {context_vec_2}')
print(context_vec_2)

iteartion i = 0 : tensor([0.0596, 0.0208, 0.1233])
iteartion i = 1 : tensor([0.1904, 0.2277, 0.2803])
iteartion i = 2 : tensor([0.3234, 0.4260, 0.4296])
iteartion i = 3 : tensor([0.3507, 0.4979, 0.4705])
iteartion i = 4 : tensor([0.4340, 0.5250, 0.4813])
iteartion i = 5 : tensor([0.4419, 0.6515, 0.5683])
tensor([0.4419, 0.6515, 0.5683])


### 3.3.2 Computing attention weights for all input tokens

 In the previous section, we computed `attention weights` and the `context vector for input 2`,
 as shown in the highlighted row in Figure 3.11. Now, we are extending this computation to
 calculate attention weights and context vectors for all inputs.
 
  <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/11.webp" width="700px">,


 We follow the same three steps as before, as summarized in Figure 3.12, except that we
 make a few modifications in the code to compute all context vectors instead of only the
 second context vector, z(2).

   <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/12.webp" width="700px">,


- **Step 1:** `compute attention scores for all pairs of inputs`

In [10]:
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


- When computing the preceding attention score tensor, we used for-loops in Python.
 However, for-loops are generally slow, and we can achieve the same results using matrix
 multiplication:

In [11]:
attn_scores = inputs @ inputs.T  # shape(6, 3) * (3, 6) = output shape(6, 6)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


- **Step 2:** we now normalize each row so that the values in each row sum to 1:

In [12]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


- In the context of using PyTorch, the dim parameter in functions like torch.softmax specifies
 the dimension of the input tensor along which the function will be computed. By setting
 dim=-1, we are instructing the softmax function to apply the normalization along the last
 dimension of the attn_scores tensor. If attn_scores is a 2D tensor (for example, with a
 shape of [rows, columns]), dim=-1 will normalize across the columns so that the values in
 each row (summing over the column dimension) sum up to 1.

In [13]:
row_2_sum = sum( [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print(f'row_2_sum: {row_2_sum}')
print(f'All row sums: {attn_weights.sum(dim=-1)}')

row_2_sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


- **Step 3:** Using these attention weights we compute all context vectors via matrix multiplicaiton:

In [14]:
all_context_vectors = attn_weights @ inputs
print(all_context_vectors)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


- We can double-check that the code is correct by comparing the 2nd row with the context
 vector z(2) that we computed previously in section 3.3.1:

In [15]:
print("Previous 2nd context vector: ", context_vec_2)

Previous 2nd context vector:  tensor([0.4419, 0.6515, 0.5683])


-  Based on the result, we can see that the previously calculated context_vec_2 matches the
 second row in the previous tensor exactly:

> Note:
This concludes the code walkthrough of a simple self-attention mechanism. In the next
 section, we will add trainable weights, enabling the LLM to learn from data and improve its
 performance on specific tasks.

## 3.4 Implementing self-attention with trainable weights

- In this section, we are implementing the self-attention mechanism that is used in the
 original transformer architecture, the GPT models, and most other popular LLMs. This self
attention mechanism is also called scaled dot-product attention. Figure 3.13 provides a
 mental model illustrating how this self-attention mechanism fits into the broader context of
 implementing an LLM.

    <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/13.webp" width="700px">,


 - The most notable difference is the introduction of weight matrices that are updated
 during model training. These trainable weight matrices are crucial so that the model
 (specifically, the attention module inside the model) can learn to produce "good" context
 vectors. (Note that we will train the LLM in chapter 5.)
 
 - We will tackle this self-attention mechanism in the two subsections. First, we will code it
 step-by-step as before. Second, we will organize the code into a compact Python class that
 can be imported into an LLM architecture, which we will code in chapter 4.

### 3.4.1 Computing the attention weights step by step

- We will implement the self-attention mechanism step by step by introducing the three
 trainable weight matrices Wq, Wk, and Wv. These three matrices are used to project the
 embedded input tokens, x(i), into query, key, and value vectors as illustrated in Figure 3.14.

    <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/14.webp" width="700px">,

- Figure 3.14 In the first step of the self-attention mechanism with trainable weight matrices, we compute query
 (q), key (k), and value (v) vectors for input elements x. Similar to previous sections, we designate the second
 input, x(2) , as the query input. The query vector q(2) is obtained via matrix multiplication between the input x(2)
 and the weight matrix Wq.  Similarly, we obtain the key and value vectors via matrix multiplication involving the
 weight matrices Wk and Wv

- Earlier in section 3.3.1, we defined the second input element x(2) as the query when we
 computed the simplified attention weights to compute the context vector z(2). Later, in
 section 3.3.2, we generalized this to compute all context vectors z(1) ... z(T) for the six
word input sentence "Your journey starts with one step."

- Similarly, we will start by computing only one context vector, z(2), for illustration
 purposes. In the next section, we will modify this code to calculate all context vectors

In [16]:
x_2 = inputs[1]             #A
d_in = inputs.shape[1]      #B
d_out = 2                   #C

#A The second input element
#B The input embedding size, d = 3
#C The output embedding size, d_out = 2

> Note that in GPT-like models, the input and output dimensions are usually the same, but for
 illustration purposes, to better follow the computation, we choose different input (d_in=3)
 and output (d_out=2) dimensions here.

- Next, we initialize the three weight matrices Wq, Wk, and Wv that are shown in Figure 3.14


In [17]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

>Note that we are setting requires_grad=False to reduce clutter in the outputs for
 illustration purposes, but if we were to use the weight matrices for model training, we
 would set requires_grad=True to update these matrices during model training.

- Next, we compute the query, key, and value vectors as shown earlier in Figure 3.14:


In [18]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


**WEIGHT PARAMETERS VS ATTENTION WEIGHTS**

- Note that in the weight matrices W, the term "weight" is short for "weight
 parameters," the values of a neural network that are optimized during training. This
 is not to be confused with the attention weights. As we already saw in the previous
 section, attention weights determine the extent to which a context vector depends on
 the different parts of the input, i.e., to what extent the network focuses on different
 parts of the input.

- In summary, weight parameters are the fundamental, learned coefficients that define
 the network's connections, while attention weights are dynamic, context-specific
 values.

- Even though our temporary goal is to only compute the one context vector, z(2), we still
 `require the key and value vectors for all input elements` `as they are involved in computing
 the attention weights with respect to the query q(2)`, as illustrated in Figure 3.14.
    
    We can obtain all keys and values via matrix multiplication:

In [19]:
keys = inputs @ W_key
values = inputs @ W_value
print(f'keys.shape: {keys.shape}')
print(f'values.shape: {values.shape}')

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


- The second step is now to compute the attention scores, as shown in Figure 3.15.

    <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/15.webp" width="700px">,

- Figure 3.15 The attention score computation is a dot-product computation similar to what we have used in the
 simplified self-attention mechanism in section 3.3. The new aspect here is that we are not directly computing
 the dot-product between the input elements but using the query and key obtained by transforming the inputs
 via the respective weight matrices. 

 First, let's compute the attention score ω22:

In [20]:
keys_2 = keys[1]
attn_scores_22 = query_2.dot(keys_2)
print(attn_scores_22)

tensor(1.8524)


- Again, we can generalize this computation to all attention scores via matrix multiplication:


In [21]:
attn_scores_2 = query_2 @ keys.T # All attention scores for given query_2 (2) @ (2, 6) => (6)
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [22]:
attn_scores_2.shape

torch.Size([6])

- The third step is now going from the attention scores to the attention weights, as illustrated
 in Figure 3.16.

    <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/16.webp" width="700px">,

- Normalize these attention scores using softmax

- The difference to earlier is  that we now scale the attention scores by dividing them by the square root of the
 embedding dimension of the keys, (note that taking the square root is mathematically the
 same as exponentiating by 0.5)

In [23]:
d_k = keys.shape[-1]  # taking embedding dimension
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


 > THE RATIONALE BEHIND SCALED-DOT PRODUCT ATTENTION
 
- The reason for the normalization by the embedding dimension size is to improve the
 training performance by avoiding small gradients. For instance, when scaling up the
 embedding dimension, which is typically greater than thousand for GPT-like LLMs,
 large dot products can result in very small gradients during backpropagation due to
 the softmax function applied to them. As dot products increase, the softmax function
 behaves more like a step function, resulting in gradients nearing zero. These small
 gradients can drastically slow down learning or cause training to stagnate.
 
 - The scaling by the square root of the embedding dimension is the reason why this
 self-attention mechanism is also called scaled-dot product attention.

- Now, the final step is to compute the context vectors, as illustrated in Figure 3.17.

    <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/17.webp" width="700px">,

- **In the final step of the self-attention computation, we compute the context vector by combining all  value vectors via the attention weights**

 - Similar to section 3.3, where we computed the context vector as a weighted sum over the
 input vectors, we now compute the context vector as a weighted sum over the value
 vectors. Here, the attention weights serve as a weighting factor that weighs the respective
 importance of each value vector. Similar to section 3.3, we can use matrix multiplication to
 obtain the output in one step

In [24]:
print(f'Shape of attn_weights_2: {attn_weights_2.shape}')
print(f'Shape of values: {values.shape}')

Shape of attn_weights_2: torch.Size([6])
Shape of values: torch.Size([6, 2])


In [25]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


- So far, we only computed a `single context vector, z(2)`. In the next section, we will generalize the code to compute `all context vectors` in the input sequence, `z(1) to z(T)`.

> WHY QUERY, KEY, AND VALUE?
 
- The terms "key," "query," and "value" in the context of attention mechanisms are
 borrowed from the domain of information retrieval and databases, where similar
 concepts are used to store, search, and retrieve information.
 
- A "query" is analogous to a search query in a database. It represents the current
 item (e.g., a word or token in a sentence) the model focuses on or tries to
 understand. The query is used to probe the other parts of the input sequence to
 determine how much attention to pay to them.

- The "key" is like a database key used for indexing and searching. In the attention
 mechanism, each item in the input sequence (e.g., each word in a sentence) has an
 associated key. These keys are used to match with the query.

- The "value" in this context is similar to the value in a key-value pair in a database. It
 represents the actual content or representation of the input items. Once the model
 determines which keys (and thus which parts of the input) are most relevant to the
 query (the current focus item), it retrieves the corresponding values.

### 3.4.2 Implementing a compact self-attention Python class

- In the previous sections, we have gone through a lot of steps to compute the self-attention
 outputs. This was mainly done for illustration purposes so we could go through one step at
 a time. In practice, with the LLM implementation in the next chapter in mind, it is helpful to
 organize this code into a Python class as follows:

In [26]:
# A compact self-attention class
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec        

- The __init__ method initializes trainable weight matrices (W_query, W_key, and
 W_value) for queries, keys, and values, each transforming the input dimension d_in to an
 output dimension d_out.

- During the forward pass, using the forward method, we compute the attention scores
 (attn_scores) by multiplying queries and keys, normalizing these scores using softmax.
 Finally, we create a context vector by weighting the values with these normalized attention
 scores

In [27]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)  # d_in = 3, d_out = 2
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [28]:
print(context_vec_2)

tensor([0.3061, 0.8210])


- Figure 3.18 summarizes the self-attention mechanism we just implemented.

    <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/18.webp" width="700px">,

- In self-attention, we transform the input vectors in the input matrix X with the three weight
 matrices, Wq, Wk, and Wv. Then, we compute the attention weight matrix based on the resulting queries `(Q)`
 and keys `(K)`. Using the attention weights and values `(V)`, we then compute the context vectors (Z). (For visual
 clarity, we focus on a single input text with n tokens in this figure, not a batch of multiple inputs. Consequently,
 the 3D input tensor is simplified to a 2D matrix in this context. This approach allows for a more straightforward
 visualization and understanding of the processes involved. Also, for consistency with later figures, the values in
 the attention matrix do not depict the real attention weights.)

- As shown in Figure 3.18, self-attention involves the trainable weight matrices Wq, Wk, and
 Wv. These matrices transform input data into queries, keys, and values, which are crucial
 components of the attention mechanism. As the model is exposed to more data during
 training, it adjusts these trainable weights, as we will see in upcoming chapters.

- We can improve the SelfAttention_v1 implementation further by utilizing PyTorch's
 nn.Linear layers, which effectively perform matrix multiplication when the bias units are
 disabled. Additionally, a significant advantage of using nn.Linear instead of manually
 implementing nn.Parameter(torch.rand(...)) is that nn.Linear has an optimized weight
 initialization scheme, contributing to more stable and effective model training

In [29]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [30]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


> Note that SelfAttention_v1 and SelfAttention_v2 give different outputs because they use different initial weights for the weight matrices since nn.Linear uses a more  sophisticated weight initialization scheme.

Exercise 3.1 have done in the `exercise_solutions.ipynb` file

- In the next section, we will make enhancements to the self-attention mechanism, focusing
specifically on incorporating causal and multi-head elements. 

- The causal aspect involves
modifying the attention mechanism to prevent the model from accessing future information
in the sequence, which is crucial for tasks like language modeling, where each word prediction should only depend on previous words.

- The multi-head component involves splitting the attention mechanism into multiple
"heads." Each head learns different aspects of the data, allowing the model to
simultaneously attend to information from different representation subspaces at different
positions. This improves the model's performance in complex tasks.

## 3.5 Hiding future words with causal attention

- In this section, we modify the standard self-attention mechanism to create a causal
 attention mechanism, which is essential for developing an LLM in the subsequent chapters.

 - Causal attention, also known as masked attention, is a specialized form of self-attention.
 It restricts a model to only consider previous and current inputs in a sequence when
 processing any given token. This is in contrast to the standard self-attention mechanism,
 which allows access to the entire input sequence at once.
 
 - Consequently, when computing attention scores, the causal attention mechanism
 ensures that the model only factors in tokens that occur at or before the current token in
 the sequence.
 
 - To achieve this in GPT-like LLMs, for each token processed, we mask out the future
 tokens, which come after the current token in the input text, as illustrated in Figure 3.19.

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/19.webp" width="700px">,
 

- Figure 3.19 In causal attention, we mask out the attention weights above the diagonal such that for a given  input, the LLM can't access future tokens when computing the context vectors using the attention weights. For  example, for the word "journey" in the second row, we only keep the attention weights for the words before  ("Your") and in the current position ("journey").

- As illustrated in Figure 3.19, we mask out the attention weights above the diagonal, and we
 normalize the non-masked attention weights, such that the attention weights sum to 1 in
 each row. In the next section, we will implement this masking and normalization procedure
 in code.

### 3.5.1 Applying a causal attention mask

- In this section, we implement the causal attention mask in code. We start with the
 procedure summarized in Figure 3.20.

 <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/20.webp" width="700px">,

 - Figure 3.20 One way to obtain the masked attention weight matrix in causal attention is to apply the softmax
 function to the attention scores, zeroing out the elements above the diagonal and normalizing the resulting
 matrix.

 - In the first step illustrated in Figure 3.20, we compute the attention weights using the
 softmax function as we have done in previous sections

In [31]:
queries = sa_v2.W_query(inputs)   #A
keys = sa_v2.W_key(inputs)
values = sa_v2.W_value(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

#A Reuse the query and key weight matrices of the SelfAttention_v2 object from the previous section for convenience

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


- We can implement step 2 in Figure 3.20 using PyTorch's tril function to create a mask
 where the values above the diagonal are zero:

In [32]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


- Now, we can multiply this mask with the attention weights to zero out the values above the
 diagonal:

In [33]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


- The `third step` in Figure 3.20 is to renormalize the attention weights to sum up to 1 again in
 each row. We can achieve this by dividing each element in each row by the sum in each row
 

In [34]:
row_sums = masked_simple.sum(dim=1, keepdim=True)
print(row_sums.shape)
print("row sums:\n", row_sums)

masked_simple_norm = masked_simple / row_sums
print("\nmasked attention weights: \n", masked_simple_norm)

torch.Size([6, 1])
row sums:
 tensor([[0.1921],
        [0.3700],
        [0.5357],
        [0.6775],
        [0.8415],
        [1.0000]], grad_fn=<SumBackward1>)

masked attention weights: 
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


**INFORMATION LEAKAGE**

- When we apply a mask and then renormalize the attention weights, it might initially
 appear that information from future tokens (which we intend to mask) could still
 influence the current token because their values are part of the softmax calculation.
 However, the key insight is that when we renormalize the attention weights after
 masking, what we're essentially doing is recalculating the softmax over a smaller
 subset (since masked positions don't contribute to the softmax value).

- The mathematical elegance of softmax is that despite initially including all positions
 in the denominator, after masking and renormalizing, the effect of the masked
 positions is nullified — they don't contribute to the softmax score in any meaningful
 way.
- In simpler terms, after masking and renormalization, the distribution of attention
 weights is as if it was calculated only among the unmasked positions to begin with.
 This ensures there's no information leakage from future (or otherwise masked)
 tokens as we intended.

- While we could be technically done with implementing causal attention at this point, we can take advantage of a mathematical property of the softmax function and implement the computation of the masked attention weights more efficiently in fewer steps, as shown in Figure 3.21

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/21.webp" width="700px">,

- Figure 3.21 A more efficient way to obtain the masked attention weight matrix in causal attention is to mask  the attention scores with negative infinity values before applying the softmax function.

- The softmax function converts its inputs into a probability distribution. When negative
 infinity values (-∞) are present in a row, the softmax function treats them as zero
 probability. (Mathematically, this is because e-∞ approaches 0.)

- We can implement this more efficient masking "trick" by creating a mask with 1's above
 the diagonal and then replacing these 1's with negative infinity (-inf) values

In [35]:
print(attn_scores)

tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
        [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)


In [36]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
print(mask)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


In [37]:
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [38]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


- As we can see based on the output, the values in each row sum to 1, and no further
 normalization is necessary:

We could now use the modified attention weights to compute the context vectors via `context_vec = attn_weights @ values`, as in section 3.4. However, in the next section, we first cover another minor tweak to the causal attention mechanism that is useful for reducing overfitting when training LLMs

### 3.5.2 Masking additional attention weights with dropout

 - Dropout in deep learning is a technique where randomly selected hidden layer units are
 ignored during training, effectively "dropping" them out. This method helps prevent
 overfitting by ensuring that a model does not become overly reliant on any specific set of
 hidden layer units. It's important to emphasize that dropout is only used during training
 and is disabled afterward.
 - In the transformer architecture, including models like GPT, dropout in the attention
 mechanism is typically applied in two specific areas: after calculating the attention scores
 or after applying the attention weights to the value vectors.
 - Here, we will apply the dropout mask after computing the attention weights, as
 illustrated in Figure 3.22, because it's the more common variant in practice.

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/22.webp" width="700px">,


- Figure 3.22 Using the causal attention mask (upper left), we apply an additional dropout mask (upper right) to zero out additional attention weights to reduce overfitting during training.

 In the following code example, we use a dropout rate of 50%, which means masking out
 half of the attention weights. (When we train the GPT model in later chapters, we will use a  lower dropout rate, such as 0.1 or 0.2.)
 
 In the following code, we apply PyTorch's dropout implementation first to a 6×6 tensor  consisting of ones for illustration purposes:

In [39]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)  # we choose dropout rate of 50%
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


>When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 =2. This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.

In [40]:
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

In [41]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


 - Having gained an understanding of causal attention and dropout masking, we will
 develop a concise Python class in the following section. This class is designed to facilitate
 the efficient application of these two techniques.

### 3.5.3 Implementing a compact causal attention class


- In this section, we will now incorporate the causal attention and dropout modifications into
 the SelfAttention Python class we developed in section 3.4. This class will then serve as a
 template for developing multi-head attention in the upcoming section, which is the final
 attention class we implement in this chapter.
- But before we begin, one more thing is to ensure that the code can handle batches
 consisting of more than one input so that the CausalAttention class supports the batch
 outputs produced by the data loader we implemented in chapter 2.
- For simplicity, to simulate such batch inputs, we duplicate the input text example:

In [42]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) #2 inputs with 6 tokens each, and each token has embedding dimension 3

torch.Size([2, 6, 3])


In [43]:
b, num_tokens, d_in = batch.shape 
b, num_tokens, d_in

(2, 6, 3)

In [44]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)                                          #A
        self.register_buffer(                                                       #A2
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)      #B
        )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape      #C
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)   #C
        attn_scores.masked_fill_(                      #D
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # PyTorch automatically applies the mask to all batches at once.
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values 
        return context_vec  
    
#A Compared to the previous SelfAttention_v1 class, we added a dropout layer
#B The register_buffer call is also a new addition (more information is provided in the following text)
#C We transpose dimensions 1 and 2, keeping the batch dimension at the first position (0)
#D In PyTorch, operations with a trailing underscore are performed in-place, avoiding unnecessary memory copies
#A2:
    #register_buffer takes two argumentsname: a string specifying the name under which the buffer will be registered (in your case, 'mask').
    #tensor: the tensor itself that you want to register.PyTorch automatically associates the string name passed to register_buffer with a class attribute. 
    #In your case, self.register_buffer('mask', tensor) makes the tensor available as self.mask.'''

- register_buffer is a special method in PyTorch, provided by the nn.Module class. It is used to register tensors with the model as "buffers".

- What are buffers in PyTorch?
Buffers are tensors that are stored inside the model, but are not model parameters. Unlike parameters (nn.Parameter), buffers do not require gradients to be calculated and are not updated during training. However, they are saved and loaded with the model, making them useful for storing things like:
Masks, 
Constants, 
Indices,
Normalizing coefficients, etc.

> While all added code lines should be familiar from previous sections, we now added a
 `self.register_buffer()` call in the __init__ method. The use of register_buffer in
 PyTorch is not strictly necessary for all use cases but offers several advantages here. For
 instance, when we use the CausalAttention class in our LLM, `buffers` are automatically
 `moved to the appropriate device (CPU or GPU) along with our model`, which will be relevant
 when training the LLM in future chapters. This means we don't need to manually ensure
 these tensors are on the same device as your model parameters, avoiding device mismatch
 errors.

In [45]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape: ", context_vecs.shape)

context_vecs.shape:  torch.Size([2, 6, 2])


The resulting context vector is a 3D tensor where each token is now represented by a 2D
 embedding:

In [46]:
context_vecs

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)

- Figure 3.23 provides a mental model that summarizes what we have accomplished so far.

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/23.webp" width="700px">,


 > As illustrated in Figure 3.23, in this section, we focused on the concept and implementation
 of causal attention in neural networks. In the next section, we will expand on this concept
 and `implement a multi-head attention module that implements several of such causal  attention mechanisms in parallel.`

## 3.6 Extending single-head attention to multi-head attention

- In this final section of this chapter, we are extending the previously implemented causal
 attention class over multiple-heads. This is also called multi-head attention.

 The term "multi-head" refers to dividing the attention mechanism into multiple "heads,"
 each operating independently. In this context, a `single causal attention module` can be
 considered `single-head attention`, where there is only one set of attention weights
 processing the input sequentially.

> In the following subsections, we will tackle this expansion from causal attention to multi
head attention. The first subsection will intuitively build a multi-head attention module by
 stacking multiple CausalAttention modules for illustration purposes. The second
 subsection will then implement the same multi-head attention module in a more
 complicated but computationally more efficient way

### 3.6.1 Stacking multiple single-head attention layers

- In practical terms, implementing multi-head attention involves creating multiple instances of the self-attention mechanism (depicted earlier in Figure 3.18 in section 3.4.1), each with its own weights, and then combining their outputs. Using multiple instances of the self attention mechanism can be computationally intensive, but it's crucial for the kind of complex pattern recognition that models like transformer-based LLMs are known for.

Figure 3.24 illustrates the structure of a multi-head attention module, which consists of  multiple single-head attention modules, as previously depicted in Figure 3.18, stacked on  top of each other.

 <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/24.webp" width="700px">


> Figure 3.24 The multi-head attention module in this figure depicts two single-head attention modules stacked
 on top of each other. So, instead of using a single matrix Wv for computing the value matrices, in a multi-head
 attention module with two heads, we now have two value weight matrices: Wv1 and Wv2. The same applies to
 the other weight matrices, Wq and Wk. We obtain two sets of context vectors Z1 and Z2 that we can combine
 into a single context vector matrix Z.

 - As mentioned before, the main idea behind multi-head attention is to run the attention
 mechanism multiple times (in parallel) with different, learned linear projections — the
 results of multiplying the input data (like the query, key, and value vectors in attention
 mechanisms) by a weight matrix.

In [47]:
# A wrapper class to implement multi-head attention
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

- For example, if we use this MultiHeadAttentionWrapper class with two attention heads (via  num_heads=2) and CausalAttention output dimension d_out=2, this results in a 4 dimensional context vectors (d_out*num_heads=4), as illustrated in Figure 3.25.

 <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/25.webp" width="700px">


>Figure 3.25 Using the MultiHeadAttentionWrapper, we specified the number of attention heads
 (num_heads). If we set num_heads=2, as shown in this figure, we obtain a tensor with two sets of context
 vector matrices. In each context vector matrix, the rows represent the context vectors corresponding to the
 tokens, and the columns correspond to the embedding dimension specified via d_out=4. We concatenate
 these context vector matrices along the column dimension. Since we have 2 attention heads and an
 embedding dimension of 2, the final embedding dimension is 2 × 2 = 4.

In [51]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)

print(context_vecs)
print(f'context_vecs.shape: {context_vecs.shape}')

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


- The first dimension of the resulting context_vecs tensor is 2 since we have two input texts
 (the input texts are duplicated, which is why the context vectors are exactly the same for
 those). The second dimension refers to the 6 tokens in each input. The third dimension
 refers to the 4-dimensional embedding of each token

**Exercise 3.2** have done in the `exercise_solutions.ipynb` file

- Change the input arguments for the MultiHeadAttentionWrapper(...,
 num_heads=2) call such that the output context vectors are 2-dimensional instead of
 4-dimensional while keeping the setting num_heads=2. Hint: You don't have to modify
 the class implementation; you just have to change one of the other input arguments.

>In this section, we implemented a MultiHeadAttentionWrapper that combined multiple single-head attention modules. However, note that these are processed sequentially via [head(x) for head in self.heads] in the forward method. We can improve this implementation by processing the heads in parallel. One way to achieve this is by computing the outputs for all attention heads simultaneously via matrix multiplication, as we will explore in the next sectio

### 3.6.2 Implementing multi-head attention with weight splits

- In the previous section, we created a MultiHeadAttentionWrapper to implement multi
head attention by stacking multiple single-head attention modules. This was done by
 instantiating and combining several CausalAttention objects.

 - Instead of maintaining two separate classes, MultiHeadAttentionWrapper and
 CausalAttention, we can combine both of these concepts into a single
 MultiHeadAttention class. Also, in addition to just merging the
 MultiHeadAttentionWrapper with the CausalAttention code, we will make some other
 modifications to implement multi-head attention more efficiently. 


> In the `MultiHeadAttentionWrapper`, multiple heads are implemented by creating a list
 of `CausalAttention objects` `(self.heads`), each representing a separate attention head.
 The CausalAttention class independently performs the attention mechanism, and the
 results from each head are concatenated. In contrast, the following `MultiHeadAttention`
 class integrates the multi-head functionality within a single class. It splits the input into
 multiple heads by reshaping the projected query, key, and value tensors and then combines
 the results from these heads after computing attention.

In [53]:
# An efficient multi-head attention class

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads 
        self.head_dim = d_out // num_heads                                       #A
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)                                  #B
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)                                                     #C
        values = self.W_value(x)                                                 #C
        queries = self.W_query(x)                                                #C

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)           #D
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)       #D
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)     #D

        keys = keys.transpose(1, 2)                                              #E
        queries = queries.transpose(1, 2)                                        #E
        values = values.transpose(1, 2)                                          #E

        attn_scores = queries @ keys.transpose(2, 3)                             #F
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]                   #G

        attn_scores.masked_fill_(mask_bool, -torch.inf)                          #H

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)                    #I
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)   #J
        context_vec = self.out_proj(context_vec)                                 #K
        return context_vec

 #A Reduce the projection dim to match desired output dim
 #B Use a Linear layer to combine head outputs
 #C Tensor shape: (b, num_tokens, d_out)
 #D We implicitly split the matrix by adding a `num_heads` dimension. Then we unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
 #E Transpose from shape (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
 #F Compute dot product for each head
 #G Mask truncated to the number of tokens
 #H Use the mask to fill attention scores
 #I Tensor shape: (b, num_tokens, n_heads, head_dim)
 #J Combine heads, where self.d_out = self.num_heads * self.head_dim
 #K Add an optional linear projection

- Even though the reshaping (.view) and transposing (.transpose) of tensors inside the
 MultiHeadAttention class looks very complicated, mathematically, the MultiHeadAttention class impelemnts the same concept as the MultiHeadAttentionWrapper earlier.


- On a big-picture level, in the previous MultiHeadAttentionWrapper, we stacked
 multiple single-head attention layers that we combined into a multi-head attention layer.
The MultiHeadAttention class takes an integrated approach. It starts with a multi-head
 layer and then internally splits this layer into individual attention heads, as illustrated in
 Figure 3.26

  <img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/26.webp" width="700px">


- Figure 3.26 In the MultiheadAttentionWrapper class with two attention heads, we initialized two weight
 matrices Wq1 and Wq2 and computed two query matrices Q1 and Q2 as illustrated at the top of this figure. In
 the MultiheadAttention class, we initialize one larger weight matrix Wq , only perform one matrix
 multiplication with the inputs to obtain a query matrix Q, and then split the query matrix into Q1 and Q2 as
 shown at the bottom of this figure. We do the same for the keys and values, which are not shown to reduce
 visual clutter.

- The splitting of the query, key, and value tensors, as depicted in Figure 3.26, is achieved
 through tensor reshaping and transposing operations using PyTorch's .view and
 .transpose methods. The input is first transformed (via linear layers for queries, keys, and
 values) and then reshaped to represent multiple heads.
 
- The key operation is to split the `d_out` dimension into `num_heads` and `head_dim`, where
 `head_dim = d_out / num_heads`. This splitting is then achieved using the .view method: a
 tensor of dimensions `(b, num_tokens, d_out)` is reshaped to dimension `(b, num_tokens,num_heads, head_dim)`.

The tensors are then transposed to bring the `num_heads` dimension before the num_tokens dimension, resulting in a shape of `(b, num_heads, num_tokens, head_dim)`. This transposition is crucial for correctly aligning the queries, keys, and values across the different heads and performing batched matrix multiplications efficiently.

- Continuing with MultiHeadAttention, after computing the attention weights and context
 vectors, the context vectors from all heads are transposed back to the shape (b,
 num_tokens, num_heads, head_dim). These vectors are then reshaped (flattened) into the
 shape (b, num_tokens, d_out), effectively combining the outputs from all heads.

- Additionally, we added a so-called `output projection layer (self.out_proj)` to
 MultiHeadAttention after combining the heads, which is not present in the
 CausalAttention class. `This output projection layer is not strictly necessary` (see the
 References section in Appendix B for more details), but it `is commonly used in many LLM architectures`, which is why we added it here for completeness.

- Even though the `MultiHeadAttention class` looks more complicated than the
 MultiHeadAttentionWrapper due to the additional reshaping and transposition of tensors,
 it is more efficient. The reason is that we only need one matrix multiplication to compute
 the keys, for instance, `keys = self.W_key(x)` (the same is true for the queries and
 values). In the MultiHeadAttentionWrapper, we needed to repeat this matrix multiplication,
 which is computationally one of the most expensive steps, for each attention head.

The `MultiHeadAttention` class can be used similar to the `SelfAttention` and  `CausalAttention` classes we implemented earlier:

In [56]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape 
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print(f'context_vecs.shape: {context_vecs.shape}')

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


- In this section, we implemented the `MultiHeadAttention` class that we will use in the
 upcoming sections when implementing and training the LLM itself. Note that while the code
 is fully functional, we used relatively small embedding sizes and numbers of attention
 heads to keep the outputs readable.

- For comparison, the smallest GPT-2 model (117 million parameters) has 12 attention
 heads and a context vector embedding size of 768. The largest GPT-2 model (1.5 billion
 parameters) has 25 attention heads and a context vector embedding size of 1600. Note
 that the embedding sizes of the token inputs and context embeddings are the same in GPT
 models `(d_in = d_out)`.

 **EXERCISE 3.3 INITIALIZING GPT-2 SIZE ATTENTION MODULES**
 
 - Using the MultiHeadAttention class, initialize a multi-head attention module that
 has the same number of attention heads as the smallest GPT-2 model (12 attention
 heads). Also ensure that you use the respective input and output embedding sizes
 similar to GPT-2 (768 dimensions). Note that the smallest GPT-2 model supports a
 context length of 1024 tokens.

 - `This exercise was done in exercise_solution.ipynb file`

## 3.7 Summary

- Attention mechanisms transform input elements into enhanced context
 vector representations that incorporate information about all inputs. 
- A self-attention mechanism computes the context vector representation
 as a weighted sum over the inputs. 
- In a simplified attention mechanism, the attention weights are computed
 via dot products. 
- A dot product is just a concise way of multiplying two vectors element
wise and then summing the products. 
- Matrix multiplications, while not strictly required, help us to implement
 computations more efficiently and compactly by replacing nested for
loops. 
- In self-attention mechanisms that are used in LLMs, also called scaled-dot
 product attention, we include trainable weight matrices to compute
 intermediate transformations of the inputs: queries, values, and keys. 
- When working with LLMs that read and generate text from left to right,
 we add a causal attention mask to prevent the LLM from accessing future
 tokens. 
- Next to causal attention masks to zero out attention weights, we can also
 add a dropout mask to reduce overfitting in LLMs. 
- The attention modules in transformer-based LLMs involve multiple
 instances of causal attention, which is called multi-head attention. 
- We can create a multi-head attention module by stacking multiple
 instances of causal attention modules. 
- A more efficient way of creating multi-head attention modules involves
 batched matrix multiplications