In [13]:
import torch
import torch.nn as nn

In [3]:
torch.manual_seed(42)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [4]:
xbow = torch.zeros(B,T,C) # bag of words (average of all words)
for b in range(B):
    for t in range(T):
        xbow[b, t] = x[b, :t+1].mean(dim=0)

print(x[0])
xbow[0]

tensor([[ 1.9269,  1.4873],
        [ 0.9007, -2.1055],
        [ 0.6784, -1.2345],
        [-0.0431, -1.6047],
        [-0.7521,  1.6487],
        [-0.3925, -1.4036],
        [-0.7279, -0.5594],
        [-0.7688,  0.7624]])


tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])

In [5]:
tril = torch.tril(torch.ones(T,T))
tril[tril == 0] = -torch.inf
tril = torch.softmax(tril, dim=1)
# tril = tril.repeat(B,1,1) # unnecessary

In [6]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros(T,T)
wei = wei.masked_fill(tril == 0, -torch.inf)
wei = torch.nn.functional.softmax(wei, dim=1)

In [7]:
x_weighted = wei @ x
x_weighted.shape, x_weighted[0]

(torch.Size([4, 8, 2]),
 tensor([[ 1.9269,  1.4873],
         [ 1.4138, -0.3091],
         [ 1.1687, -0.6176],
         [ 0.8657, -0.8644],
         [ 0.5422, -0.3617],
         [ 0.3864, -0.5354],
         [ 0.2272, -0.5388],
         [ 0.1027, -0.3762]]))

In [34]:
# self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

# single head of self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)   # (B,T,head_size)
q = query(x) # (B,T,head_size)
wei = q @ k.transpose(-2, -1) # (B,T,T)
# normalize variance back to unit after dot multiplication
# this allows softmax to not converge to one hot afterwards
wei = wei / (head_size ** 0.5)

tril = torch.tril(torch.ones(T,T))
# wei = torch.zeros(T,T)
wei = wei.masked_fill(tril == 0, -torch.inf)
wei = torch.nn.functional.softmax(wei, dim=-1)

v = value(x) # (B,T,head_size)
out = wei @ v

print(x.shape, out.shape)
print(wei[0])

torch.Size([4, 8, 32]) torch.Size([4, 8, 16])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3966, 0.6034, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3069, 0.2892, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3233, 0.2175, 0.2443, 0.2149, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1479, 0.2034, 0.1663, 0.1455, 0.3369, 0.0000, 0.0000, 0.0000],
        [0.1259, 0.2490, 0.1324, 0.1062, 0.3141, 0.0724, 0.0000, 0.0000],
        [0.1598, 0.1990, 0.1140, 0.1125, 0.1418, 0.1669, 0.1061, 0.0000],
        [0.0845, 0.1197, 0.1078, 0.1537, 0.1086, 0.1146, 0.1558, 0.1553]],
       grad_fn=<SelectBackward0>)
