In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [8]:

# the mathematical trick in self-attention
B, T, C = 4, 8, 2  # batch, time, channels
x_batch = torch.randn(B, T, C)
x_batch.shape

# version 1: x[b, t] = mean_{i <= t} x[b, i], aggregate the channel dimension.
xbow = torch.zeros(B, T, C)
for b in range(B):
    for t in range(T):
        xprev = x_batch[b, :t + 1]
        xbow[b, t] = torch.mean(xprev, dim=0)

In [10]:
# version 2: optimize the performance with matrix operation
weight2 = torch.tril(torch.ones(T, T))
weight2 = weight2 / weight2.sum(dim=1, keepdim=True)
xbow2 = weight2 @ x_batch  # (*B, T, T) @ (B, T, C) -> (B, T, C)  pytorch will extend the batch dimension

# version 3: use softmax
tril = torch.tril(torch.ones(T, T))
weight3 = torch.zeros(T, T)
weight3 = weight3.masked_fill(tril == 0, float('-inf'))
weight3 = F.softmax(weight3, dim=-1)
xbow3 = weight3 @ x_batch

Notes:
* Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
* There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
* Each example across batch dimension is processed completely independently and never communicate with each other. 
* "self-attention" just means that they keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other external source (e.g. an encoder module).
* "Scaled" attention additional divides `weight` by 1/sqrt(head_size). This makes it so when input Q, K are unit variance, `weight` will be unit variance too and Softmax will stay diffuse and not saturate too much.

In [18]:
# version 4: self-attention
B, T, C = 4, 8, 32  # batch, time, channels
x = torch.randn(B, T, C)

# single head self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)  # (B, T, 16)
q = query(x)  # (B, T, 16)
v = value(x)  # (B, T, 16)

weight4 = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) -> (B, T, T)
tril = torch.tril(torch.ones(T, T))
weight4 = weight4.masked_fill(tril == 0, float('-inf'))
weight4 = F.softmax(weight4, dim=-1)
print(weight4.shape)  # (B, T, T)
out = weight4 @ v  # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
print(out.shape)

torch.Size([4, 8, 8])
torch.Size([4, 8, 16])
