In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [4]:
tril = torch.tril(torch.ones(T,T))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [5]:
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei,dim=1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [6]:
xbow = wei @ x
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [13]:
B,T,C = 4,8,32
x = torch.randn(B,T,C)

tril = torch.tril(torch.ones(T,T))
head_size = 16
key = nn.Linear(C,head_size,bias=False)
query = nn.Linear(C,head_size,bias=False)
value = nn.Linear(C,head_size,bias=False)
k = key(x) #(B,T,C)
q = query(x) #(B,T,C)
v = value(x)
wei = q@k.transpose(-2,-1) #(B,T,T)
wei = wei * head_size**-0.5 # scaling to not lose data when softmax
wei = wei.masked_fill(tril==0,float('-inf'))
wei = F.softmax(wei,dim=-1)
out = wei@v
print(out.shape)



torch.Size([4, 8, 16])


In [14]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4264, 0.5736, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3985, 0.2439, 0.3576, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2071, 0.3279, 0.2295, 0.2354, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1905, 0.1367, 0.2292, 0.1947, 0.2490, 0.0000, 0.0000, 0.0000],
        [0.1581, 0.1656, 0.1326, 0.2218, 0.1998, 0.1221, 0.0000, 0.0000],
        [0.1261, 0.1103, 0.1771, 0.0911, 0.0658, 0.1539, 0.2756, 0.0000],
        [0.1490, 0.1364, 0.1155, 0.1084, 0.1570, 0.0833, 0.1377, 0.1126]],
       grad_fn=<SelectBackward0>)

In [11]:
out[0]

tensor([[-1.4322, -0.2810, -2.2789, -1.5010, -0.5178, -0.0930,  0.7448,  0.2769,
         -1.3683, -0.1367,  0.5261,  0.8502,  0.5255, -1.4073, -0.8778,  1.5681,
          0.5790, -1.0601, -0.1289,  0.0574, -2.1171,  0.5979, -0.8894, -0.1832,
          2.1316,  0.4207, -1.9636, -0.4431,  2.0773, -0.8678,  0.4456, -0.8511],
        [-1.0838, -0.1815, -0.7352,  0.4120,  0.4859,  0.7286, -0.9462, -0.3681,
          0.0187, -0.9503,  0.8314,  1.1676, -0.3323, -0.6562, -0.6349, -0.7566,
          0.0235, -0.1110,  0.5504, -0.6057, -1.1194, -0.4809,  0.3112, -0.6086,
          0.5210, -0.3065,  1.0929, -0.4713, -0.1311,  0.2276,  1.6285,  0.7358],
        [-0.9194,  0.3455, -0.8164,  0.0958, -0.0866,  0.8167, -0.2980, -0.0519,
         -0.5104, -0.7264,  0.6249,  0.6054, -0.2808, -0.9529, -0.6782,  0.4328,
          0.1139, -0.1076,  0.1905, -0.0101, -0.8987,  0.1063, -0.4636, -0.5507,
          0.7236, -0.2252, -0.1500, -0.2930,  1.1695,  0.0935,  0.6896,  0.3067],
        [-0.6774,  0.1403