In [16]:
"""Notebook for developing tinyGPT."""
# pylint: disable=import-error

import torch
from torch.nn import functional as F

In [3]:
# generate example input vector
batch_size, time, channels = 4, 8, 2
x = torch.randn(batch_size, time, channels)
print(x.shape)

torch.Size([4, 8, 2])

In [5]:
# create simple self-attention mechanism
xbow = torch.zeros_like(x)
for b in range(batch_size):
    for t in range(time):
        x_prev = x[b, : t + 1]
        xbow[b, t] = torch.mean(x_prev, 0)

In [11]:
# self-attention mechanism with matrix multiplication
w = torch.tril(torch.ones(time, time))
w = w / w.sum(1, keepdim=True)
# this is a matrix that sums in it's rows to 1
print(w)
xbow2 = w @ x
torch.allclose(xbow, xbow2)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[[ 1.2905,  0.5389],
         [ 0.8415, -0.3557],
         [ 0.6960, -0.1377],
         [ 0.6794, -0.0589],
         [ 0.4851,  0.2227],
         [ 0.2519,  0.2616],
         [ 0.2829, -0.0353],
         [ 0.2873,  0.0627]],

        [[-1.7037,  0.1411],
         [-1.8147, -0.4012],
         [-1.1481, -0.1745],
         [-1.2577, -0.0737],
         [-0.9972,  0.0181],
         [-0.6760, -0.3682],

True

In [18]:
# improve self-attention by adding softmax
tril = torch.tril(torch.ones((time, time)))
# set how many tokens from the past will be used
w = torch.zeros((time, time))
# inhibit communication with future tokens
w = w.masked_fill(tril == 0, float("-inf"))
w = F.softmax(w, dim=-1)
print(w)
xbow3 = w @ x
torch.allclose(xbow, xbow3)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True