In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [3]:
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [4]:
# version 1: bag of words
xbow = torch.zeros((B,T,C)) # bag of words
for b in range(B): # iterative approach
    for t in range(T):
        xprev = x[b, :t+1]
        xbow[b,t] = torch.mean(xprev, 0)

In [5]:
# version 2: matrix multiply
wei = torch.tril(torch.ones(T, T)) # parrarel approach (low left triangular matrix with rows summing up to 1)
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (T, T) @ (B, T, C) -> (B, T, C)
wei, xbow2[0]

(tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
         [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
         [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]]),
 tensor([[ 0.6245, -1.9711],
         [-0.0555, -1.4861],
         [ 0.1136, -1.0268],
         [-0.4454, -0.4890],
         [-0.5955, -0.0595],
         [-0.3372, -0.1447],
         [-0.4577,  0.1969],
         [-0.3316,  0.0153]]))

In [7]:
# version 3: SoftMax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float("-inf"))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
wei, xbow3

torch.Size([4, 8, 2])

In [61]:
# version 4: (masked) self-attention !!
B, T, C = 4,8,32
x = torch.randn(B, T, C)
tril = torch.tril(torch.ones(T, T)) # lower triangular structure

head_size = 16
key   = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
v = value(x)

wei = q @ k.transpose(-2, -1)                 # Key @ Query: (B,T,16) @ (T,B,16) -> (B, T, T)
wei = wei.masked_fill(tril==0, float("-inf")) # infinities are "not allowed" to communicate, so "Masked self-attention", NOTE: remove this line for encoder!
wei /= head_size**-0.5                        # making the variance be 1
wei = F.softmax(wei, dim=-1)                  # average out, otherwise softmax may converge to one-hot vectors, which waste space, computation and may lead to dead neurons
out = wei @ v
out.shape, wei[:1]

(torch.Size([4, 8, 16]),
 tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00],
          [9.5054e-01, 4.9456e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00],
          [1.6184e-01, 6.3264e-01, 2.0551e-01, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00],
          [5.3194e-04, 1.6164e-05, 9.9763e-01, 1.8206e-03, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00],
          [9.9920e-01, 3.4292e-07, 5.1411e-06, 2.2355e-07, 7.9542e-04,
           0.0000e+00, 0.0000e+00, 0.0000e+00],
          [2.4315e-06, 9.8354e-01, 1.4112e-02, 6.3690e-04, 1.6176e-03,
           8.7480e-05, 0.0000e+00, 0.0000e+00],
          [3.1648e-05, 9.8119e-01, 2.1116e-05, 6.1659e-08, 1.8755e-02,
           8.0306e-07, 3.5688e-07, 0.0000e+00],
          [5.4150e-06, 7.4348e-06, 2.7379e-03, 3.5637e-08, 9.9164e-01,
           5.4801e-03, 8.3817e-05, 4.0323e-05]]], grad_fn=<SliceBackwar