In [None]:
import torch
import torch.nn.functional as F

In [None]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

In [None]:
xbow = torch.zeros(B, T, C)
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] #(t, C)
        xbow[b,t] = torch.mean(xprev, 0)


In [None]:
x[0], xbow[0]

In [None]:
torch.manual_seed(18)
a = torch.ones(3,3)
b = torch.randint(0,10, (3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

In [None]:
torch.manual_seed(18)
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

In [None]:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x

In [None]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x

In [None]:
torch.allclose(xbow, xbow3)

In [None]:
wei

In [None]:
# last version refined
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
out = wei @ x # (T,T) * (B, T, C) => Batch mm (B, T, T) * (B, T, C) => (B, T, C)
print(out.shape)
# the affinity from token_i to all its past token is uniform at this point in time

In [55]:
# last version refined
import torch.nn as nn
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# single head SA
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, C) * (C, head_size) => batch mm => (B, T, head_size)
q = query(x) # same
wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x # (T,T) * (B, T, C) => Batch mm (B, T, T) * (B, T, C) => (B, T, C)
print(out.shape)

torch.Size([4, 8, 32])


In [56]:
print(wei[0])

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)


In [57]:
# last version refined with value introduced
import torch.nn as nn
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# single head SA
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, C) * (C, head_size) => batch mm => (B, T, head_size)
q = query(x) # same
v = value(x) # same
wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
# out = wei @ x # (T,T) * (B, T, C) => Batch mm (B, T, T) * (B, T, C) => (B, T, C)
out = wei @ v #
print(out.shape)

torch.Size([4, 8, 16])


In [60]:
print(v[0])

tensor([[-0.1571,  0.8801,  0.1615, -0.7824, -0.1429,  0.7468,  0.1007, -0.5239,
         -0.8873,  0.1907,  0.1762, -0.5943, -0.4812, -0.4860,  0.2862,  0.5710],
        [ 0.8321, -0.8144, -0.3242,  0.5191, -0.1252, -0.4898, -0.5287, -0.0314,
          0.1072,  0.8269,  0.8132, -0.0271,  0.4775,  0.4980, -0.1377,  1.4025],
        [ 0.6035, -0.2500, -0.6159,  0.4068,  0.3328, -0.3910,  0.1312,  0.2172,
         -0.1299, -0.8828,  0.1724,  0.4652, -0.4271, -0.0768, -0.2852,  1.3875],
        [ 0.6657, -0.7096, -0.6099,  0.4348,  0.8975, -0.9298,  0.0683,  0.1863,
          0.5400,  0.2427, -0.6923,  0.4977,  0.4850,  0.6608,  0.8767,  0.0746],
        [ 0.1536,  1.0439,  0.8457,  0.2388,  0.3005,  1.0516,  0.7637,  0.4517,
         -0.7426, -1.4395, -0.4941, -0.3709, -1.1819,  0.1000, -0.1806,  0.5129],
        [-0.8920,  0.0578, -0.3350,  0.8477,  0.3876,  0.1664, -0.4587, -0.5974,
          0.4961,  0.6548,  0.0548,  0.9468,  0.4511,  0.1200,  1.0573, -0.2257],
        [-0.4849,  0.1