# Chapter 3

__Figure 3.1 A mental model of the three main stages of coding an LLM, pretraining the LLM on a general text dataset, and finetuning it on a labeled dataset. This chapter focuses on attention mechanisms, which are an integral part of an LLM architecture.__

![LLM mental model](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image001.png)

__Figure 3.2 The figure depicts different attention mechanisms we will code in this chapter, starting with a simplified version of self-attention before adding the trainable weights. The causal attention mechanism adds a mask to self-attention that allows the LLM to generate one word at a time. Finally, multi-head attention organizes the attention mechanism into multiple heads, allowing the model to capture various aspects of the input data in parallel.__

![different attention mechanisms](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image003.png)

## 3.1 The problem with modeling long sequences

__Figure 3.3 When translating text from one language to another, such as German to English, it's not possible to merely translate word by word. Instead, the translation process requires contextual understanding and grammar alignment.__


![problem of word by word translation](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image005.png)

__Figure 3.4 Before the advent of transformer models, encoder-decoder RNNs were a popular choice for machine translation. The encoder takes a sequence of tokens from the source language as input, where a hidden state (an intermediate neural network layer) of the encoder encodes a compressed representation of the entire input sequence. Then, the decoder uses its current hidden state to begin the translation, token by token.__

![RNNs use hidden state to encode the entire sequence of tokens](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image007.png)

## 3.2 Capturing data dependencies with attention mechanisms

__Figure 3.5 Using an attention mechanism, the text-generating decoder part of the network can access all input tokens selectively. This means that some input tokens are more important than others for generating a given output token. The importance is determined by the so-called attention weights, which we will compute later. Note that this figure shows the general idea behind attention and does not depict the exact implementation of the Bahdanau mechanism, which is an RNN method outside this book's scope.__

![attention mechaisn](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image009.png)

__Figure 3.6 Self-attention is a mechanism in transformers that is used to compute more efficient input representations by allowing each position in a sequence to interact with and weigh the importance of all other positions within the same sequence. In this chapter, we will code this self-attention mechanism from the ground up before we code the remaining parts of the GPT-like LLM in the following chapter.__

![self-attention mechaisn](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image011.png)

## 3.3 Attending to different parts of the input with self-attention

### 3.3.1 A simple self-attention mechanism without trainable weights

__Figure 3.7 The goal of self-attention is to compute a context vector, for each input element, that combines information from all other input elements. In the example depicted in this figure, we compute the context vector z(2). The importance or contribution of each input element for computing z(2) is determined by the attention weights α21 to α2T. When computing z(2), the attention weights are calculated with respect to input element x(2) and all other inputs. The exact computation of these attention weights is discussed later in this section.__

![context vector](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image013.png)

dot product not using pytorch

In [2]:
x1 = [0.4, 0.1, 0.8]
x2 = [0.5, 0.8, 0.6]

# calculate the dot product of x1 and x2
dot_product = sum(a * b for a, b in zip(x1, x2))

print("The dot product of x1 and x2 is:", dot_product)

The dot product of x1 and x2 is: 0.76


using pytorch

In [9]:
import torch

torch.dot(torch.tensor(x1), torch.tensor(x2))

tensor(0.7600)

In [2]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

__Figure 3.8 The overall goal of this section is to illustrate the computation of the context vector z(2) using the second input element, x(2) as a query. This figure shows the first intermediate step, computing the attention scores ω between the query x(2) and all other input elements as a dot product. (Note that the numbers in the figure are truncated to one digit after the decimal point to reduce visual clutter.)__

![self-attention scores](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image015.png)

In [7]:
print(inputs.shape)
print(inputs.shape[0])
print(inputs.shape[1])

torch.Size([6, 3])
6
3


In [3]:
query = inputs[1]

attn_scores_2 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
  attn_scores_2[i] = torch.dot(x_i, query)

print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


__Figure 3.9 After computing the attention scores ω<sup>21</sup> to ω<sup>2T</sup> with respect to the input query x<sup>(2)</sup>, the next step is to obtain the attention weights α<sup>21</sup> to α<sup>2T</sup> by normalizing the attention scores.__

![attention weights](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image017.png)

In [4]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()

print("Attention weights:", attn_weights_2_tmp)

print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In practice, it's more common and advisable to use the __softmax__ function for normalization. This approach is better at managing extreme values and offers more favorable gradient properties during training. Below is a basic implementation of the softmax function for normalizing the attention scores:

In [5]:
def softmax_naive(x):
  return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)

print("Attention weights:", attn_weights_2_naive)

print("Sum:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


PyTorch softmax implementation

In [6]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention weights:", attn_weights_2)

print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


__Figure 3.10 The final step, after calculating and normalizing the attention scores to obtain the attention weights for query x<sup>(2)</sup>, is to compute the context vector z<sup>(2)</sup>. This context vector is a combination of all input vectors x<sup>(1)</sup> to x<sup>(T)</sup> weighted by the attention weights.__

![query weights](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image019.png)

In [6]:
query = inputs[1] # 2nd input token is the query

print(query.shape)

torch.Size([3])


In [7]:
context_vec_2 = torch.zeros(query.shape)

for i, x_i in enumerate(inputs):
  context_vec_2 += attn_weights_2[i]*x_i

print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


### 3.3.2 Computing attention weights for all input tokens

__Figure 3.11 The highlighted row shows the attention weights for the second input element as a query, as we computed in the previous section. This section generalizes the computation to obtain all other attention weights.__

![attention weights for the 2nd token](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image021.png)

__3.3.1 A simple self-attention mechanism without trainable weights__

![computing all context vectors at once](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image023.png)

In [8]:
attn_scores = torch.empty(6, 6)

for i, x_i in enumerate(inputs):
  for j, x_j in enumerate(inputs):
    print(f"dot({i}, {j}): {x_i} * {x_j}")
    attn_scores[i, j] = torch.dot(x_i, x_j)

print(attn_scores)

dot(0, 0): tensor([0.4300, 0.1500, 0.8900]) * tensor([0.4300, 0.1500, 0.8900])
dot(0, 1): tensor([0.4300, 0.1500, 0.8900]) * tensor([0.5500, 0.8700, 0.6600])
dot(0, 2): tensor([0.4300, 0.1500, 0.8900]) * tensor([0.5700, 0.8500, 0.6400])
dot(0, 3): tensor([0.4300, 0.1500, 0.8900]) * tensor([0.2200, 0.5800, 0.3300])
dot(0, 4): tensor([0.4300, 0.1500, 0.8900]) * tensor([0.7700, 0.2500, 0.1000])
dot(0, 5): tensor([0.4300, 0.1500, 0.8900]) * tensor([0.0500, 0.8000, 0.5500])
dot(1, 0): tensor([0.5500, 0.8700, 0.6600]) * tensor([0.4300, 0.1500, 0.8900])
dot(1, 1): tensor([0.5500, 0.8700, 0.6600]) * tensor([0.5500, 0.8700, 0.6600])
dot(1, 2): tensor([0.5500, 0.8700, 0.6600]) * tensor([0.5700, 0.8500, 0.6400])
dot(1, 3): tensor([0.5500, 0.8700, 0.6600]) * tensor([0.2200, 0.5800, 0.3300])
dot(1, 4): tensor([0.5500, 0.8700, 0.6600]) * tensor([0.7700, 0.2500, 0.1000])
dot(1, 5): tensor([0.5500, 0.8700, 0.6600]) * tensor([0.0500, 0.8000, 0.5500])
dot(2, 0): tensor([0.5700, 0.8500, 0.6400]) * tensor

In [9]:
attn_scores = inputs @ inputs.T

print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [10]:
attn_weights = torch.softmax(attn_scores, dim=1)

print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [11]:
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

print("Row 2 sum:", row_2_sum)

print("All row sums:", attn_weights.sum(dim=1))

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [12]:
all_context_vecs = attn_weights @ inputs

print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [13]:
print("Previous 2nd contxt vector:", context_vec_2)

Previous 2nd contxt vector: tensor([0.4419, 0.6515, 0.5683])


## 3.4 Implementing self-attention with trainable weights

__Figure 3.13 A mental model illustrating how the self-attention mechanism we code in this section fits into the broader context of this book and chapter. In the previous section, we coded a simplified attention mechanism to understand the basic mechanism behind attention mechanisms. In this section, we add trainable weights to this attention mechanism. In the upcoming sections, we will then extend this self-attention mechanism by adding a causal mask and multiple heads.__

![self-attention in the LLM mental model](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image025.png)

### 3.4.1 Computing the attention weights step by step

__Figure 3.14 In the first step of the self-attention mechanism with trainable weight matrices, we compute query (q), key (k), and value (v) vectors for input elements x. Similar to previous sections, we designate the second input, x<sup>(2)</sup>, as the query input. The query vector q<sup>(2)</sup> is obtained via matrix multiplication between the input x<sup>(2)</sup> and the weight matrix W<sub>q</sub>. Similarly, we obtain the key and value vectors via matrix multiplication involving the weight matrices W<sub>k</sub> and W<sub>v</sub>.__

![computing attention weights](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image027.png)

In [14]:
x_2 = inputs[1]

d_in = inputs.shape[1]

d_out = 2

In [15]:
print("Input dimension:", d_in)

Input dimension: 3


In [16]:
print(x_2)

tensor([0.5500, 0.8700, 0.6600])


In [17]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [18]:
print("W shape:", W_query.shape)

W shape: torch.Size([3, 2])


In [19]:
print(W_query)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])


In [20]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2)

tensor([0.4306, 1.4551])


In [21]:
keys = inputs @ W_key
values = inputs @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


__Figure 3.15 The attention score computation is a dot-product computation similar to what we have used in the simplified self-attention mechanism in section 3.3. The new aspect here is that we are not directly computing the dot-product between the input elements but using the query and key obtained by transforming the inputs via the respective weight matrices.__

![attention score](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image029.png)

In [22]:
keys_2 = keys[1]

attn_scores_22 = query_2.dot(keys_2)

print(attn_scores_22)

tensor(1.8524)


In [23]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


__Figure 3.16 After computing the attention scores _ω_, the next step is to normalize these scores using the softmax function to obtain the attention weights _α_.__

![normalise attentions scores](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image031.png)

In [24]:
d_k = keys.shape[-1]

attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)

print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


In [25]:
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

In [26]:
d_k

2

__Figure 3.17 In the final step of the self-attention computation, we compute the context vector by combining all value vectors via the attention weights.__

![context vector](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image033.png)

In [27]:
context_vec_2 = attn_weights_2 @ values

print(context_vec_2)

tensor([0.3061, 0.8210])


In [28]:
attn_weights_2

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

In [29]:
values

tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])

### 3.4.2 Implementing a compact self-attention Python class

In [34]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        print("W_query:", self.W_query)
        # print("W_query transposed:", self.W_query.T)
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        # print("W_key:", self.W_key)
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
        # print("W_value:", self.W_value)

    def forward(self, x):
        keys = x @ self.W_key
        # print("keys.shape:", keys.shape)
        # print("keys:", keys)
        queries = x @ self.W_query
        # print("queries.shape:", queries.shape)
        values = x @ self.W_value
        # print("values.shape:", values.shape)
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        # print("attn_weights:", attn_weights)
        context_vec = attn_weights @ values
        return context_vec

In [35]:
torch.manual_seed(123)

sa_v1 = SelfAttention_v1(d_in, d_out)

print(sa_v1(inputs))

W_query: Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]], requires_grad=True)
tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


__Figure 3.18 In self-attention, we transform the input vectors in the input matrix X with the three weight matrices, W<sub>q</sub>, W<sub>k</sub>, and W<sub>v</sub>. Then, we compute the attention weight matrix based on the resulting queries (Q) and keys (K). Using the attention weights and values (V), we then compute the context vectors (Z). (For visual clarity, we focus on a single input text with n tokens in this figure, not a batch of multiple inputs. Consequently, the 3D input tensor is simplified to a 2D matrix in this context. This approach allows for a more straightforward visualization and understanding of the processes involved. Also, for consistency with later figures, the values in the attention matrix do not depict the real attention weights.)__

![context vector](https://drek4537l1klr.cloudfront.net/raschka/v-8/Figures/ch03__image035.png)

In [57]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        print("W_query shape:", self.W_query.weight.shape)
        print("W_query:", self.W_query.weight)
        print("W_query bias:", self.W_query.bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        # print("W_key:", self.W_key.weight)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # print("W_value:", self.W_value.weight)

    def forward(self, x):
        keys = self.W_key(x)
        # print("keys.shape:", keys.shape)
        # print("keys:", keys)
        queries = self.W_query(x)
        # print("queries.shape:", queries.shape)
        values = self.W_value(x)
        # print("values.shape:", values.shape)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        # print("attn_weights:", attn_weights)
        # print("W_query bias:", self.W_query.bias)
        context_vec = attn_weights @ values
        return context_vec

In [58]:
torch.manual_seed(789)

sa_v2 = SelfAttention_v2(d_in, d_out)

print (sa_v2(inputs))

W_query shape: torch.Size([2, 3])
W_query: Parameter containing:
tensor([[ 0.3161,  0.4568,  0.5118],
        [-0.1683, -0.3379, -0.0918]], requires_grad=True)
W_query bias: None
tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


__Exercise 3.1 Comparing SelfAttention_v1 and SelfAttention_v2__

Note that nn.Linear in `SelfAttention_v2` uses a different weight initialization scheme as `nn.Parameter(torch.rand(d_in, d_out))` used in `SelfAttention_v1`, which causes both mechanisms to produce different results. To check that both implementations, `SelfAttention_v1` and `SelfAttention_v2`, are otherwise similar, we can transfer the weight matrices from a `SelfAttention_v2` object to a `SelfAttention_v1`, such that both objects then produce the same results.

Your task is to correctly assign the weights from an instance of `SelfAttention_v2` to an instance of `SelfAttention_v1`. To do this, you need to understand the relationship between the weights in both versions. (Hint: `nn.Linear` stores the weight matrix in a transposed form.) After the assignment, you should observe that both instances produce the same outputs.

In [39]:
# Function to transfer weights from SelfAttention_v2 to SelfAttention_v1
def transfer_weights(attention_v2, attention_v1):
    # Transpose weights for W_query, W_key, and W_value
    attention_v1.W_query.data = attention_v2.W_query.weight.T.detach().clone()
    attention_v1.W_key.data = attention_v2.W_key.weight.T.detach().clone()
    attention_v1.W_value.data = attention_v2.W_value.weight.T.detach().clone()

    # Check for bias in SelfAttention_v2 and ensure they are zeroed for consistency
    if attention_v2.W_query.bias is not None:
        print("Warning: SelfAttention_v2 includes bias terms. Setting to zero for comparison.")
        attention_v2.W_query.bias.data.zero_()
        attention_v2.W_key.bias.data.zero_()
        attention_v2.W_value.bias.data.zero_()

# Create instances
# d_in, d_out = 5, 3
# batch_size = 4

x = inputs # torch.randn(batch_size, d_in)

# Initialize both attention mechanisms
attention_v2 = SelfAttention_v2(d_in, d_out, qkv_bias=True)
attention_v1 = SelfAttention_v1(d_in, d_out)

# Transfer weights
transfer_weights(attention_v2, attention_v1)

# Compare outputs
output_v2 = attention_v2(x)
output_v1 = attention_v1(x)

print("Output of SelfAttention_v2:\n", output_v2)
print("Output of SelfAttention_v1:\n", output_v1)

# Check if outputs are close
if torch.allclose(output_v2, output_v1, atol=1e-6):
    print("The outputs are identical!")
else:
    print("The outputs differ.")


W_query shape: torch.Size([2, 3])
W_query: Parameter containing:
tensor([[ 0.0911,  0.4770, -0.5456],
        [-0.3887, -0.2299,  0.0232]], requires_grad=True)
W_query bias: Parameter containing:
tensor([-0.1347, -0.0634], requires_grad=True)
W_query: Parameter containing:
tensor([[0.2251, 0.3111],
        [0.1955, 0.9153],
        [0.7751, 0.6749]], requires_grad=True)
Output of SelfAttention_v2:
 tensor([[-0.2406,  0.3448],
        [-0.2571,  0.3424],
        [-0.2573,  0.3424],
        [-0.2530,  0.3425],
        [-0.2568,  0.3423],
        [-0.2518,  0.3426]], grad_fn=<MmBackward0>)
Output of SelfAttention_v1:
 tensor([[-0.2406,  0.3448],
        [-0.2571,  0.3424],
        [-0.2573,  0.3424],
        [-0.2530,  0.3425],
        [-0.2568,  0.3423],
        [-0.2518,  0.3426]], grad_fn=<MmBackward0>)
The outputs are identical!


## 3.5 Hiding future words with causal attention

__Figure 3.19 In causal attention, we mask out the attention weights above the diagonal such that for a given input, the LLM can’t access future tokens when computing the context vectors using the attention weights. For example, for the word “journey” in the second row, we only keep the attention weights for the words before (“Your”) and in the current position (“journey”).__

![casual attention](https://drek4537l1klr.cloudfront.net/raschka/Figures/3-19.png)

### 3.5.5 Applying the causal attention mask

__Figure 3.20 One way to obtain the masked attention weight matrix in causal attention is to apply the softmax function to the attention scores, zeroing out the elements above the diagonal and normalizing the resulting matrix.__

![appluing causal attention](https://drek4537l1klr.cloudfront.net/raschka/Figures/3-20.png)

In [59]:
queries = sa_v2.W_query(inputs)

keys = sa_v2.W_key(inputs)

attn_scores = queries @ keys.T

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [61]:
context_length = attn_scores.shape[0]

mask_simple = torch.tril(torch.ones(context_length, context_length))

print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [62]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [63]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)

masked_simple_norm = masked_simple / row_sums

print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


__Figure 3.21 A more efficient way to obtain the masked attention weight matrix in causal attention is to mask the attention scores with negative infinity values before applying the softmax function.__

![more efficient causal attention](https://drek4537l1klr.cloudfront.net/raschka/Figures/3-21.png)

In [68]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

masked = attn_scores.masked_fill(mask.bool(), -torch.inf)

print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [70]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)

print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [71]:
# Check if outputs are close
if torch.allclose(attn_weights, masked_simple_norm, atol=1e-6):
    print("The outputs are identical!")
else:
    print("The outputs differ.")

The outputs are identical!


### 3.5.2 Masking additional attention weights with dropout


__Figure 3.22 Using the causal attention mask (upper left), we apply an additional dropout mask (upper right) to zero out additional attention weights to reduce overfitting during training.__

![dropout](https://drek4537l1klr.cloudfront.net/raschka/Figures/3-22.png)

In [72]:
torch.manual_seed(123)

dropout = torch.nn.Dropout(0.5)

example = torch.ones(6, 6)

print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [73]:
torch.manual_seed(123)

print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


### 3.5.3 Implementing a compact causal attention class

In [76]:
batch = torch.stack((inputs, inputs), dim=0)

print(batch.shape)

torch.Size([2, 6, 3])


In [104]:
class CausalAttention(nn.Module):
    def __init__(self, 
                 d_in, 
                 d_out, 
                 context_length, 
                 dropout, 
                 qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

  

In [105]:
torch.manual_seed(123)

ontext_length = batch.shape[1]

ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


__Figure 3.23 Here’s what we’ve done so far. We began with a simplified attention mechanism, added trainable weights, and then added a causal attention mask. Next, we will extend the causal attention mechanism and code multi-head attention, which we will use in our LLM.__

![causal attention](https://drek4537l1klr.cloudfront.net/raschka/Figures/3-23.png)

## 3.6 Extending single-head attention to multi-head attention

### 3.6.1 Stacking multiple single-head attention layers

__Figure 3.24 The multi-head attention module includes two single-head attention modules stacked on top of each other. So, instead of using a single matrix W<sub>v</sub> for computing the value matrices, in a multi-head attention module with two heads, we now have two value weight matrices: W<sub>v1</sub> and W<sub>v2</sub>. The same applies to the other weight matrices, W<sub>Q</sub> and W<sub>k</sub>. We obtain two sets of context vectors Z<sub>1</sub> and Z<sub>2</sub> that we can combine into a single context vector matrix Z.__

![multi-head attention](https://drek4537l1klr.cloudfront.net/raschka/Figures/3-24.png)

In [81]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                CausalAttention(
                    d_in, d_out, context_length, dropout, qkv_bias
                ) 
                for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

__Figure 3.25 Using the`MultiHeadAttentionWrapper`, we specified the number of attention heads (`num_heads`). If we set `num_heads=2`, as in this example, we obtain a tensor with two sets of context vector matrices. In each context vector matrix, the rows represent the context vectors corresponding to the tokens, and the columns correspond to the embedding dimension specified via `d_out=4`. We concatenate these context vector matrices along the column dimension. Since we have two attention heads and an embedding dimension of 2, the final embedding dimension is 2 × 2 = 4.__

![multi-head attention context vector](https://drek4537l1klr.cloudfront.net/raschka/Figures/3-25.png)

In [83]:
torch.manual_seed(123)

context_length = batch.shape[1] # this is the number of tokens

d_in, d_out = 3, 2

mha = MultiHeadAttentionWrapper(d_in, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)

print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


__Exercise 3.2 Returning two-dimensional embedding vectors__

Change the input arguments for the `MultiHeadAttentionWrapper(..., num_ heads=2)` call such that the output context vectors are two-dimensional instead of four dimensional while keeping the setting `num_heads=2`. Hint: You don’t have to modify the class implementation; you just have to change one of the other input arguments.

In [84]:
torch.manual_seed(123)

context_length = batch.shape[1] # this is the number of tokens

d_in, d_out = 3, 1

mha = MultiHeadAttentionWrapper(d_in, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)

print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]],

        [[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


### 3.6.2 Implementing muylti-head attention with weight splits

In [163]:
class MultiHeadAttention(nn.Module):
    def __init__(
            self,
            d_in,
            d_out,
            context_length,
            dropout,
            num_heads,
            qkv_bias=False
        ):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisble by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
       )
    # def __init__(self, d_in, d_out, 
    #              context_length, dropout, num_heads, qkv_bias=False):
    #     super().__init__()
    #     assert (d_out % num_heads == 0), \
    #         "d_out must be divisible by num_heads"

    #     self.d_out = d_out
    #     self.num_heads = num_heads
    #     self.head_dim = d_out // num_heads
    #     self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    #     self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    #     self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    #     self.out_proj = nn.Linear(d_out, d_out)
    #     self.dropout = nn.Dropout(dropout)
    #     self.register_buffer(
    #         "mask",
    #         torch.triu(torch.ones(context_length, context_length),
    #                    diagonal=1)
    #     )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec


__Figure 3.26 In the `MultiHeadAttentionWrapper` class with two attention heads, we initialized two weight matrices, W<sub>q1</sub> and W<sub>q2</sub>, and computed two query matrices, Q<sub>1</sub> and Q<sub>2</sub> (top). In the `MultiheadAttention` class, we initialize one larger weight matrix W<sub>q</sub>, only perform one matrix multiplication with the inputs to obtain a query matrix Q, and then split the query matrix into Q<sub>1</sub> and Q<sub>2</sub> (bottom). We do the same for the keys and values, which are not shown to reduce visual clutter.__

![multi-head attention matrix optimisations](https://drek4537l1klr.cloudfront.net/raschka/Figures/3-26.png)

In [86]:
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],

                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

In [87]:
print(a @ a.transpose(2, 3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


In [88]:
first_head = a[0, 0, :, :]

first_res = first_head @ first_head.T

print("First head:\n", first_res)

second_head = a[0, 1, :, :]

second_res = second_head @ second_head.T

print("Second head:\n", second_res)

First head:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])
Second head:
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


In [164]:
torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape

d_out = 2

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)

print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


In [143]:
class MultiHeadAttention_Book(nn.Module):
    def __init__(self, d_in, d_out, 
                 context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)  
        queries = queries.view(                                             
            b, num_tokens, self.num_heads, self.head_dim                    
        )                                                                   

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(
            b, num_tokens, self.d_out
        )
        context_vec = self.out_proj(context_vec)
        return context_vec

In [146]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention_Book(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
