# Chapter 3: Coding attention mechanisms

In [1]:
import torch

import util

## 1. Simplified self-attention

### Create embedding

In [2]:
eos = '<|endoftext|>'
text = "Your journey starts with one step" + eos

In [3]:
data_loader = util.create_dataloader_v1(
    text,
    batch_size=1,
    context_window=6,
    stride=7,
)
data_iter = iter(data_loader)
x, y = next(data_iter)
# Need to concat from the last element of the target tensor,
# otherwise, setting context window as 4 results in a crash.
token_ids = torch.cat((x[0], y[0, -1:]))
print(token_ids)

tensor([ 7120,  7002,  4940,   351,   530,  2239, 50256])


In [4]:
tokens_len: int = token_ids.size(0)
vocab_size: int = torch.max(token_ids).item() + 1
embed_dim: int = 3

In [5]:
torch.manual_seed(123)
tok_embed_layer = torch.nn.Embedding(vocab_size, embed_dim)
pos_embed_layer = torch.nn.Embedding(tokens_len, embed_dim)
inputs = tok_embed_layer(token_ids) + pos_embed_layer(torch.arange(tokens_len))
print(inputs)

tensor([[ 0.8311,  1.3393, -1.1653],
        [-0.9936,  0.8519, -2.3310],
        [ 2.2730,  1.0514, -0.6150],
        [ 1.2780, -0.2958, -1.4757],
        [-2.9027,  3.0901,  0.6925],
        [-0.7583,  0.3646, -0.9988],
        [-2.1264,  0.5579, -1.1521]], grad_fn=<AddBackward0>)


### Attention Weight (for "journey")

In [6]:
query = inputs[1]
attn_scores_2 = torch.empty(tokens_len)
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
attn_weights_2 = torch.softmax(attn_scores_2, dim=-1)
print("Attention weights:")
print(torch.round(attn_weights_2, decimals=3))
print("Sum:", attn_weights_2.sum())

Attention weights:
tensor([0.0130, 0.8070, 0.0010, 0.0040, 0.0310, 0.0190, 0.1240],
       grad_fn=<RoundBackward1>)
Sum: tensor(1., grad_fn=<SumBackward0>)


### Attention weight (for all inputs)

In [7]:
attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(attn_scores, dim=-1)
print("Attention weights:")
print(torch.round(attn_weights, decimals=3))
print("Sum:", attn_weights.sum())

Attention weights:
tensor([[0.3320, 0.1480, 0.3950, 0.0770, 0.0180, 0.0200, 0.0100],
        [0.0130, 0.8070, 0.0010, 0.0040, 0.0310, 0.0190, 0.1240],
        [0.0640, 0.0010, 0.8960, 0.0380, 0.0000, 0.0010, 0.0000],
        [0.1070, 0.0670, 0.3250, 0.4840, 0.0000, 0.0150, 0.0030],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000],
        [0.0380, 0.4050, 0.0070, 0.0200, 0.1900, 0.0750, 0.2650],
        [0.0010, 0.1030, 0.0000, 0.0000, 0.6370, 0.0100, 0.2490]],
       grad_fn=<RoundBackward1>)
Sum: tensor(7.0000, grad_fn=<SumBackward0>)


### Context vectors

In [8]:
context_vectors = attn_weights @ inputs
print("Context vectors:")
print(context_vectors)

Context vectors:
tensor([[ 1.0378,  1.0312, -1.1078],
        [-1.1537,  0.8782, -2.0441],
        [ 2.1363,  1.0175, -0.6857],
        [ 1.3623,  0.4056, -1.2118],
        [-2.9026,  3.0900,  0.6925],
        [-1.5024,  1.1598, -1.2709],
        [-2.4877,  2.1992, -0.0969]], grad_fn=<MmBackward0>)


## 2. Self-attention

In [9]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [10]:
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [11]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)
print(key_2)
print(value_2)

tensor([-1.0809, -0.4605], grad_fn=<SqueezeBackward4>)
tensor([-1.2738, -1.7953], grad_fn=<SqueezeBackward4>)
tensor([-0.8775, -1.3379], grad_fn=<SqueezeBackward4>)


In [12]:
keys = inputs @ W_key
values = inputs @ W_value
print(f'keys.shape: {keys.shape}')
print(f'values.shape: {values.shape}')

keys.shape: torch.Size([7, 2])
values.shape: torch.Size([7, 2])


In [13]:
keys_2 = keys[1]
attn_score_2_2 = query_2.dot(keys_2)
print(attn_score_2_2.item())

2.2034785747528076


In [14]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([-0.5480,  2.2035, -2.2023, -0.1176,  1.1066,  1.2662,  2.6174],
       grad_fn=<SqueezeBackward4>)


In [15]:
attn_weights_2 = torch.softmax(attn_scores_2 / keys.shape[-1] ** 0.5, dim=-1)
print(f'Attention weights: {attn_weights_2}')
print(f'Attention weights sum: {attn_weights_2.sum()}')

Attention weights: tensor([0.0387, 0.2705, 0.0120, 0.0524, 0.1245, 0.1394, 0.3625],
       grad_fn=<SoftmaxBackward0>)
Attention weights sum: 0.9999999403953552


In [17]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([-0.7672, -1.1645], grad_fn=<SqueezeBackward4>)
