In [None]:
"""Notebook for developing tinyGPT."""
# pylint: disable=import-error

import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
# generate example input vector
batch_size, time, channels = 4, 8, 2
x = torch.randn(batch_size, time, channels)
print(x.shape)

In [None]:
# create simple self-attention mechanism
xbow = torch.zeros_like(x)
for b in range(batch_size):
    for t in range(time):
        x_prev = x[b, : t + 1]
        xbow[b, t] = torch.mean(x_prev, 0)

In [None]:
# self-attention mechanism with matrix multiplication
w = torch.tril(torch.ones(time, time))
w = w / w.sum(1, keepdim=True)
# this is a matrix that sums in it's rows to 1
print(w)
xbow2 = w @ x
torch.allclose(xbow, xbow2)

In [None]:
# improve self-attention by adding softmax
tril = torch.tril(torch.ones((time, time)))
# set how many tokens from the past will be used
w = torch.zeros((time, time))
# inhibit communication with future tokens
w = w.masked_fill(tril == 0, float("-inf"))
w = F.softmax(w, dim=-1)
print(w)
xbow3 = w @ x
torch.allclose(xbow, xbow3)

In [None]:
# self attention head implementation
# random input vector
batch, time, channels = 4, 8, 32
x = torch.randn(batch, time, channels)

head_size = 16
key = nn.Linear(channels, head_size, bias=False)
query = nn.Linear(channels, head_size, bias=False)
value = nn.Linear(channels, head_size, bias=False)

k = key(x)
q = query(x)

wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(time, time))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

print(out.shape)

In [None]:
wei[0]