# 3.5 Hiding Subsequent Words Using Causal Attention

In this section, we will modify the standard self-attention mechanism to create a causal attention mechanism, which is essential for the development of large language models in subsequent chapters.

Causal attention, also known as masked attention, is a special form of self-attention. It restricts the model to only consider the previous and current inputs in the sequence when processing any given token. This is in contrast to the standard self-attention mechanism, which allows access to the entire input sequence at once.

Therefore, when calculating the attention score, the causal attention mechanism ensures that the model only considers tokens that appear before or before the current token in the sequence.

In large language models like GPT, to achieve this, we mask the subsequent tokens after the current token in the input text for each processed token, as shown in Figure 3.19.

**Figure 3.19 In causal attention, we mask the attention weights above the diagonal so that the large language model cannot access subsequent tokens when calculating the context vector. For example, in the second row, for the word "journey", we only keep the attention weights for "Your" (the previous word) and "journey" (the current position). **

![3.19](../img/fig-3-19.jpg)

As shown in Figure 3.19, we mask the attention weights above the diagonal and normalize the unmasked attention weights so that the sum of the attention weights in each row is 1. In the next section, we will implement this masking and normalization process in code.

## 3.5.1 Applying Causal Attention Masking

In this section, we will implement causal attention masking in code. We start with the procedure summarized in Figure 3.20.

**Figure 3.20 One way to obtain the masked attention weight matrix in the causal attention mechanism is to apply a softmax function to the attention scores, zero the elements above the diagonal and normalize the resulting matrix. **

![3.20](../img/fig-3-20.jpg)

To implement the causal attention masking step shown in Figure 3.20 and obtain the masked attention weights, let's encode the causal attention mechanism using the attention scores and weights from the previous section.

In the first step, as shown in Figure 3.20, we calculate the attention weights using the softmax function, as we did in the previous sections:

In [5]:
queries = sa_v2.W_query(inputs)  #A
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward>)


Get the following attention weights: 
```python
tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
[0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
[0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
[0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
[0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<SoftmaxBackward0>)
```

We can use PyTorch’s tril function to implement the second step in Figure 3.20, creating a mask so that the values ​​above the diagonal are zero:

In [6]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


The resulting mask is as follows:
```python
tensor([[1., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1.]])
```

Now, we can multiply this mask with the attention weights to zero out the values ​​above the diagonal:

In [7]:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


We can see that the elements above the diagonal have been successfully zeroed: 
```python
tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
[0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
[0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<MulBackward0>)
```

The third step in Figure 3.20 is to renormalize the attention weights so that they sum to 1 again for each row. We can do this by dividing each element in each row by the sum of that row:

In [8]:
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


We get an attention weight matrix where the attention weights above the diagonal are zeroed and the sum of each row is 1:
```python
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
[0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
[0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<DivBackward0>)
```

### Information Leakage

When we apply masking and then renormalize the attention weights, it may appear that information from subsequent tokens (that we intend to mask) still affects the current token, since their values ​​are part of the softmax function calculation. However, the key point is that when we renormalize the attention weights after masking, we are actually recalculating the softmax function on a smaller subset (since the masked positions do not contribute anything to the softmax value).

The mathematical elegance of softmax is that, although all positions are included in the denominator in the initial calculation, after masking and renormalization, the influence of the masked positions is eliminated - they do not affect the softmax score in any meaningful way.

In short, after masking and renormalization, the distribution of attention weights is as if they were only calculated on unmasked positions in the beginning. This ensures that information from subsequent (or other masked) tokens does not leak as we might think.

Although the implementation of causal attention is technically complete at this point, we can exploit a mathematical property of the softmax function and implement the calculation of the masked attention weights more efficiently in fewer steps, as shown in Figure 3.21.

**Figure 3.21 A more efficient way to obtain the masked attention weight matrix in causal attention is to apply the softmax functionPreviously, the attention scores were masked with negative infinity values. **

![3.21](../img/fig-3-21.jpg)

The softmax function converts its input into a probability distribution. When there is a negative infinity (-∞) value in a row, the softmax function treats its probability as zero. (Mathematically, this is because e^-∞ approaches 0.)

We can implement this more efficient masking technique by creating a mask with 1s above the diagonal, and then replacing these 1s with negative infinity (-inf) values:

In [19]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


The following masking is obtained: 
```python
tensor([[0.2899, -inf, -inf, -inf, -inf, -inf],
[0.4656, 0.1723, -inf, -inf, -inf, -inf],
[0.4594, 0.1703, 0.1731, -inf, -inf, -inf],
[0.2642, 0.1024, 0.1036, 0.0186, -inf, -inf],
[0.2183, 0.0874, 0.0882, 0.0177, 0.0786, -inf],
[0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
grad_fn=<MaskedFillBackward0>)
```

Now, we just need to apply a softmax function to these masked results and we are done:

In [20]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward>)


As can be seen from the output, the sum of the values ​​in each row is 1, and no further normalization is needed: 
```python
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
[0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
[0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<SoftmaxBackward0>)
```

We can now use the modified attention weights to compute the context vector via context_vec = attn_weights@values, as we did in Section 3.4.

In the next section, we will first introduce another small tweak to the causal attention mechanism, which is useful for reducing overfitting when training large language models.

## 3.5.2 Masking Additional Attention Weights via Dropout

In deep learning, dropout is a technique where selected hidden layer units are randomly ignored during training, effectively "dropping" them. This approach helps prevent overfitting by ensuring that the model does not become overly dependent on any particular group of hidden layer units. It is important to emphasize that dropout is only used during training and not afterwards.

In Transformer architectures, including GPT, dropout in the attention mechanism is usually applied in two specific areas: after the attention scores are calculated, or after the attention weights are applied to the value vector.

Here, we will apply dropout masking after the attention weights are calculated, as shown in Figure 3.22, which is the more common variant in practice.

**Figure 3.22 Using causal attention masking (top left), we apply additional dropout masking (top right) to zero additional attention weights to reduce overfitting during training. **![3.22](../img/fig-3-22.jpg)

In the following code example, we use a 50% dropout rate, which means masking out half of the attention weights. (When training the GPT model in later chapters, we will use lower dropout rates, such as 0.1 or 0.2.)

In the following code, we first apply PyTorch's dropout implementation to a 6x6 tensor of 1s for illustration purposes:

In [21]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) #A
example = torch.ones(6, 6) #B
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


As you can see, about half of the values ​​are reset to zero:
```python
tensor([[2., 2., 0., 2., 2., 0.],
[0., 0., 0., 2., 0., 2.],
[2., 2., 2., 2., 0., 2.],
[0., 2., 2., 0., 0., 2.],
[0., 2., 0., 2., 0., 2.],
[0., 2., 0., 2., 0., 2.],
[0., 2., 2., 2., 2., 0.]])
```

When a 50% dropout rate is applied to the attention weight matrix, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values ​​of the remaining elements in the matrix are scaled up by a factor of 1/0.5 = 2. This scaling is critical to maintaining the overall balance of the attention weights, ensuring that the average impact of the attention mechanism remains consistent during training and inference.

Now, let's apply dropout to the attention weight matrix itself:

In [22]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


More elements in the processed attention weight matrix are zeroed, and the remaining elements are rescaled:
```python
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
[0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
[0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
[0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
grad_fn=<MulBackward0>
```

Note that the output of Dropout may vary depending on the operating system; you can read more about this inconsistency on the PyTorch issue tracker (https://github.com/pytorch/pytorch/issues/121595).

Having learned about causal attention and dropout masking, we will develop a concise Python class in the next section. This class is designed to facilitate the efficient application of both techniques.

## 3.5.3 Implementing a compact causal attention class

In this section, we integrate causal attention and dropout techniques into the SelfAttention Python class we developed in Section 3.4. This class will then serve as a template for developing multi-head attention in the upcoming section, which is the final attention class we will implement in this chapter.

But before we get started, there is one more thing to ensure, and that is that the code can handle batches consisting of multiple inputs so that the CausalAttention class supports the batch outputs generated by the data loader we implemented in Chapter 2.

To simplify simulating this batch input, we copy the input text example:

In [23]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) #A 

torch.Size([2, 6, 3])


This will generate a 3D tensor containing 2 input texts, each with 6 tokens, and each token is a 3D embedding vector:
```python
torch.Size([2, 6, 3])
```

The following CausalAttention class is similar to the SelfAttention class we implemented earlier, except that we have added the Dropout and causal masking parts highlighted in the code below:

### Listing 3.3 A compact causal attention class

In [29]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # New
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

    def forward(self, x):
        b, num_tokens, d_in = x.shape # New batch dimension b
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) 
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec

While all the new lines of code are similar to the code in the previous section, we now add a self.register_buffer() call in the ‘__init__’ method. Using register_buffer in PyTorch is not necessary in all cases, but it has several advantages here. For example, when we use the CausalAttention class in a large language model, the buffer is automatically moved to the appropriate device (CPU or GPU) along with the model, which will be useful when training large language models in subsequent sections. This means that we do not need to manually ensure that these tensors are on the same device as the model parameters, thus avoiding device mismatch errors.

We can use the CausalAttention class in the same way as we used the SelfAttention class before:

In [30]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


The resulting context vector is a 3D tensor where each token is now represented by a 2D embedding:
```python
context_vecs.shape: torch.Size([2, 6, 2])
```

Figure 3.23 provides a mental model that summarizes what we have accomplished so far.

Figure 3.23 A mental model that summarizes the four different attention modules we wrote in this chapter. We started with a simplified attention mechanism, added trainable weights, and then added causal attention masking. In the rest of this chapter, we will expand the causal attention mechanism and write a multi-head attention mechanism, which is the final module we will use in the large language model implementation in the next chapter.

![3.23](../img/fig-3-23.jpg)

As shown in Figure 3.23, in this section, we focus on the concept and implementation of causal attention in neural networks. In the next section, we will further expand this concept and implement a multi-head attention module that implements multiple such causal attention mechanisms in parallel.