# Chapter 3: Coding attention mechanisms

In [1]:
import torch
import torch.nn as nn

import util

## 1. Simplified self-attention

### Create embedding

In [2]:
eos = '<|endoftext|>'
text = "Your journey starts with one step" + eos

In [3]:
data_loader = util.create_dataloader_v1(
    text,
    batch_size=1,
    context_window=6,
    stride=7,
)
data_iter = iter(data_loader)
x, y = next(data_iter)
# Need to concat from the last element of the target tensor,
# otherwise, setting context window as 4 results in a crash.
token_ids = torch.cat((x[0], y[0, -1:]))
print(token_ids)

tensor([ 7120,  7002,  4940,   351,   530,  2239, 50256])


In [4]:
tokens_len: int = token_ids.size(0)
vocab_size: int = torch.max(token_ids).item() + 1
embed_dim: int = 3

In [5]:
torch.manual_seed(123)
tok_embed_layer = torch.nn.Embedding(vocab_size, embed_dim)
pos_embed_layer = torch.nn.Embedding(tokens_len, embed_dim)
inputs = tok_embed_layer(token_ids) + pos_embed_layer(torch.arange(tokens_len))
print(inputs)

tensor([[ 0.8311,  1.3393, -1.1653],
        [-0.9936,  0.8519, -2.3310],
        [ 2.2730,  1.0514, -0.6150],
        [ 1.2780, -0.2958, -1.4757],
        [-2.9027,  3.0901,  0.6925],
        [-0.7583,  0.3646, -0.9988],
        [-2.1264,  0.5579, -1.1521]], grad_fn=<AddBackward0>)


### Attention Weight (for "journey")

In [6]:
query = inputs[1]
attn_scores_2 = torch.empty(tokens_len)
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
attn_weights_2 = torch.softmax(attn_scores_2, dim=-1)
print("Attention weights:")
print(torch.round(attn_weights_2, decimals=3))
print("Sum:", attn_weights_2.sum())

Attention weights:
tensor([0.0130, 0.8070, 0.0010, 0.0040, 0.0310, 0.0190, 0.1240],
       grad_fn=<RoundBackward1>)
Sum: tensor(1., grad_fn=<SumBackward0>)


### Attention weight (for all inputs)

In [7]:
attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(attn_scores, dim=-1)
print("Attention weights:")
print(torch.round(attn_weights, decimals=3))
print("Sum:", attn_weights.sum())

Attention weights:
tensor([[0.3320, 0.1480, 0.3950, 0.0770, 0.0180, 0.0200, 0.0100],
        [0.0130, 0.8070, 0.0010, 0.0040, 0.0310, 0.0190, 0.1240],
        [0.0640, 0.0010, 0.8960, 0.0380, 0.0000, 0.0010, 0.0000],
        [0.1070, 0.0670, 0.3250, 0.4840, 0.0000, 0.0150, 0.0030],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000],
        [0.0380, 0.4050, 0.0070, 0.0200, 0.1900, 0.0750, 0.2650],
        [0.0010, 0.1030, 0.0000, 0.0000, 0.6370, 0.0100, 0.2490]],
       grad_fn=<RoundBackward1>)
Sum: tensor(7.0000, grad_fn=<SumBackward0>)


### Context vectors

In [8]:
context_vectors = attn_weights @ inputs
print("Context vectors:")
print(context_vectors)

Context vectors:
tensor([[ 1.0378,  1.0312, -1.1078],
        [-1.1537,  0.8782, -2.0441],
        [ 2.1363,  1.0175, -0.6857],
        [ 1.3623,  0.4056, -1.2118],
        [-2.9026,  3.0900,  0.6925],
        [-1.5024,  1.1598, -1.2709],
        [-2.4877,  2.1992, -0.0969]], grad_fn=<MmBackward0>)


## 2. Self-attention

In [9]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [10]:
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [11]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)
print(key_2)
print(value_2)

tensor([-1.0809, -0.4605], grad_fn=<SqueezeBackward4>)
tensor([-1.2738, -1.7953], grad_fn=<SqueezeBackward4>)
tensor([-0.8775, -1.3379], grad_fn=<SqueezeBackward4>)


In [12]:
keys = inputs @ W_key
values = inputs @ W_value
print(f'keys.shape: {keys.shape}')
print(f'values.shape: {values.shape}')

keys.shape: torch.Size([7, 2])
values.shape: torch.Size([7, 2])


In [13]:
keys_2 = keys[1]
attn_score_2_2 = query_2.dot(keys_2)
print(attn_score_2_2.item())

2.2034785747528076


In [14]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([-0.5480,  2.2035, -2.2023, -0.1176,  1.1066,  1.2662,  2.6174],
       grad_fn=<SqueezeBackward4>)


In [15]:
attn_weights_2 = torch.softmax(attn_scores_2 / keys.shape[-1] ** 0.5, dim=-1)
print(f'Attention weights: {attn_weights_2}')
print(f'Attention weights sum: {attn_weights_2.sum()}')

Attention weights: tensor([0.0387, 0.2705, 0.0120, 0.0524, 0.1245, 0.1394, 0.3625],
       grad_fn=<SoftmaxBackward0>)
Attention weights sum: 0.9999999403953552


In [16]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([-0.7672, -1.1645], grad_fn=<SqueezeBackward4>)


## 3. Attention Class

In [17]:
class SelfAttention_v1(nn.Module):
    _wq: torch.Tensor
    _wk: torch.Tensor
    _wv: torch.Tensor

    def __init__(self, d_in: int, d_out: int | None = None):
        super().__init__()
        if not d_out:
            d_out = d_in
        self._wq = nn.Parameter(torch.rand(d_in, d_out))
        self._wk = nn.Parameter(torch.rand(d_in, d_out))
        self._wv = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        queries = x @ self._wq
        keys = x @ self._wk
        values = x @ self._wv

        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1
        )
        return attn_weights @ values # context vector

In [18]:
self_attention_v1 = SelfAttention_v1(d_in, d_out)
simple_context_vec = self_attention_v1(inputs)
print(f'Simple context vec shape: {simple_context_vec.shape}')
print(simple_context_vec)

Simple context vec shape: torch.Size([7, 2])
tensor([[-0.5660,  0.3537],
        [-0.9693, -0.0085],
        [-0.4474,  0.4220],
        [-0.9014, -0.0317],
        [-0.6582,  0.7741],
        [-0.7078,  0.1000],
        [-0.7461,  0.0863]], grad_fn=<MmBackward0>)


## 2. Multi-head attention

In [19]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_out: int,
        context_length: int,
        dropout: float,
        num_heads: int,
        qkv_bias=False
    ): 
        super().__init__()
        assert (d_out % num_heads == 0), f'd_out ({d_out}) must be divisible by num_heads ({num_heads})'

        self._d_in = d_in
        self._d_out = d_out
        self._context_length = context_length
        self._num_heads = num_heads
        self._head_dim = d_out // num_heads

        self._w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self._w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self._w_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self._out_proj = nn.Linear(d_out, d_out)
        self._dropout = nn.Dropout(dropout)

        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, num_tokens, d_in = x.shape
        assert (d_in == self._d_in), f'input d_in ({d_in}) must be the same as d_in ({self._d_in})'

        # KQV shape: (b, num_heads, num_tokens, head_dim)
        keys = self._reshape_kqv(self._w_key(x), b, num_tokens)
        queries = self._reshape_kqv(self._w_query(x), b, num_tokens)
        values = self._reshape_kqv(self._w_value(x), b, num_tokens)

        # Attention scores and weights shape:  (b, num_heads, num_tokens, num_tokens)
        attn_scores = queries @ keys.transpose(2, 3)
        mask = self.mask[:num_tokens, :num_tokens]
        attn_weights = torch.softmax(attn_scores / self._head_dim ** 0.5, dim=-1)
        attn_weights = self._dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2) # (b, num_tokens, num_heads, head_dim)
        context_vec = context_vec.contiguous().view(b, num_tokens, self._d_out)
        return self._out_proj(context_vec) # (b, num_tokens, d_out)

    def _reshape_kqv(self, tensor: torch.Tensor, b: int, num_tokens: int) -> torch.Tensor:
        """
        Converts QKV weights from shape (b, num_tokens, d_out) to (b, num_heads, num_tokens, head_dim).
        """
        return tensor.view(b, num_tokens, self._num_heads, self._head_dim).transpose(1, 2)

In [20]:
context_length = 6
num_heads = 1
mh_attn = MultiHeadAttention(d_in, d_out * num_heads, context_length, dropout=0.1, num_heads=num_heads)
example_context_vec = mh_attn(inputs.unsqueeze(0))
print(f'Context vec shape: {example_context_vec.shape}')
print(example_context_vec)

Context vec shape: torch.Size([1, 7, 16])
tensor([[[ 2.0074e-01,  1.2954e-01, -2.0794e-01,  5.0414e-02,  4.3644e-01,
          -4.2802e-02, -9.0871e-01, -6.8125e-01,  3.1008e-03,  1.2606e-01,
          -8.9249e-02,  2.2884e-01,  5.0619e-01,  7.1841e-01,  8.4086e-01,
           2.5266e-01],
         [ 3.0884e-01,  2.3584e-01, -2.8430e-01,  2.5105e-01,  6.8879e-01,
          -1.1523e-01, -1.0889e+00, -7.1455e-01,  1.9503e-01,  1.5492e-01,
          -1.3607e-01,  9.3015e-02,  4.4244e-01,  9.9420e-01,  8.6440e-01,
           3.3103e-01],
         [-1.3366e-01, -1.0432e-01, -5.1187e-02,  2.6272e-02,  8.2505e-02,
          -1.2981e-01, -8.9960e-01, -7.9067e-01, -2.6319e-01, -6.0446e-02,
          -1.0746e-01,  3.7379e-01,  7.1751e-01,  3.8893e-01,  8.4441e-01,
           1.1615e-01],
         [ 1.2362e-01,  1.0733e-01, -1.8874e-01,  9.9266e-03,  3.9622e-01,
          -6.6299e-02, -9.5984e-01, -8.5638e-01, -7.6979e-05,  1.2348e-02,
          -1.4353e-01,  2.6240e-01,  5.9882e-01,  6.8316e-01,