# 3.4 Implementing Self-Attention with Trainable Weights

In this section, we will implement the self-attention mechanism that is widely used in the original Transformer architecture, the GPT model, and most other popular large language models. This self-attention mechanism is also called scaled dot
product attention. Figure 3.13 depicts how this self-attention mechanism fits into the broader context of building large language models.

**Figure 3.13 shows how the self-attention mechanism we encode in this section fits into the overall context of this book and chapter. In the previous section, we implemented a simplified attention mechanism in order to understand the basic principles of the attention mechanism. In this section, we will build on this by adding trainable weights. In subsequent sections, we will further extend this self-attention mechanism by adding causal masks and multi-head attention. **

![3.13](../img/fig-3-13.jpg)

As shown in Figure 3.13, the self-attention mechanism with trainable weights builds on the previous concept: we want to calculate the context vector as a weighted sum of the input vector based on a specific input element. As you can see, it has only a few minor differences compared to the basic self-attention mechanism we wrote before in Section 3.3.

The most significant difference is the introduction of the weight matrix that is updated during model training. ThisThese trainable weight matrices are crucial because they enable the model (especially the attention module inside the model) to learn to produce "good" context vectors. (Note that we will train the large language model in Chapter 5.)

We will discuss this self-attention mechanism in two subsections. First, we will write the code step by step as before. Second, we will organize the code into a compact Python class that can be imported into our large language model architecture in Chapter 4.

## 3.4.1 Calculating Attention Weights Step by Step

We will implement the self-attention mechanism step by step by introducing three trainable weight matrices Wq, Wk, and Wv. These three matrices are used to project the embedded input token x(i) into query vectors, key vectors, and value vectors, as shown in Figure 3.14.

** Figure 3.14 shows the first step of implementing the self-attention mechanism with trainable weight matrices. In this step, for each input element x, we calculate its corresponding query (q), key (k), and value (v) vectors. As in the previous section, we treat the second input x(2) as the query input. The query vector q(2) is obtained by matrix multiplication of the input x(2) with the query weight matrix Wq. Similarly, we use the weight matrices Wk and Wv to calculate the key vector and value vector respectively through corresponding matrix multiplication operations. **

![3.14](../img/fig-3-14.jpg)

In Section 3.3.1, we defined the second input element x(2) as the query vector to compute the simplified attention weights and obtain the context vector z(2). In Section 3.3.2, we generalized this computation to all context vectors z(1) to z(T) for the six-word input sentence "Your journey starts with one step."

Again, for illustration purposes, we will first compute only one context vector z(2). In the next section, we will modify this code to compute all context vectors.

Let's start by defining some variables:

In [4]:
x_2 = inputs[1] #A
d_in = inputs.shape[1] #B
d_out = 2 #C

Note that in models like GPT, the input and output dimensions are usually the same, but to better illustrate the calculation process, we choose different input (d_in=3) and output (d_out=2) dimensions here.

Next, we initialize the three weight matrices Wq, Wk, and Wv shown in Figure 3.14:

In [5]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

Note that we set requires_grad to False for clearer output during demonstration, but if we were to use these weight matrices for model training, we would set requires_grad to True so that these matrices can be updated during model training.

Next, as shown in Figure 3.14, we compute the query, key, and value vectors:

In [6]:
query_2 = x_2 @ W_query 
key_2 = x_2 @ W_key 
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


As you can see from the query output, this produces a two-dimensional vector because we set the number of columns of the corresponding weight matrix to 2 via d_out:
```python
tensor([0.4306, 1.4551])
```

### Weight Parameters vs. Attention Weights

Note that in the weight matrix W, the word "weight" is short for "weight parameters", which are the values ​​of the neural network that are optimized during the training process. It should not be confused with the attention weights. As we have seen in the previous section, the attention weights determine how much the context vector depends on different parts of the input, that is, how much the network pays attention to different parts of the input.

That is, the weight parameters are the basic learned coefficients that define the network connections, while the attention weights are dynamic, context-specific values.

Although our temporary goal is only to calculate a context vector z(2), we still need the key and value vectors of all input elements because they participate in calculating the attention weights based on the query vector q(2), as shown in Figure 3.14.

We can obtain all the key and value vectors by matrix multiplication:

In [7]:
keys = inputs @ W_key 
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


From the output, we can see that we have successfully projected the 6 input tokens from the 3D space to the 2D embedding space:
```python
keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])
```

The second step is to calculate the attention score, as shown in Figure 3.15.

**Figure 3.15 The attention score calculation is a dot product calculation, similar to the simplified self-attention mechanism we used in Section 3.3. The new aspect is that instead of directly calculating the dot product between the input elements, we use the query vector and key vector obtained by transforming the input through their respective weight matrices. **

![3.15](../img/fig-3-15.jpg)

First, let's calculate the attention score ω22:

In [8]:
keys_2 = keys[1] #A
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


Get the unnormalized attention score result:
```python
tensor(1.8524)
```

Again, we can extend this calculation to all attention scores via matrix multiplication:

In [9]:
attn_scores_2 = query_2 @ keys.T # All attention scores for given query
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


From the output we can quickly see that the second element of the output matches the attn_score_22 we calculated previously:
```python
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
```

The third step is to go from attention scores to attention weights, as shown in Figure 3.16.

**Figure 3.16 After calculating the attention scores ω, the next step is to normalize these scores using the softmax function to obtain the attention weights α. **

![3.16](../img/fig-3-16.jpg)

Next, as shown in Figure 3.16, we calculate the attention weights by scaling the attention scores and using the softmax function we used before. Unlike before, we now scale the attention scores by dividing by the square root of the embedding dimension of the key, (note that taking the square root is mathematically the same as exponentiating to 0.5):

In [None]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

The obtained attention weights are as follows:
```python
tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
```

### The logic behind scaled dot product attention

The reason for normalizing by the size of the embedding dimension is to improve training performance by avoiding small gradients. For example, when scaling up the embedding dimension, for large language models like GPT, whose dimensions often exceed a thousand, large dot products may produce very small gradients during backpropagation due to the softmax function applied. As the dot product increases, the softmax function behaves more like a step function, causing the gradient to approach zero. These small gradients can greatly slow down learning or cause training to stagnate.

The reason this self-attention mechanism is also called scaled dot product attention is that it scales by the square root of the embedding dimension.

The last step is to calculate the context vector, as shown in Figure 3.17. 

**Figure 3.17 In the last step of the self-attention calculation, we calculate the context vector by combining all the value vectors by the attention weights. **

![3.17](../img/fig-3-17.jpg)

Similar to Section 3.3, we calculated the context vector by weighted sum of input vectors, now we calculate the context vector by weighted sum of value vectors. Here, the attention weight acts as a weighting factor to measure the corresponding importance of each value vector. Similar to Section 3.3, we can use matrix multiplication to get the output in one step:

In [10]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

NameError: name 'attn_weights_2' is not defined

The generated tensor content is as follows:
```python
tensor([0.3061, 0.8210])
```

So far, we have only computed one context vector, z(2). In the next section, we will extend the code to compute all context vectors z(1) to z(T) in the input sequence.

### Why query, key, and value?

In the context of attention mechanisms, the terms "key", "query", and "value" are borrowed from the fields of information retrieval and databases, where similar concepts are used to store, search, and retrieve information.

A "query" is similar to a search query in a database. It represents the item that the model is currently focusing on or trying to understand (e.g., a word or token in a sentence). The query is used to explore other parts of the input sequence to determine how much attention should be given to them.

A "key" is similar to the keys used for indexing and searching in a database. In an attention mechanism, each item in the input sequence (e.g., each word in a sentence) has an associated key. These keys are used to match against queries.

A "value" in this context is similar to the value of a key-value pair in a database. It represents the actual content or representation of the input item. Once the model determines which keys (which parts of the input) are most relevant to the query (the currently focused item), it retrieves the corresponding values.

## 3.4.2 Implementing a compact self-attention Python class

In the previous chapters, we detailed the steps for calculating the self-attention output. This was done mainly to facilitate step-by-step explanation and demonstration. However, in practice, especially considering the implementation of the large language model to be introduced in the next chapter, it is more efficient to integrate these codes into a Python class. As shown below:

### Listing 3.1 Compact self-attention class

In [11]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
 
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In this PyTorch code, SelfAttention_v1 is a class that inherits from nn.Module, which is the basic unit of PyTorch models and provides the functions required to create and manage model layers.

The ‘__init__’ method is responsible for initializing the trainable weight matrices (W_query, W_key, and W_value) for query, key, and value respectively, each of which transforms the input dimension d_in to the output dimension d_out.

During the forward pass, through the forward method, we calculate the attention scores (attn_scores) by multiplying the query and the key and normalize them using the softmax function. Finally, we weight the values ​​by these normalized attention scores to create a context vector.

We can use this class in the following ways:

In [12]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward>)


Since the input contains six embedding vectors, we get a matrix containing six context vectors:
```python
tensor([[0.2996, 0.8053],
[0.3061, 0.8210],
[0.3058, 0.8203],
[0.2948, 0.7939],
[0.2927, 0.7891],
[0.2990, 0.8040]], grad_fn=<MmBackward0>)
```

A quick check shows that the second row ([0.3061, 0.8210]) actually corresponds to the contents of context_vec_2 in the previous section.

Figure 3.18 provides an overview of the self-attention mechanism we just implemented.

**Figure 3.18 In the self-attention mechanism, we transform the input vectors in the input matrix X by three weight matrices Wq, Wk, and Wv. We then compute the attention weight matrices based on the generated query (Q) and key (K). Using these attention weights and values ​​(V), we compute the context vector (Z). For visual simplicity, this figure only shows a single input text containing n tokens, rather than a batch of multiple inputs. This simplified representation into a 2D matrix makes the process more intuitive to visualize and understand. **

![3.18](../img/fig-3-18.jpg)

As shown in Figure 3.18, self-attention involves three trainable weight matrices Wq, Wk, and Wv. These matrices transform input data into queries, keys, and values, and are the core components of the attention mechanism. As the model is exposed to more data during training, these trainable weights are adjusted accordingly, which we will discuss in detail in the following chapters.

We can further improve SelfAttention by using PyTorch's nn.Linear layer_v1. These layers can perform matrix multiplication efficiently when no bias units are used. In addition, a significant advantage of using nn.Linear over manually implementing nn.Parameter(torch.rand(...)) is that it has an optimized weight initialization scheme, which helps to achieve more stable and efficient model training.

### Listing 3.2 Self-attention class using PyTorch's Linear layer

In [14]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
 
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

You can use SelfAttention_v2 in the same way as SelfAttention_v1:

In [15]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward>)


The output is: 
```python
tensor([[-0.0739, 0.0713],
[-0.0748, 0.0703],
[-0.0749, 0.0702],
[-0.0760, 0.0685],
[-0.0763, 0.0679],
[-0.0754, 0.0693]], grad_fn=<MmBackward0>)
```

Note that the outputs of SelfAttention_v1 and SelfAttention_v2 are different because they use different weight matrix initialization schemes. nn.Linear uses a more complex weight initialization scheme than nn.Parameter(torch.rand(d_in, d_out)) .

### Exercise 3.1 Comparing SelfAttention_v1 and SelfAttention_v2

Note that the nn.Linear in SelfAttention_v2 uses a different weight initialization scheme than the nn.Parameter(torch.rand(d_in, d_out)) in SelfAttention_v1, which causes the two mechanisms to produce different results. To check that the implementations of SelfAttention_v1 and SelfAttention_v2 are otherwise similar, we can transfer the weight matrix of the SelfAttention_v2 object to the SelfAttention_v1 object so that both objects produce the same results.

Your task is to correctly assign the weights of the SelfAttention_v2 instance to the SelfAttention_v1 instance. To do this, you need to understand the relationship between the weights in the two versions. (Hint: nn.Linear stores the weight matrix in transposed form.) After completing the weight assignment, you should find that the two instances produce the same output.

In the next section, we will enhance the self-attention mechanism, in particular, we will incorporate elements of causality and multi-head attention. The improvement in causality involves modifying the attention mechanism to prevent the model from accessing future information in the sequence.This is crucial for tasks like language modeling, where the prediction of each word can only depend on the previous words.

The component of multi-head attention involves splitting the attention mechanism into multiple "heads". Each head learns a different aspect of the data, allowing the model to simultaneously focus on information from different locations in different representation subspaces. This improves the model's performance on complex tasks.