<a href="https://colab.research.google.com/github/winniema/mini_transformer/blob/main/Single_Head_of_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [25]:
import torch
import torch.nn as nn

torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)


# Implement a single head of attention
# Head size is the dimention which this attention head operates in. Usually it's of a lower dimension than n_embd
head_size = 16

key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (4,8,16)
q = query(x) # (4,8,16)
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # (4,8,16) @ (4,16,8) -> (4,8,8)

# Mask future tokens
tril = torch.tril(torch.ones(T, T))

wei = wei.masked_fill(tril ==0, float('-inf'))
wei = torch.softmax(wei, dim=-1)

v = value(x) # (4,8,16)
out = wei @ v # (4,8,8) @ (4,8,16) -> (4,8,16)

Notes:
*   Attention is a communication mechanism for nodes in a directed graph. In the above case, each token has an edge that points to every token after it. For token `t`, this means there's an edge that points to it from every preceding token.
*   As such, each node knows about the nodes that precede it, but no additional positional information. Specifically, `t` doesn't know that another node, let's say `t2`, is 3 nodes behind. `t` just knows that `t2` comes before itself.
*   Self attention means that the keys are produced from the same source as the queries. Cross-attention means that the keys and queries are produced from different sources. The keys describe what you're "looking" for and the queries describe what it "is". An example of cross-attention is translation where the keys come from the source language, and the queries come from the language to be translated into.
*   "Scaled" self attention is ensuring that the weight matrix has a variance of 1 before softmax() is applied. We do this by dividing the product of Key and Query by squareroot of the head size. The reasoning is that softmax will converge towards 1-hot vectors, vectors that are all 0s except for one 1, if the variance is high. This is bad at initialization because the model doesn't get a chance to "learn" about all the tokens equally, and decide how to adjust the weights based on backpropagation, instead it will overindex on specific tokens simply because of how it was randomly initialized.
