In [2]:
import torch


torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [5]:
# x bag of words
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t + 1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

In [3]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [6]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [40]:
from torch.nn import functional as F

tril = torch.tril(torch.ones(T, T))
print(tril)
attn = torch.zeros((T, T))
attn = attn.masked_fill(tril == 0, float('-inf'))
print(attn)
attn = F.softmax(attn, dim=1)
print(attn)
xbow3 = attn @ x



xbow3.shape


tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000,

torch.Size([4, 8, 32])

In [41]:
import torch.nn as nn
# data dependent context
# query and key vector
# query: "what am i looking for"?
# key: "what do I contain"
# dot product between key and query

torch.manual_seed(1337)
B, T, C = 4, 8, 32 # batch, time, channels

# in channel-space
x = torch.randn(B, T, C) # (B, T, C)

head_size = 16

key = nn.Linear(C, head_size, bias=False)   # (C, 16)
query = nn.Linear(C, head_size, bias=False) # (C, 16)
value = nn.Linear(C, head_size, bias=False) # (C, 16)

# we shrink to key space (features weighting importance in)
# question: for transformers, why is one key and one query? Isn't it basically symmetric?
k = key(x)   # (B, T, C) @ (C, 16) = (B, T, 16)
q = query(x) # (B, T, C) @ (C, 16) = (B, T, 16)

# Attention matrix
# TxT
attn = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, 16) @ (B, 16, T)  =  (B, T, T)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))

# remove this block to make it a "encoder block"
# right now it's a decoder block
attn = attn.masked_fill(tril == 0, float('-inf'))

# Softmax attention weights
attn = F.softmax(attn, dim=-1) 

v = value(x)

out = attn @ v # (B, T, T) @ (B, T, 16) = (B, T, 16)

out.shape






torch.Size([4, 8, 16])

In [93]:
# lets try with einsum
B, T, C = 4, 8, 32 # batch, time, channels

x = torch.randn(B, T, C) # (B, T, C)
head_size = 16

key = nn.Linear(C, head_size, bias=False)   # (C, 16)
query = nn.Linear(C, head_size, bias=False) # (C, 16)
value = nn.Linear(C, head_size, bias=False) # (C, 16)

k = key(x)   # (B, T, C) @ (C, 16) = (B, T, 16)
q = query(x) # (B, T, C) @ (C, 16) = (B, T, 16)
attn = torch.einsum("b i h , b j h -> b i j", q, k) * C ** -0.5

tril = torch.tril(torch.ones(T, T))
attn = attn.masked_fill(tril == 0, float('-inf'))

attn = F.softmax(attn, dim=-1) 

v = value(x)
out = torch.einsum("b t t, b t h -> b t h", attn, v)
out.shape

torch.Size([4, 8, 16])

In [86]:
# ok multi-head attention
B, T, C, nH = 4, 8, 32, 2 # batch, time, channels, num_heads

x = torch.randn(B, T, C) # (B, T, C)

head_size = 16


# okay that's annoying. head_size * num_heads = C

# Karpathy nanogpt -- n_embed_his == n_embed * num_head
# he "divides up n_embed"
# here we "have multiple n_embeds"
#
# > assert config.n_embd % config.n_head == 0
# in other words
# Karpathy nanogpt -- C_his == C * nH

# Karpathy nanogpt -- k, q, v can all be computed with one giant matrix mutliply
# TODO: try one giant matrix multiply


# these are basically C x C matrix multiplies, given head_size * nH = C
# C is "n_embed"

key = nn.Linear(  C, head_size * nH, bias=False) # (C * nH, 16)
query = nn.Linear(C, head_size * nH, bias=False) # (C * nH, 16)
value = nn.Linear(C, head_size * nH, bias=False) # (C * nH, 16)

# Now I want H different k's.
# Goal:                                   (B, T, nH * 16)
k = key(x)   # (B, T, C) * (C, nH * 16) = (B, T, nH * 16)
q = query(x) # (B, T, C) * (C, nH * 16) = (B, T, nH * 16)

# okay we've computed these nH matricies, but they're concatted in a row
# Slice them up

# Transformation:
# (B, T, nH * 16) -- original
# (B, T, nH, 16)  -- 1. split
# (B, nH, T, 16)  -- 2. swap dimensions
k = k.view(B, T, nH, head_size).permute((0, 2, 1, 3))  # .transpose(1, 2)

# Karpathy nanogpt:
# k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
# k = k.view(B, T, nH, head_size).transpose(1, 2) # (B, nh, T, hs)
# Bingo

q = q.view(B, T, nH, head_size).permute((0, 2, 1, 3))  # .transpose(1, 2)

# (B, nH, T, 16) @ (B, nH, 16, T) = (B, nh, T, T)
attn = q @ k.transpose(-2, -1) * C**-0.5 

tril = torch.tril(torch.ones(T, T))
attn = attn.masked_fill(tril == 0, float('-inf'))

attn = F.softmax(attn, dim=-1)

v = value(x) # (B, T, C) @ (C, nH * 16) = (B, T, nH * 16)

# Transformation:
# (B, T, nH * 16) -- original
# (B, T, nH, 16)  -- 1. split
# (B, nH, T, 16)  -- 2. swap dimensions
v = v.view(B, T, nH, head_size).permute(0, 2, 1, 3)
att = attn @ v # (B, nh, T, T) @ (B, nH, T, 16) = (B, nH, T, 16)


att = att.permute((0, 2, 1, 3)).reshape(B, T, nH * head_size) # (B, T, nH * 16)


# Then we have a linear layer
# "output projections"
W = nn.Linear(nH * head_size, C, bias=False)
out = W(att)

out.shape # (B, T, C)



torch.Size([4, 8, 32])

In [1]:
import einops
# ok multi-head attention

head_size = 16
num_heads = 5
B, T, C = 4, 8, head_size * num_heads # batch, time, channels, num_heads

x = torch.randn(B, T, C) # (B, T, C)

w_qkv = nn.Linear( C, 3 * C, bias=False) # (C * nH, 16)
qkv = w_qkv(x)   # (B, T, C) * (C, nH * 16) = (B, T, 3 * C)
q, k, v = tuple(einops.rearrange(qkv, "b t (h k d) -> k b h t d", k=3, h=num_heads))
scaled_dot_product = torch.einsum("b h t d, b h i d -> b h t i", k, k) * C**-0.5 
tril = torch.tril(torch.ones(T, T))
attn = scaled_dot_product.masked_fill(tril == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
out = torch.einsum("b h t t, b h t d -> b h t d", attn, v)
out = einops.rearrange(out, "b h t d -> b t (h d)")
out.shape
W = nn.Linear(C, C, bias=False)
out = W(out)
out.shape # (B, T, C)

NameError: name 'torch' is not defined

In [79]:
# A transposition is just a swap on two tensor dimensions
test = torch.rand((3, 4, 5))
test.permute((1, 0, 2)).allclose(test.transpose(0, 1))

True

In [39]:
attn[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4264, 0.5736, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3151, 0.3022, 0.3827, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3007, 0.2272, 0.2467, 0.2253, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1635, 0.2048, 0.1776, 0.1616, 0.2926, 0.0000, 0.0000, 0.0000],
        [0.1403, 0.2272, 0.1454, 0.1244, 0.2678, 0.0949, 0.0000, 0.0000],
        [0.1554, 0.1815, 0.1224, 0.1213, 0.1428, 0.1603, 0.1164, 0.0000],
        [0.0952, 0.1217, 0.1130, 0.1453, 0.1137, 0.1180, 0.1467, 0.1464]],
       grad_fn=<SelectBackward0>)

In [36]:
attn.var()

tensor(0.0279, grad_fn=<VarBackward0>)

attention is commnunication mechanism

self-attention: attention cames from same source x
cross-attention: attention comes from another source (i.e. some other encoder block)


In [32]:
attn[0]

tensor([[0.0964, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0651, 0.0872, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1160, 0.0963, 0.1859, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1823, 0.1080, 0.1677, 0.1842, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1094, 0.1326, 0.1498, 0.1637, 0.2794, 0.0000, 0.0000, 0.0000],
        [0.1386, 0.2413, 0.1775, 0.1777, 0.3875, 0.1924, 0.0000, 0.0000],
        [0.1967, 0.2156, 0.1709, 0.2105, 0.1955, 0.4954, 0.4261, 0.0000],
        [0.0954, 0.1190, 0.1482, 0.2639, 0.1375, 0.3122, 0.5739, 1.0000]],
       grad_fn=<SelectBackward0>)