# Chapter 3: Writing an Attention Mechanism

The packages used in this notebook are:

In [1]:
from importlib.metadata import version
import torch

print("torch version:", version("torch"))

torch version: 2.1.1


## 3.1 Problems with long sequence modeling

- There is no code in this section.

## 3.2 Capturing Data Dependencies Using Attention Mechanism

- There is no code in this section.

## 3.3 Using self-attention to focus on different parts of the input

### 3.3.1 A simple self-attention mechanism without trainable weights

- This section introduces an extremely simplified version of the self-attention mechanism, which does not contain any trainable weights. This is just for illustration purposes and is not the actual attention mechanism used in transformer models. The following section 3.3.2 will expand on this simple attention mechanism and introduce the real self-attention mechanism.
- Suppose we have an input sequence from $x^{(1)}$ to $x^{(T)}$.
- The input is a piece of text (such as the sentence "Your journey starts with one step"), which has been converted into the token embedding form described in Chapter 2.
- For example, $x^{(1)}$ is a d-dimensional vector representing the word "Your", and so on.
- **Goal:** For each element $x^{(i)}$ in the input sequence from $x^{(1)}$ to $x^{(T)}$, where $z$ and $x$ have the same dimension, calculate the context vector $z^{(i)}$.
- The context vector $z^{(i)}$ is a weighted average of the inputs $x^{(1)}$ to $x^{(T)}$.
- The context vector is the contextual information for a specific input.
- Instead of using $x^{(i)}$ as a placeholder for an arbitrary input token, we consider the secondInput, $x^{(2)}$.
- To make this concrete, instead of using a placeholder $z^{(i)}$, we consider the context vector for the second output, $z^{(2)}$.
- The second context vector $z^{(2)}$ is a weighted average of all inputs $x^{(1)}$ to $x^{(T)}$, with the weights determined based on the second input element $x^{(2)}$. These attention weights determine how much each input element contributes to the final weighted average when computing $z^{(2)}$.
- In short, $z^{(2)}$ can be thought of as a variant of $x^{(2)}$ that not only contains the information of $x^{(2)}$, but also incorporates the information of all other input elements that are relevant to the current task.

- By convention, unnormalized attention weights are called **"attention scores"**, while normalized attention scores (they sum to 1) are called **"attention weights"**.

- The calculation of attention weights and context vectors is summarized in the following diagram:

<img src="figures/attention.png" width="600px">

- The code below shows the above diagram step by step.

<br>

- **Step 1:** Calculate the unnormalized attention score $\omega$.
- Assuming we use the second input token as the query, i.e., $q^{(2)} = x^{(2)}$, we calculate the unnormalized attention score by dot product:
- $\omega_{21} = x^{(1)} \cdot q^{(2)\top}$
- $\omega_{22} = x^{(2)} \cdot q^{(2)\top}$
- $\omega_{23} = x^{(3)} \cdot q^{(2)\top}$
- ...
- $\omega_{2T} = x^{(T)} \cdot q^{(2)\top}$
- Here, $\omega$ is the Greek letter "omega" and is used to represent the unnormalized attention score.
- The subscript "21" in $\omega_{21}$ means that the second element of the input sequence is used as the query to be compared with the first element of the input sequence.

<img src="figures/dot-product.png" width="450px">

- Suppose we have the following input sentence that has been converted into a 3-dimensional vector as described in Chapter 3 (for illustration purposes, a very small embedding dimension is used here so that it fits on the page without line wrapping):

In [2]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

- We take the second element $x^{(2)}$ in the input sequence as an example to calculate the context vector $z^{(2)}$; in the following section, we will generalize this method to calculate all context vectors.
- The first step is to calculate the unnormalized attention score, which is achieved by calculating the dot product between the query $x^{(2)}$ and all other input tokens:

In [3]:
# Take the second element from the input sequence as the query vector
query = inputs[1]

# Create an empty tensor to store the attention scores, with the same shape as the batch size of the input sequence
attn_scores_2 = torch.empty(inputs.shape[0])

# Iterate over each element of the input sequence
for i, x_i in enumerate(inputs):
# Calculate the dot product of the current element and the query vector as the attention score
# No need to transpose here, because the input vector is assumed to be one-dimensional
    attn_scores_2[i] = torch.dot(x_i, query)

# Print attention scores
print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


- Note: The dot product is actually a shorthand for multiplying corresponding elements of two vectors and then adding these products together:

In [4]:
# Initialize the result variable to 0
res = 0.

# Iterate over each element in the first element of the input sequence
for idx, element in enumerate(inputs[0]):
# Multiply the current element by the corresponding element of the query vector and add the result to res
    res += inputs[0][idx] * query[idx]

# Print the manually calculated dot product result
print(res)

# Use PyTorch's torch.dot function to calculate the dot product and print the result
print(torch.dot(inputs[0], query))

tensor(0.9544)
tensor(0.9544)


- **Step 2:** Normalize the unnormalized attention scores (called "omegas", represented by the Greek letter $\omega$) so that their sum is equal to 1.
- Here is a simple way to normalize these unnormalized attention scores to ensure that they sum to 1 (this is a common practice that helps to understand and is critical to the stability of the training process):

In [5]:
# Normalize the attention scores using the sum of the attention scores
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()

# Print normalized attention weights
print("Attention weights:", attn_weights_2_tmp)
# Verify that the sum of the normalized attention weights is 1
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


- However, in practice, it is more common and recommended to use the softmax function for normalization, as it is better at handling extreme values ​​and has more desirable gradient properties during training.
- The following is a simple implementation of the softmax function, which is used to scale and normalize the vector elements so that their sum is 1:

In [6]:
# Define a simple softmax function implementation
def softmax_naive(x):
# Applies an exponential function to each element of the input tensor x
    exp_x = torch.exp(x)
# Calculate the sum of exp_x in the specified dimension (here is the first dimension, dim=0)
    sum_exp_x = exp_x.sum(dim=0)
# Divide each element of exp_x by their sum to get the softmax result
    return exp_x / sum_exp_x

# Normalize the attention scores using the naive softmax function
attn_weights_2_naive = softmax_naive(attn_scores_2)

# Print normalized attention weights
print("Attention weights:", attn_weights_2_naive)
# Verify that the sum of the normalized attention weights is 1
print("Sum:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


- The simple implementation above may cause numerical instability due to input values ​​that are too large or too small, which is mainly due to numerical overflow and underflow.
- Therefore, in practical applications, it is recommended to use the `softmax` function provided by PyTorch, which is highly optimized and has better performance:

In [7]:
# Use PyTorch's softmax function to normalize the attention scores
# dim=0 means to perform softmax calculation on the first dimension (usually the feature dimension)
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

# Print normalized attention weights
print("Attention weights:", attn_weights_2)
# Verify that the sum of the normalized attention weights is 1
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


- **Step 3**: Compute the context vector $z^{(2)}$ by multiplying the embedded input token $x^{(i)}$ with the attention weight and then adding the resulting vectors:

In [8]:
# Select the second element in the input sequence as the query vector
query = inputs[1]

# Initialize the context vector, which has the same shape as the query vector and an initial value of 0
context_vec_2 = torch.zeros(query.shape)

# Iterate over each element in the input sequence
for i, x_i in enumerate(inputs):
# Accumulate the product of each input element and its corresponding attention weight
    context_vec_2 += attn_weights_2[i] * x_i

# Print the calculated context vector
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


### 3.3.2 Calculate the attention weights of all input tags

#### Generalize to all input sequence tokens:

- In the above, we calculated the attention weights and context vectors for input 2 (as shown in the highlighted row in the diagram below).
- Next, we will generalize this calculation process to calculate the attention weights and context vectors for all input tokens.

<img src="figures/attention-matrix.png" width="400px">

- Apply the previous **first step** to calculate all pairs of elements to get the unnormalized attention score matrix:

In [9]:
# Create a 6x6 zero tensor to store the attention scores
attn_scores = torch.empty(6, 6)

# Iterate over each element in the input sequence
for i, x_i in enumerate(inputs):
# For the current input element x_i, traverse the entire input sequence again
    for j, x_j in enumerate(inputs):
# Calculate the dot product of x_i and x_j as the attention score and store it in the corresponding position of the attn_scores matrix
        attn_scores[i, j] = torch.dot(x_i, x_j)

# Print the complete attention score matrix
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


- We can implement the above calculation more efficiently through matrix multiplication:

In [10]:
# Use matrix multiplication to calculate the dot product matrix of the input sequence
# inputs @ inputs.T is equivalent to multiplying inputs by the transpose of inputs
attn_scores = inputs @ inputs.T

# Print attention score matrix
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


- Similar to the previous **step 2**, we normalize each row so that the sum of the values ​​in each row is 1:

In [11]:
attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


- Quickly verify that the values ​​in each row do sum to 1:

In [12]:
# Define the attention weight list for the second row
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
# Print the sum of the second line
print("Row 2 sum:", row_2_sum)

# Use PyTorch's sum function to calculate the sum of all rows along the specified dimension (here dimension 1, i.e. rows)
# attn_weights.sum(dim=1) returns a 1D tensor containing the sum of each row
print("All row sums:", attn_weights.sum(dim=1))

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


- Apply the previous **third step** to calculate all context vectors:

In [13]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


- As a sanity check, the previously calculated context vector $z^{(2)} = [0.4419, 0.6515, 0.5683]$ can be found in the second line above:

In [14]:
print("Previous 2nd context vector:", context_vec_2)

Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])


## 3.4 Implementing self-attention using trainable weights

### 3.4.1 Calculate attention weights step by step

- In this section, we are implementing the self-attention mechanism, a technique used in the original transformer architecture, the GPT model, and most other popular large language models (LLMs).
- This self-attention mechanism is also known as "scaled dot-product attention".
- The general idea is similar to before:
- We want to compute the context vector as a weighted sum of the input vectors for a particular input element.
- To do this, we need attention weights.
- You will see that there are only subtle differences compared to the basic attention mechanism introduced before:
- The most notable difference is the introduction of weight matrices that are updated during model training.
- These trainable weight matrices are crucial so that the model (especially the attention module within the model) can learn to produce "good" context vectors.

- As we gradually implement the self-attention mechanism, we will first introduce three training weight matrices $W_q$, $W_k$, and $W_v$.
- These three matrices are used to project the embedded input token $x^{(i)}$ to the query, key, and value vectors via matrix multiplication:

- Query vector: $q^{(i)} = W_q \cdot x^{(i)}$
- Key vector: $k^{(i)} = W_k \cdot x^{(i)}$
- Value vector: $v^{(i)} = W_v \cdot x^{(i)}$

<img src="figures/weight-selfattn-1.png" width="600px">

- The embedding dimension of the input $x$ and the dimension of the query vector $q$ can be the same or different, depending on the design and implementation of the model.
- In the GPT model, the dimensions of the input and output are usually the same, but to better illustrate the calculation process, we choose different input and output dimensions here:

In [15]:
# Get the second element in the input sequence as a specific input vector
x_2 = inputs[1]

# Get the embedding dimension of the input tensor. Here we assume that the dimension of the input vector is 3.
d_in = inputs.shape[1]

# Set the dimension of the output embedding. Here we assume that the dimension of the output vector is 2
d_out = 2

- Below, we initialize the three weight matrices; note that we set `requires_grad=False` to reduce clutter in the output, this is for illustration purposes. However, if we were to use these weight matrices in model training, we would set `requires_grad` to `True` so that these matrices are updated during model training.

In [16]:
# Set the random seed to ensure reproducibility of results
torch.manual_seed(123)

# Create a query weight matrix with shape (d_in, d_out) and set requires_grad=False to indicate that these weights will not be updated during training
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# Create a key weight matrix with the same shape as W_query and also set requires_grad=False
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# Create a value weight matrix with the same shape as W_query and also set requires_grad=False
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

- Next, we compute the query, key, and value vectors:

In [17]:
# Project the second input element x_2 into the query space using the query weight matrix W_query
query_2 = x_2 @ W_query  # 使用 @ 运算符进行矩阵乘法

# Project the second input element x_2 into the key space using the key weight matrix W_key
key_2 = x_2 @ W_key

# Project the second input element x_2 into the value space using the value weight matrix W_value
value_2 = x_2 @ W_value

# Print the calculated query vector
print(query_2)

tensor([0.4306, 1.4551])


- As can be seen from the results below, we have successfully projected the 6 input tokens from a 3D space to a 2D embedding space:

In [18]:
# Project the input sequence inputs into the key space using the key weight matrix W_key
keys = inputs @ W_key

# Use the value weight matrix W_value to project the input sequence inputs into the value space
values = inputs @ W_value

# Print the shape of the key vector
print("keys.shape:", keys.shape)

# Print the shape of the value vector
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


- In the next step, step 2, we compute the unnormalized attention scores by computing the dot product between the query vector and each key vector:

<img src="figures/weight-selfattn-2.png" width="600px">

In [19]:
# Extract the second key vector from the key tensor, corresponding to the second element in the input sequence
keys_2 = keys[1]  # Python 中的索引是从 0 开始的

# Calculate the dot product between the query vector query_2 and the key vector keys_2 to get the attention score
attn_score_22 = query_2.dot(keys_2)

# Print the calculated attention score
print(attn_score_22)

tensor(1.8524)


- Since we have 6 inputs, for a given query vector we get 6 attention scores:

In [20]:
# Perform matrix multiplication using the query vector query_2 and the transpose of all key vectors keys (keys.T) to get all attention scores
attn_scores_2 = query_2 @ keys.T

# Print all attention scores
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


<img src="figures/weight-selfattn-3.png" width="600px">

- Next, in **step 3**, we compute the attention weights (normalized attention scores, which sum to 1) using the softmax function mentioned earlier.
- Unlike before, we now scale the attention scores by dividing by the square root of the embedding dimension, $\sqrt{d_k}$ (i.e. `d_k**0.5`):

In [21]:
# Get the dimension of the key vector, that is, the dimension size of each key vector
d_k = keys.shape[1]

# Use the softmax function to normalize the attention scores
# attn_scores_2 / d_k**0.5 is to scale the dot product score to prevent the softmax gradient from being too small due to too large a dot product value
# dim=-1 means to perform softmax calculation along the last dimension (i.e. the dimension of attention score)
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)

# Print normalized attention weights
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


<img src="figures/weight-selfattn-4.png" width="600px">

- In step 4, we now compute the context vector for the input query vector 2:

In [22]:
# Use the normalized attention weights attn_weights_2 and the value vector values ​​to perform matrix multiplication to get the context vector
context_vec_2 = attn_weights_2 @ values

# Print the calculated context vector
print(context_vec_2)

tensor([0.3061, 0.8210])


### 3.4.2 Implementing a compact SelfAttention class

Putting it all together, we can implement the self-attention mechanism as follows:

In [23]:
import torch.nn as nn

# Define the self-attention module
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
# Call the parent class constructor
        super().__init__()
# Set the output dimension
        self.d_out = d_out
# Initialize the weight matrices for queries, keys, and values. These matrices are trainable
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
# Project input x into query, key, and value spaces using weight matrix
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
# Calculate attention scores (unnormalized)
        attn_scores = queries @ keys.T  # omega
        
# Normalize the attention scores using the softmax function and scaling factor
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

# Calculate the context vector using the normalized attention weights and value vector
        context_vec = attn_weights @ values
        return context_vec

# Set the random seed to ensure reproducibility of results
torch.manual_seed(123)
# Create a SelfAttention_v1 instance
sa_v1 = SelfAttention_v1(d_in, d_out)
# Use input data inputs to perform forward propagation and print the results
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


- We can simplify the above implementation using PyTorch's linear layer `nn.Linear`, which is equivalent to matrix multiplication if we turn off the bias unit.
- Another important advantage of using `nn.Linear` instead of the `nn.Parameter(torch.rand(...))` method we created manually is that `nn.Linear` comes with a preferred weight initialization scheme, which helps to achieve more stable model training.

In [24]:
import torch.nn as nn

# Define the second version of the self-attention module
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
# Call the parent class constructor
        super().__init__()
# Set the output dimension
        self.d_out = d_out
# Initialize the linear layers for query, key, and value, optionally including a bias
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
# Use a linear layer to project the input x into query, key, and value spaces
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
# Calculate attention scores (unnormalized)
        attn_scores = queries @ keys.T
        
# Normalize the attention scores using the softmax function and scaling factor
# Note that dim=1 here means normalization along the dimension of the key vector
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)

# Calculate the context vector using the normalized attention weights and value vector
        context_vec = attn_weights @ values
        return context_vec

# Set the random seed to ensure reproducibility of results
torch.manual_seed(789)
# Create a SelfAttention_v2 instance
sa_v2 = SelfAttention_v2(d_in, d_out)
# Use input data inputs to perform forward propagation and print the results
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


- Note that `SelfAttention_v1` and `SelfAttention_v2` produce different outputs because they use different initial weight matrices.

## 3.5 Attention Mechanism to Mask Context Information

### 3.5.1 Using Causal Attention Mask

- In this section, we convert the previous self-attention mechanism into a causal self-attention mechanism.

- The core goal of the causal self-attention mechanism is to ensure that the model's prediction of a position in the sequence depends only on the known output of the previous position (that is, the previous context), and not on the future position (that is, the following context). In other words, ensure that the prediction of each word should only depend on the previous word.

- To achieve this, for each given word, we mask out future words (that is, words after the current word in the input text).

<img src="figures/masked.png" width="600px">

- To illustrate and implement the causal self-attention mechanism, let’s operate with the attention scores and weights from the previous section.

In [25]:
# Use the query and key weight matrices of SelfAttention_V2 in the previous section
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T
# The attention weights here are the same as in the previous section
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


The simplest way to mask future attention weights is to create a mask via PyTorch’s tril function, with elements below the main diagonal (including the diagonal itself) set to 1 and elements above the main diagonal set to 0:

In [26]:
# The shape of the mask we create should be consistent with the shape of the attention weight matrix, in one-to-one correspondence
block_size = attn_scores.shape[0]
# The tril method creates a lower triangular matrix
mask_simple = torch.tril(torch.ones(block_size, block_size))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


- We can then multiply the attention weights by this mask to zero out the attention scores above the diagonal:

In [27]:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


- However, if we apply masking after softmax, as we did above, it will destroy the probability distribution created by softmax. Softmax will ensure that all output values ​​sum to 1, but since we set some of the output values ​​to 0, this will cause the output values ​​to sum to different values.

- Therefore, masking after softmax will require re-normalizing the outputs to sum to 1 again. However, this complicates the process and may lead to unexpected effects.

- To ensure that the output values ​​sum to 1, we can normalize the weight matrix as follows:

In [28]:
# dim = 1 means sum by row
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


- Although we have now technically completed the causal attention mechanism, there are more efficient ways to achieve the same effect as above.

- For example, instead of zeroing the attention weights above the diagonal and renormalizing the result, we can mask the above-diagonal portion with negative infinity before the unnormalized attention scores enter the softmax function.

In [29]:
# That is, by changing the mask from 0 to -inf, the masking operation can be moved before the softmax
mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


- As we can see, if we then pass the attention matrix through softmax, we can return the sum of each row to 1:

In [30]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


### 3.5.2 Masking for additional attention weights via dropout

- In addition, we can also apply dropout during training to reduce overfitting.

- Dropout can be applied in multiple places such as the following examples:
- After calculating the attention weights;
- After multiplying the attention weights with the value vector.

- Here, we will apply the dropout mask after calculating the attention weights because this is more common.

- In addition, in this particular example, we used a dropout rate of 50%, which means randomly masking half of the attention weights. (When we train the GPT model later, we will use a lower dropout rate, such as 0.1 or 0.2.)

<img src="figures/dropout.png" width="500px">

- Note that if we apply a dropout rate of 0.5, then the unmasked values ​​will be scaled accordingly by a factor of 1/0.5 = 2.

In [30]:
# Set a random number seed
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) # 设置 50% 的 Dropout 比例
example = torch.ones(6, 6) # 创建一个全 1 矩阵作为示例

print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [31]:
torch.manual_seed(123)
# Dropout the attention weights
print(dropout(attn_weights))

tensor([[0.3843, 0.3293, 0.3303, 0.3100, 0.3442, 0.3019],
        [0.0000, 0.3318, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.3325, 0.0000, 0.3328, 0.0000],
        [0.3738, 0.3334, 0.0000, 0.0000, 0.0000, 0.3128],
        [0.3661, 0.0000, 0.0000, 0.0000, 0.0000, 0.3169],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


## 3.5.3 Implementing a causal self-attention class

- Now, we are ready to implement a causal self-attention class with dropout.

- We also need to implement code to handle a batch of samples consisting of multiple inputs, so that our CausalAttention class supports the batch output produced by the dataloader we implemented in Chapter 2.

- For simplicity, to simulate such a batch input, we copy the input text example:

In [32]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) # 2个输入，每个输入有 6个 token，每个 token 的维度为 3

torch.Size([2, 6, 3])


In [33]:
# Define a causal self-attention layer with dropout
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):
        '''
        构造函数，输入参数如下：
        d_in: 输入的维度
        d_out: 输出的维度
        block_size: 注意力权重矩阵的大小
        dropout: dropout 比例
        qkv_bias: 是否对 query、key 和 value 加偏置
        '''
        super().__init__()
        self.d_out = d_out
# According to the previous text, each weight matrix is ​​a linear layer of d_in x d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
# A dropout layer
        self.dropout = nn.Dropout(dropout) 
# A mask matrix with 1 in the lower triangle and 0 in the rest
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New

    def forward(self, x):
        '''
        前向传播函数，输入参数为 x，维度为 b x num_tokens x d_in，输出维度为 b x num_tokens x d_out
        '''
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
# transpose is to implement matrix multiplication
        attn_scores = queries @ keys.transpose(1, 2)
# As mentioned above, change the mask from 0 to -inf and then perform the masking operation
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
# After softmax
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
# Perform dropout
        attn_weights = self.dropout(attn_weights) # New
# Get the final result
        context_vec = attn_weights @ values
        return context_vec

# Experiment
torch.manual_seed(123)

block_size = batch.shape[1]
ca = CausalAttention(d_in, d_out, block_size, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.0844,  0.0414],
         [-0.2264, -0.0039],
         [-0.4163, -0.0564],
         [-0.5014, -0.1011],
         [-0.7754, -0.1867],
         [-1.1632, -0.3303]],

        [[-0.0844,  0.0414],
         [-0.2264, -0.0039],
         [-0.4163, -0.0564],
         [-0.5014, -0.1011],
         [-0.7754, -0.1867],
         [-1.1632, -0.3303]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


- Note that dropout is only used during training, not during inference.

## 3.6 Extending single-head attention to multiple heads

### 3.6.1 Directly stack multiple single-head attention layers

- The following figure summarizes the self-attention mentioned above (for simplicity, the causal attention mask and dropout are not shown) 

- Also known as single-head attention:

<img src="figures/single-head.png" width="600px">

- We can simply stack multiple single-head attention layers together to achieve a multi-head attention layer:

<img src="figures/multi-head.png" width="600px">

- The main idea of ​​the multi-head attention mechanism is to run the attention mechanism multiple times (in parallel) using different, learned weight matrices. This allows the model to jointly attend to information in different representation subspaces at different locations.

In [34]:
# Define a multi-head attention layer
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
        super().__init__()
# Combine num_heads single-head attention layers together to achieve multi-head
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )

    def forward(self, x):
# Concatenate the outputs of multiple heads together during forward calculation
        return torch.cat([head(x) for head in self.heads], dim=-1)


# Experiment
torch.manual_seed(123)

block_size = batch.shape[1] # token 数量
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.0844,  0.0414,  0.0766,  0.0171],
         [-0.2264, -0.0039,  0.2143,  0.1185],
         [-0.4163, -0.0564,  0.3878,  0.2453],
         [-0.5014, -0.1011,  0.4992,  0.3401],
         [-0.7754, -0.1867,  0.7387,  0.4868],
         [-1.1632, -0.3303,  1.1224,  0.8460]],

        [[-0.0844,  0.0414,  0.0766,  0.0171],
         [-0.2264, -0.0039,  0.2143,  0.1185],
         [-0.4163, -0.0564,  0.3878,  0.2453],
         [-0.5014, -0.1011,  0.4992,  0.3401],
         [-0.7754, -0.1867,  0.7387,  0.4868],
         [-1.1632, -0.3303,  1.1224,  0.8460]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


- In the above implementation, the embedding dimension is 4 because we set d_out=2 as the embedding dimension for key, query, and value. Since we have 2 attention heads, the output embedding dimension is 2*2=4.

- If we want the output dimension to be 2, like the earlier single-head attention, we can change the projection dimension d_out to 1:

In [36]:
torch.manual_seed(123)

d_out = 1
mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-9.1476e-02,  3.4164e-02],
         [-2.6796e-01, -1.3427e-03],
         [-4.8421e-01, -4.8909e-02],
         [-6.4808e-01, -1.0625e-01],
         [-8.8380e-01, -1.7140e-01],
         [-1.4744e+00, -3.4327e-01]],

        [[-9.1476e-02,  3.4164e-02],
         [-2.6796e-01, -1.3427e-03],
         [-4.8421e-01, -4.8909e-02],
         [-6.4808e-01, -1.0625e-01],
         [-8.8380e-01, -1.7140e-01],
         [-1.4744e+00, -3.4327e-01]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


### 3.6.2 Multi-head attention through weight splitting

- While the above is the most intuitive and fully functional implementation of multi-head attention (encapsulating the earlier single-head attention CausalAttention implementation), we can also write a separate class called MultiHeadAttention to implement the same functionality.

- For this separate MultiHeadAttention class, we do not connect the individual attention heads together. Instead, we create individual W_query, W_key, and W_value weight matrices, and then split them into separate matrices for each attention head:

In [35]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
        super().__init__()
# Because the weight matrix needs to be split according to the number of attention heads, all output dimensions must be integer multiples of the number of heads
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.num_heads = num_heads
# head_dim is the dimension that each head should output after splitting
        self.head_dim = d_out // num_heads 

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

# shape is (b, num_tokens, d_out)
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

# We can split the matrix into each head by adding a dimension of num_heads
# Dimension change: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

# Calculate attention weights
# Based on matrix multiplication, simply implement parallel computing of each head
        attn_scores = queries @ keys.transpose(2, 3) 
# Generally we convert the mask matrix to bool values ​​and truncate based on the length of the sequence
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# The mask matrix needs to be unsqueezed twice, that is, two dimensions are added, so that the dimensions of the mask matrix and the attention weights correspond
        mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)
# Use a mask matrix for masking
        attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

# shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
# Reassemble the outputs of multiple heads self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

# Experiment
torch.manual_seed(123)

batch_size, block_size, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


- Note that the above is actually a more efficient rewrite of MultiHeadAttentionWrapper.
- The final output looks a bit different due to differences in the random weight initialization, but both are perfectly usable implementations that will be used in the GPT class implemented in subsequent chapters.
- In addition, we added a linear projection layer (self.out_proj) to the MultiHeadAttention class above. This is just a linear transformation that does not change the dimensionality. Using such a projection layer in LLM implementations is a standard practice, but not strictly necessary (recent research shows that it can be removed without affecting modeling performance; see the Further Reading section at the end of this chapter)

If you are interested in a more sophisticated and efficient implementation of multi-head attention, you may consider using PyTorch’s [`torch.nn.MultiheadAttention`](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) class.

- The above implementation may look a bit complicated, let's take a look at what happens when we run `attn_scores = queries @ keys.transpose(2, 3)`:

In [38]:
# (b, num_heads, num_tokens, head_dim) = (1, 2, 3, 4)
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],

                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

print(a @ a.transpose(2, 3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


- In this case, the matrix multiplication implementation in PyTorch will process the 4-dimensional input tensor to do a matrix multiplication between the last two dimensions (num_tokens, head_dim), and then repeat for each head.

- For example, the above becomes a more compact matrix multiplication that computes each head separately:

In [39]:
first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head:\n", first_res)

second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)

First head:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])

Second head:
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


In [40]:
block_size = 1024
d_in, d_out = 768, 768
num_heads = 12

mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(mha)

2360064

# Summary and Gains

You can check out the [./multihead-attention.ipynb](./multihead-attention.ipynb) code notebook, which is a concise version of DataLoader (Chapter 2), plus the multi-head attention class we implemented in this chapter, which we will use when training the GPT model in subsequent chapters.