## 3.3 Attending to different parts of the input with self-attention

### 3.3.1 A simple self attention mechanism without trainable weights

In [1]:
import torch


# Inputs are the embeddings of the words in the sentence
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your     (x^1)
        [0.55, 0.87, 0.66],  # journey  (x^2)
        [0.57, 0.85, 0.64],  # starts   (x^3)
        [0.22, 0.58, 0.33],  # with     (x^4)
        [0.77, 0.25, 0.10],  # one      (x^5)
        [0.05, 0.80, 0.55],
    ]  # step     (x^6)
)

inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

Using the second token, `journey`, as the query:

In [2]:
query = inputs[1]
attention_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attention_scores_2[i] = torch.dot(query, x_i)

attention_scores_2

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

Now we normalize each of the attention scores.

We want to ensure that the sum of the attention weights is 1.

In [3]:
# Notice how we computed use the scores to compute the weights
attention_weights_2_tmp = attention_scores_2 / attention_scores_2.sum()
print("Attention weights:", attention_weights_2_tmp)
print("Sum:", attention_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


But it's better to use the `softmax` function for normalization. It's better at handling extreme values & gives better gradient properties during training.

In [4]:
attention_weights_2 = torch.softmax(attention_scores_2, dim=0)
print("Attention weights:", attention_weights_2)
print("Sum:", attention_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


The final step is to calculate the context vector by multiplying the embedded input tokens with the corresponding attention weights & summing the resulting vectors.

In [5]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    print(f"{i} | {attention_weights_2[i]:.2f} * {x_i}")
    context_vec_2 += attention_weights_2[i] * x_i

context_vec_2

0 | 0.14 * tensor([0.4300, 0.1500, 0.8900])
1 | 0.24 * tensor([0.5500, 0.8700, 0.6600])
2 | 0.23 * tensor([0.5700, 0.8500, 0.6400])
3 | 0.12 * tensor([0.2200, 0.5800, 0.3300])
4 | 0.11 * tensor([0.7700, 0.2500, 0.1000])
5 | 0.16 * tensor([0.0500, 0.8000, 0.5500])


tensor([0.4419, 0.6515, 0.5683])

In [6]:
# Compute attention scores
query = inputs[0]
attn_scores_1 = torch.zeros(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_1[i] = torch.dot(query, x_i)

# Normalize - compute attenttion weights
attn_weights_1 = torch.softmax(attn_scores_1, dim=0)

# Compute context vector
context_vec_1 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_1 += attn_weights_1[i] * x_i

context_vec_1

tensor([0.4421, 0.5931, 0.5790])

### 3.3.2 Computing attention weights for all input tokens

In [7]:
attn_scores = torch.empty((inputs.shape[0], inputs.shape[0]))
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

Each element in the `attn_scores` tensor represents an attention score between each pair of inputs.

You can imagine it being a matrix like this, excluding the labels:

| | Your | journey | starts | with | one | step |
|---|---|---|---|---|---|---|
| Your | 0.9995 | 0.9544 | 0.9422 | 0.4753 | 0.4576 | 0.6310 |
| journey | 0.9544 | 1.4950 | 1.4754 | 0.8434 | 0.7070 | 1.0865 |
| starts | 0.9422 | 1.4754 | 1.4570 | 0.8296 | 0.7154 | 1.0605 |
| with | 0.4753 | 0.8434 | 0.8296 | 0.4937 | 0.3474 | 0.6565 |
| one | 0.4576 | 0.7070 | 0.7154 | 0.3474 | 0.6654 | 0.2935 |
| step | 0.6310 | 1.0865 | 1.0605 | 0.6565 | 0.2935 | 0.9450 |

In [8]:
# A faster way:
attn_scores = inputs @ inputs.T  # or torch.matmul
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [9]:
# dim=-1 means the last dimension.
# For this rank 2 tensor, it means we're applying softmax along the second dimension of [rows, columns]. That is,
# we're normalizing across the columns, so the values in each row (summing over the column dimension) sum up to 1.
attn_weights = torch.softmax(attn_scores, dim=-1)
attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [10]:
context_vecs = attn_weights @ inputs
context_vecs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

## 3.4 Implementing self-attention with trainable weights

### 3.4.1 Computing the attention weights step by step

In [11]:
x_2 = inputs[1]
d_in = inputs.shape[1]  # input embedding size
d_out = 2  # output embedding size

In [12]:
torch.manual_seed(42)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

`requires_grad=False` to reduce clutter. If we were using the weight matrices for model training, we'd set it to `True` during training.

In [13]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

query_2

tensor([1.0760, 1.7344])

In [14]:
keys = inputs @ W_key
values = inputs @ W_value

keys.shape, values.shape

(torch.Size([6, 2]), torch.Size([6, 2]))

Now we've projected the six input tokens from a three-dimensional onto a two-dimensional embedding space.

In [15]:
keys_2 = keys[1]
attn_scores_22 = query_2.dot(keys_2)
attn_scores_22  # unnormalized attention score

tensor(3.3338)

In [16]:
# generalized:
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([2.7084, 3.3338, 3.3013, 1.7563, 1.7869, 2.1966])

From attention scores to attention weights.

Scale the attention scores by dividing them by the sqrt of the embedding dimension of the keys & then using the softmax fn.

We scale by the embedding dimension to improve training performance by avoiding small gradients.

Large dot products can lead to very small gradients during backprop due to softmax. As dot products increase, softmax becomes more like a step function, leading to gradients near zero. These can slow down training / cause it to stagnate.

We call this self-attention mechanism "scaled-dot product attention" due to this scaling by the sqrt of the embedding dimension.

In [17]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
attn_weights_2

tensor([0.1723, 0.2681, 0.2620, 0.0879, 0.0898, 0.1200])

In [18]:
context_vec_2 = attn_weights_2 @ values

context_vec_2

tensor([1.4201, 0.8892])

### 3.4.2 Implementing a compact self-attention Python class

In [19]:
import torch.nn as nn


class SelfAttention_v1(nn.Module):
    def __init__(self, d_in: int, d_out: int):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [20]:
torch.manual_seed(42)
sa_v1 = SelfAttention_v1(d_in, d_out)
sa_v1(inputs)

tensor([[1.3751, 0.8610],
        [1.4201, 0.8892],
        [1.4198, 0.8890],
        [1.3533, 0.8476],
        [1.3746, 0.8606],
        [1.3620, 0.8532]], grad_fn=<MmBackward0>)

We can use `nn.Linear` layers instead, which effectively perform matmuls when bias units are disabled.
And linear layers have an optimized weight initialization scheme, meaning more stable and effective training.

In [21]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in: int, d_out: int, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [22]:
torch.manual_seed(42)
sa_v2 = SelfAttention_v2(d_in, d_out)

# Exercise 3.1: checking that nn.Linear (bias=False) is similar to nn.Parameter, except for weight initialization
# sa_v2.W_query.weight = nn.Parameter(sa_v1.W_query.T)
# sa_v2.W_key.weight = nn.Parameter(sa_v1.W_key.T)
# sa_v2.W_value.weight = nn.Parameter(sa_v1.W_value.T)

sa_v2(inputs)  # different outputs due to different weight initialization schemes

tensor([[0.3755, 0.2777],
        [0.3761, 0.2831],
        [0.3761, 0.2833],
        [0.3768, 0.2763],
        [0.3754, 0.2836],
        [0.3772, 0.2746]], grad_fn=<MmBackward0>)

## 3.5 Hiding future words with causal attention

For some tasks, we want self-attention to only consider tokens appearing prior to the current position when predicting tokens in a sequence.

Causal attention is also known as masked attention.
It's a special form of self-attention that restricts the model to only consider previous and current inputs in a sequence.
We essentially mask out future tokens - tokens that come after the current token in the input.
Mask attention weights above the diagonal, and normalize the non-masked attention weights, so they sum to 1 in each row.

In [23]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
attn_weights

tensor([[0.1605, 0.1726, 0.1714, 0.1681, 0.1473, 0.1801],
        [0.1627, 0.1780, 0.1758, 0.1648, 0.1306, 0.1880],
        [0.1625, 0.1782, 0.1759, 0.1648, 0.1302, 0.1885],
        [0.1661, 0.1726, 0.1715, 0.1654, 0.1475, 0.1768],
        [0.1596, 0.1777, 0.1755, 0.1664, 0.1312, 0.1896],
        [0.1682, 0.1715, 0.1707, 0.1648, 0.1511, 0.1738]],
       grad_fn=<SoftmaxBackward0>)

In [24]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [25]:
# Using * on two matrices of the same shape with PyTorch does elementwise multiplication.
masked_simple = attn_weights * mask_simple
masked_simple

tensor([[0.1605, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1627, 0.1780, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1625, 0.1782, 0.1759, 0.0000, 0.0000, 0.0000],
        [0.1661, 0.1726, 0.1715, 0.1654, 0.0000, 0.0000],
        [0.1596, 0.1777, 0.1755, 0.1664, 0.1312, 0.0000],
        [0.1682, 0.1715, 0.1707, 0.1648, 0.1511, 0.1738]],
       grad_fn=<MulBackward0>)

In [26]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
masked_simple_norm

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4775, 0.5225, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3146, 0.3450, 0.3405, 0.0000, 0.0000, 0.0000],
        [0.2459, 0.2555, 0.2538, 0.2448, 0.0000, 0.0000],
        [0.1969, 0.2193, 0.2165, 0.2053, 0.1619, 0.0000],
        [0.1682, 0.1715, 0.1707, 0.1648, 0.1511, 0.1738]],
       grad_fn=<DivBackward0>)

Instead, we'll use a trick:

In [27]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[ 0.0508,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.2157,  0.3428,    -inf,    -inf,    -inf,    -inf],
        [ 0.2163,  0.3467,  0.3282,    -inf,    -inf,    -inf],
        [ 0.1257,  0.1799,  0.1707,  0.1191,    -inf,    -inf],
        [ 0.1667,  0.3193,  0.3012,  0.2258, -0.1098,    -inf],
        [ 0.1269,  0.1548,  0.1475,  0.0978, -0.0247,  0.1731]],
       grad_fn=<MaskedFillBackward0>)

In [28]:
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=-1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4775, 0.5225, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3146, 0.3450, 0.3405, 0.0000, 0.0000, 0.0000],
        [0.2459, 0.2555, 0.2538, 0.2448, 0.0000, 0.0000],
        [0.1969, 0.2193, 0.2165, 0.2053, 0.1619, 0.0000],
        [0.1682, 0.1715, 0.1707, 0.1648, 0.1511, 0.1738]],
       grad_fn=<SoftmaxBackward0>)

In [29]:
torch.triu(torch.ones(context_length, context_length), diagonal=1)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [30]:
torch.manual_seed(42)
dropout = torch.nn.Dropout(0.5)  # usually 0.1 or 0.2 for training GPT models
example = torch.ones(6, 6)
dropout(example)

tensor([[0., 0., 2., 2., 2., 2.],
        [2., 0., 2., 0., 2., 0.],
        [0., 0., 2., 2., 2., 0.],
        [2., 2., 0., 2., 0., 2.],
        [2., 0., 2., 2., 2., 2.],
        [2., 2., 2., 0., 2., 0.]])

~50% were scaled to zero. To compensate for the reduction in active elements, the rest were scaled up by a factor of $1/0.5=2$.
This is to maintain the overall balance of the weights, so the average influence of attention mechanisms is consistent both during training and inference.

In [31]:
dropout(attn_weights)

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9551, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6292, 0.0000, 0.6809, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5110, 0.0000, 0.4895, 0.0000, 0.0000],
        [0.3938, 0.4387, 0.4331, 0.0000, 0.0000, 0.0000],
        [0.3364, 0.0000, 0.0000, 0.0000, 0.0000, 0.3476]],
       grad_fn=<MulBackward0>)

Want to ensure the code can handle batches of more than one input.
But we just use the input twice here for convencience.

In [32]:
batch = torch.stack((inputs, inputs), dim=0)
batch.shape

torch.Size([2, 6, 3])

In [33]:
class CausalAttention(nn.Module):
    def __init__(
        self, d_in: int, d_out: int, context_length: int, dropout: float, qkv_bias=False
    ):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # Useful because the buffer is auto-moved to the appropriate device (CPU/GPU) with our model (e.g. when training)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        # [batch, num_tokens, d_in]
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(
            1, 2
        )  # Keep batch dim at position 0, but transpose dim 1 and 2
        # Trailing _ means inplace. Using it to avoid unnecessary memory copies.
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights_dropped = self.dropout(attn_weights)
        context_vec = attn_weights_dropped @ values
        return context_vec

In [None]:
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
context_vecs.shape

torch.Size([2, 6, 2])

^ Resulting context vector is now a three-dimensional tensor where each token is represented by a two-dimensional embedding.

# 3.6 Extending single-head attention to multi-head attention

In [34]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_out: int,
        context_length: int,
        dropout: float,
        num_heads: int,
        qkv_bias=False,
    ):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
                for _ in range(num_heads)
            ]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [35]:
torch.manual_seed(42)
context_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print(f"{context_vecs.shape=}")

tensor([[[0.4429, 0.1077, 0.5473, 0.3307],
         [0.4656, 0.2597, 0.3420, 0.2234],
         [0.4732, 0.3030, 0.2818, 0.1894],
         [0.4135, 0.2921, 0.2105, 0.1521],
         [0.4078, 0.2567, 0.2252, 0.1357],
         [0.3772, 0.2746, 0.1709, 0.1215]],

        [[0.4429, 0.1077, 0.5473, 0.3307],
         [0.4656, 0.2597, 0.3420, 0.2234],
         [0.4732, 0.3030, 0.2818, 0.1894],
         [0.4135, 0.2921, 0.2105, 0.1521],
         [0.4078, 0.2567, 0.2252, 0.1357],
         [0.3772, 0.2746, 0.1709, 0.1215]]], grad_fn=<CatBackward0>)
context_vecs.shape=torch.Size([2, 6, 4])


- The first dimension is `2` because we have two samples in our batch.
- The second dimension denotes the 6 tokens in each input.
- The third dimension is the 4-dimensional embeddings of each token.

`[batch_size, context_length, embedding_dimensions]`

**Exercise 3.2 Returning two-dimensional embedding vectors**
> Change the input arguments for the `MultiHeadAttentionWrapper(..., num_ heads=2)` call such that the output context vectors are two-dimensional instead of four dimensional while keeping the setting `num_heads=2`. Hint: You don’t have to modify the class implementation; you just have to change one of the other input arguments.

In [36]:
d_in, d_out = 3, 1
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print(f"{context_vecs.shape=}")

tensor([[[ 0.4788,  0.2593],
         [ 0.6509,  0.0176],
         [ 0.6989, -0.0575],
         [ 0.6378, -0.0745],
         [ 0.5993, -0.1042],
         [ 0.5900, -0.1018]],

        [[ 0.4788,  0.2593],
         [ 0.6509,  0.0176],
         [ 0.6989, -0.0575],
         [ 0.6378, -0.0745],
         [ 0.5993, -0.1042],
         [ 0.5900, -0.1018]]], grad_fn=<CatBackward0>)
context_vecs.shape=torch.Size([2, 6, 2])


Changing `d_out` to 1 means the context vector will be of `2*1=2` dimensions.

### 3.6.2 Implementing multi-head attention with weight splits

In [37]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_out: int,
        context_length: int,
        dropout: float,
        num_heads: int,
        qkv_bias=False,
    ):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = (
            d_out // num_heads
        )  # Reduces projection dim to match desired output dim
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # To combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # Tensor shape (b, num_tokens, d_out)
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(
            b, num_tokens, self.num_heads, self.head_dim
        )  # implicitly split the matrix by adding num_heads dimension, then unroll the last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transposes from shape (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(
            2, 3
        )  # compute dot product for each head
        mask_bool = self.mask.bool()[
            :num_tokens, :num_tokens
        ]  # masks truncated to the number of tokens

        attn_scores.masked_fill_(mask_bool, -torch.inf)  # uses mask to fill attn scores

        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(
            1, 2
        )  # tensor shape: (b, num_tokens, n_heads, head_dim)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional linear projection
        return context_vec


In [38]:
d_in, d_out = 3, 2
context_length = batch.shape[1]
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2, qkv_bias=False)
context_vecs = mha(batch)
print(context_vecs)
print(f"{context_vecs.shape=}")

tensor([[[-0.6380,  0.3370],
         [-0.7576,  0.2926],
         [-0.7891,  0.2779],
         [-0.7887,  0.2770],
         [-0.6782,  0.2563],
         [-0.7425,  0.2639]],

        [[-0.6380,  0.3370],
         [-0.7576,  0.2926],
         [-0.7891,  0.2779],
         [-0.7887,  0.2770],
         [-0.6782,  0.2563],
         [-0.7425,  0.2639]]], grad_fn=<ViewBackward0>)
context_vecs.shape=torch.Size([2, 6, 2])


**EXERCISE 3.3 INITIALIZING GPT-2 SIZE ATTENTION MODULES**
> Using the `MultiHeadAttention` class, initialize a multi-head attention module that has the same number of attention heads as the smallest GPT-2 model (12 attention heads). Also ensure that you use the respective input and output embedding sizes similar to GPT-2 (768 dimensions). Note that the smallest GPT-2 model supports a context length of 1,024 tokens.

In [39]:
n_heads = 12
d_in, d_out = 768, 768
context_length = 1024

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, n_heads, False)
mha

MultiHeadAttention(
  (W_query): Linear(in_features=768, out_features=768, bias=False)
  (W_key): Linear(in_features=768, out_features=768, bias=False)
  (W_value): Linear(in_features=768, out_features=768, bias=False)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)