# 3.6 Extending Single-Head Attention to Multi-Head Attention

In the last section of this chapter, we will extend the previously implemented CausalAttention class with multiple heads. This is called a multi-head attention mechanism.

The term "multi-head" refers to the idea of ​​dividing the attention mechanism into multiple "heads" that each operate independently. In this case, a single causal attention module can be viewed as a single-head attention, where only one set of attention weights processes the input sequentially.

In the following sections, we will extend from causal attention to multi-head attention. The first section will intuitively build a Multi-head Attention module by stacking multiple CausalAttention modules for illustration purposes. The second section will implement the same multi-head attention module in a more complex but computationally efficient way.

## 3.6.1 Stacking Multiple Single-head Attention Layers

In practice, implementing a multi-head attention mechanism requires creating multiple instances of the self-attention mechanism (as shown in Figure 3.18 in Section 3.4.1), each with its own weights, and then merging the outputs of these examples. Although using multiple instances of the self-attention mechanism is computationally expensive, it is essential for the complex pattern recognition required by large language models like the Transformer-based ones.

Figure 3.24 showsThe structure of the Multi-head Attention module is shown in Figure 3.18, which is composed of multiple Single-head Attention modules stacked together.

**Figure 3.24 The multi-head attention module in this figure is composed of two single-head attention modules stacked together. Therefore, in a multi-head attention module with two heads, we no longer use a single matrix Wv to calculate the value matrix, but two value weight matrices: Wv1 and Wv2. Similarly, Wq and Wk also have two sets of weight matrices each. We get two sets of context vectors Z1 and Z2, and then combine them into a context vector matrix Z. **

![3.24](../img/fig-3-24.jpg)

As mentioned earlier, the main idea of ​​multi-head attention is to run the attention mechanism multiple times (in parallel) through different, learned linear projections - that is, multiplying the input data (such as the query, key and value vectors in the attention mechanism) with the weight matrix.

In code, we can do this by implementing a simple MultiHeadAttentionWrapper class that stacks multiple instances of the CausalAttention module we implemented earlier:

### Listing 3.4 Implementing the MultiHeadAttentionWrapper class

In [4]:
from torch import nn
class MultiHeadAttentionWrapper(nn.Module):
   def __init__(self, d_in, d_out, context_length,
                 dropout, num_heads, qkv_bias=False):
       super().__init__()
       self.heads = nn.ModuleList(
           [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
            for _ in range(num_heads)]
       )
 
   def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

For example, if we use this MultiHeadAttentionWrapper class and set two attention heads via num_heads=2, and set the output dimension of CausalAttention to 2 (d_out=2), this will result in a four-dimensional context vector (d_out*num_heads=4), as shown in Figure 3.25.

**Figure 3.25 Using MultiHeadAttentionWrapper, we specify the number of attention heads (num_heads). If we set num_heads=2, as shown in the figure, we will get a tensor containing two sets of context vector matrices. In each context vector matrix, the rows represent the context vector corresponding to the token, and the columns correspond to the embedding dimension specified by d_out=4. We concatenate these context vector matrices along the column dimension. Since we have 2 attention heads and the embedding dimension is 2, the final embedding dimension is 2 × 2 = 4. **

![3.25](../img/fig-3-25.jpg)

To further illustrate Figure 3.25, we can use the MultiHeadAttentionWrapper class as we did with the CausalAttention class before:

In [6]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
 
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

NameError: name 'batch' is not defined

Get the following tensor to represent the context vector: 
```python
tensor([[[-0.4519, 0.2216, 0.4772, 0.1063],
[-0.5874, 0.0058, 0.5891, 0.3257],
[-0.6300, -0.0632, 0.6202, 0.3860],
[-0.5675, -0.0843, 0.5478, 0.3589],
[-0.5526, -0.0981, 0.5321, 0.3428],
[-0.5299, -0.1081, 0.5077, 0.3493]],

[[-0.4519, 0.2216, 0.4772, 0.1063],
[-0.5874, 0.0058, 0.5891, 0.3257],
[-0.6300, -0.0632, 0.6202, 0.3860],
[-0.5675, -0.0843, 0.5478, 0.3589],[-0.5526, -0.0981, 0.5321, 0.3428],
[-0.5299, -0.1081, 0.5077, 0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])
```

The first dimension of the context vector tensor is 2 because we have two input texts (the input texts are replicated, which is why these context vectors are exactly the same for them). The second dimension refers to the 6 tokens in each input. The third dimension refers to the 4-dimensional embedding of each token.

### Exercise 3.2 Returning a 2-D embedding vector

Change the input parameters of the MultiHeadAttentionWrapper (..., num_heads=2) call so that the output context vector is 2-dimensional instead of 4-dimensional, while keeping num_heads=2. Hint: You don't need to modify the implementation of the class, you only need to change one other input parameter.

In this section, we implemented MultiHeadAttentionWrapper , which combines multiple Single-head Attention modules. Note that these are processed sequentially in the forward method [head(x) for head in self.heads] . We can improve this implementation by processing the heads in parallel. One way to do this is to compute the outputs of all attention heads simultaneously via matrix multiplication, which we will explore in the next section.

## 3.6.2 Multi-Head Attention via Weight Splitting

In the previous section, we created a MultiHeadAtattentionWrapper to implement multi-head attention by stacking multiple Single-head Attention modules. This is done by instantiating and combining several CausalAttention objects.

Instead of maintaining two separate classes, we can merge the concepts of MultiHeadAttentionWrapper and CausalAttention into a single MultiHeadAttentionWrapper class. In addition to merging the code of MultiHeadAttentionWrapper with CausalAttention, we will also make some other modifications to implement the multi-head attention mechanism more efficiently.

In MultiHeadAttentionWrapper, multiple heads are implemented by creating a series of CausalAttention objects (self.heads), each representing a separate attention head. The CausalAttention class performs the attention mechanism independently, and the results of each head are concatenated. In contrast, the following MultiHeadAttention class integrates the multi-head functionality into a single class. It splits the input into multiple heads by reshaping the projected query, key, and value tensors, and then combines the results of these heads after computing the attention.Results.

Let's take a look at the MultiHeadAttention class before going further:

### Listing 3.5 An efficient MultiHeadAttention class

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, 
                 context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads #A
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) #B
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
 
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x) #C
        queries = self.W_query(x) #C
        values = self.W_value(x) #C
 
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) #D
        values = values.view(b, num_tokens, self.num_heads, self.head_dim) #
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
 
        keys = keys.transpose(1, 2) #E
        queries = queries.transpose(1, 2) #E
        values = values.transpose(1, 2) #E
 
        attn_scores = queries @ keys.transpose(2, 3)  #F 
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens] #G
  
        attn_scores.masked_fill_(mask_bool, -torch.inf) #H
 
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
 
        context_vec = (attn_weights @ values).transpose(1, 2) #I
#J
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) #K
        return context_vec

Although the tensor reshaping (.view) and transposition (.transpose) in the MultiHeadAttention class may seem very complicated, mathematically the MultiHeadAttention class implements the same concepts as the previous MultiHeadAttentionWrapper.

At a high level, in the previous MultiHeadAttentionWrapper, we stacked multiple Single-head Attention layers and then combined them into a MultiHeadAttention layer. The MultiHeadAttention class takes an integrated approach. It starts with a Multi-head Attention layer and then internally splits this layer into separate attention heads, as shown in Figure 3.26.

**Figure 3.26 In the MultiHeadAttentionWrapper class with two attention heads, we initialized two weight matrices Wq1 and Wq2 and calculated two query matrices Q1 and Q2, as shown at the top of the figure. In the MultiHeadAttention class, we initialize a larger weight matrix Wq, perform only one matrix multiplication with the input to obtain the query matrix Q, and then multiply the query matrix Q by the weight matrix Wq1 and Wq2.into Q1 and Q2, as shown at the bottom of the figure. We do the same for keys and values, which are not shown to reduce visual clutter. **

![3.26](../img/fig-3-26.jpg)

As shown in Figure 3.26, the splitting of query, key, and value tensors is achieved by using PyTorch's .view and .transpose methods to perform tensor reshaping and transposition operations. The input is first transformed through linear layers (for query, key, and value) and then reshaped to represent multiple heads.

The key operation is to split the d_out dimension into num_heads and head_dim, where head_dim = d_out / num_heads. This split is then achieved by the .view method: the tensor of dimension (b, num_tokens, d_out) is reshaped into dimension (b, num_tokens, num_heads, head_dim).

The tensor is then transposed so that the multi-head dimension (num_heads) comes before the sequence length dimension (num_tokens), forming a structure of (b, num_heads, num_tokens, head_dim). This transposition is critical for correctly matching queries, keys, and values ​​across heads, and for efficient batch matrix multiplication.To illustrate this batch matrix multiplication, suppose we have the following example tensor:

In [13]:
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573], #A
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],
 
                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

Now, we perform a batch matrix multiplication between the tensor itself and a view of the tensor with the last two dimensions transposed:

In [14]:
print(a @ a.transpose(2, 3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


The results are as follows: 
```python
tensor([[[[1.3208, 1.1631, 1.2879],
[1.1631, 2.2150, 1.8424],
[1.2879, 1.8424, 2.0402]],

[[0.4391, 0.7003, 0.5903],
[0.7003, 1.3737, 1.0620],
[0.5903, 1.0620, 0.9912]]]])
```

In this case, the matrix multiplication implementation in PyTorch can process a 4D input tensor so that the matrix multiplication is performed between the last two dimensions (num_tokens, head_dim) and then repeated for each head.

For example, the above method can more concisely calculate the matrix multiplication for each head separately:

In [15]:
first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head:\n", first_res)

second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)

First head:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])

Second head:
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


The results are exactly the same as what we obtained earlier using batch matrix multiplication print(a @ a.transpose(2, 3)) :
```python
First head:
tensor([[1.3208, 1.1631, 1.2879],
[1.1631, 2.2150, 1.8424],
[1.2879, 1.8424, 2.0402]])

Second head:
tensor([[0.4391, 0.7003, 0.5903],
[0.7003, 1.3737, 1.0620],
[0.5903, 1.0620, 0.9912]])
```

Continuing with the MultiHeadAttention operation, after computing the attention weights and context vectors, the context vectors of all heads are re-transposed to the shape of (b, num_tokens, num_heads, head_dim) . These vectors are then reshaped (flattened) to the shape of (b, num_tokens, d_out) , effectively merging the outputs of all heads.

In addition, after merging the heads, we added a so-called output projection layer ( self.out_proj ) to MultiHeadAttention , which is not present in the Causal Attention class. This output projection layer, while not strictly necessary (see the References section of Appendix B for more details), is common in many large language model architectures, which is why we add it here for completeness.

Although the MultiHeadAttention class looks more complex than MultiHeadAttentionWrapper due to the additional reshaping and tensor transposition, it is more efficient. The reason is that we only need to perform a single matrix multiplication to compute the key, e.g. keys = self.W_key(x) (this applies to both the query and the value). In MultiHeadAttentionWrapper, we need to repeat this matrix multiplication, which is computationallyOne of the most expensive steps, and needs to be repeated for each attention head.

The MultiHeadAttention class can be used just like the SelfAttention and CausalAttention classes we implemented earlier:

In [16]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

NameError: name 'batch' is not defined

From the results, the output dimension is directly controlled by the d_out parameter: 
```python
tensor([[[0.3190, 0.4858],
[0.2943, 0.3897],
[0.2856, 0.3593],
[0.2693, 0.3873],
[0.2639, 0.3928],
[0.2575, 0.4028]],

[[0.3190, 0.4858],
[0.2943, 0.3897],
[0.2856, 0.3593],
[0.2693, 0.3873],
[0.2639, 0.3928],
[0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
```

In this section, we implemented the MultiHeadAttention class that we will use to implement and train large language models in the upcoming sections. Note that while the code is fully functional, we used relatively small embedding sizes and number of attention heads to keep the output readable.

For comparison, the smallest GPT-2 model (117 million parameters) has 12 attention heads and a context vector embedding size of 768. The largest GPT 2 model (1.5 billion parameters) has 25 attention heads and a context vector embedding size of 1600. Note that in the GPT model, the embedding size of the token input and context embeddings is the same (d_in = d_out).

## Exercise 3.3 Initialize an Attention Module with GPT-2 Size
Using the MultiHeadAttention class, initialize a MultiHeadAttention module with the same number of attention heads as the smallest GPT-2 model (12 heads). Also make sure you use similar input and output embedding sizes as GPT-2 (768 dimensions). Note that the smallest GPT-2 model supports a context length of 1024 tokens.