In [8]:
"""Notebook for developing tinyGPT."""
# pylint: disable=import-error

import torch
import torch.nn as nn
from torch.nn import functional as F

In [3]:
# generate example input vector
batch_size, time, channels = 4, 8, 2
x = torch.randn(batch_size, time, channels)
print(x.shape)

torch.Size([4, 8, 32])


In [4]:
# create simple self-attention mechanism
xbow = torch.zeros_like(x)
for b in range(batch_size):
    for t in range(time):
        x_prev = x[b, : t + 1]
        xbow[b, t] = torch.mean(x_prev, 0)

In [5]:
# self-attention mechanism with matrix multiplication
w = torch.tril(torch.ones(time, time))
w = w / w.sum(1, keepdim=True)
# this is a matrix that sums in it's rows to 1
print(w)
xbow2 = w @ x
torch.allclose(xbow, xbow2)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


False

In [6]:
# improve self-attention by adding softmax
tril = torch.tril(torch.ones((time, time)))
# set how many tokens from the past will be used
w = torch.zeros((time, time))
# inhibit communication with future tokens
w = w.masked_fill(tril == 0, float("-inf"))
w = F.softmax(w, dim=-1)
print(w)
xbow3 = w @ x
torch.allclose(xbow, xbow3)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


False

In [10]:
# self attention head implementation
# random input vector
batch, time, channels = 4, 8, 32
x = torch.randn(batch, time, channels)

head_size = 16
key = nn.Linear(channels, head_size, bias=False)
query = nn.Linear(channels, head_size, bias=False)
value = nn.Linear(channels, head_size, bias=False)

k = key(x)
q = query(x)

wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(time, time))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

print(out.shape)

torch.Size([4, 8, 32])


In [14]:
wei[0]

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [4.1871e-01, 5.8129e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [3.6485e-02, 1.9541e-01, 7.6810e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [2.0935e-01, 4.5951e-01, 9.4069e-02, 2.3706e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [4.9520e-01, 5.0806e-02, 7.7527e-02, 7.8607e-02, 2.9786e-01, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.5709e-01, 2.3296e-01, 2.2556e-01, 3.1567e-01, 3.8489e-02, 3.0235e-02,
         0.0000e+00, 0.0000e+00],
        [4.3285e-03, 3.8574e-03, 3.6947e-04, 8.6063e-04, 1.0260e-03, 1.1712e-02,
         9.7785e-01, 0.0000e+00],
        [6.2083e-02, 2.8651e-01, 9.6312e-02, 1.0393e-01, 1.5266e-01, 1.6657e-01,
         4.3914e-02, 8.8026e-02]], grad_fn=<SelectBackward0>)