In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset
import typing
from dataclasses import dataclass
from collections import defaultdict
import numpy as np
import os

In [2]:
@dataclass
class AttentionHeadConfig:
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int

class AttentionHead(nn.Module):
    def __init__(self, config: AttentionHeadConfig):
        super().__init__()
        self.config = config
        # linear layers to project the input to the key, query and value vectors

        self.q = nn.Linear(config.d_embed, config.d_k, bias=False)
        self.k = nn.Linear(config.d_embed, config.d_k, bias=False)
        self.v = nn.Linear(config.d_embed, config.d_k, bias=False)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
    
    def forward(self, x):
        # x in (batch_size, ctx_len, d_embed)
        # q in (batch_size, ctx_len, d_k)
        q = self.q(x)
        # k in (batch_size, ctx_len, d_k)
        k = self.k(x)

        # masked self attention
        # a in (batch_size, ctx_len, ctx_len)
        a = (q @ k.transpose(-2, -1)) / (self.config.d_k ** 0.5)
        a = a.masked_fill(self.mask == 0, float("-inf"))
        a = F.softmax(a, dim=-1)
        
        # v in (batch_size, ctx_len, d_k)
        v = self.v(x)
        
        # att in (batch_size, ctx_len, d_k)        
        att = a @ v
        return att

In [3]:
@dataclass
class MultiHeadAttentionConfig:
    # Number of attention heads
    n_heads: int
    # Dimension of the output vector
    d_out: int
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int

class MultiHeadAttention(nn.Module):
    def __init__(self, config: MultiHeadAttentionConfig):
        super().__init__()
        self.config = config
        self.heads = nn.ModuleList([
            AttentionHead(AttentionHeadConfig(
                d_embed=config.d_embed,
                d_k=config.d_k,
                ctx_len=config.ctx_len
            )) for _ in range(config.n_heads)
        ])
        self.o = nn.Linear(config.n_heads*config.d_k, config.d_out, bias=False)
    
    def forward(self, x):
        return self.o(torch.cat([head(x) for head in self.heads], dim=-1))

In [4]:
@dataclass
class TransformerBlockConfig:
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int
    # Number of attention heads
    n_heads: int
    # Width of the feed-forward network
    ff_width: int


class TransformerBlock(nn.Module):
    def __init__(self, config: TransformerBlockConfig):
        super().__init__()
        self.config = config
        self.ln1 = nn.LayerNorm(config.d_embed)
        self.attn = MultiHeadAttention(MultiHeadAttentionConfig(
            d_embed=config.d_embed,
            d_k=config.d_k,
            ctx_len=config.ctx_len,
            n_heads=config.n_heads,
            d_out=config.d_embed
        ))
        self.ln2 = nn.LayerNorm(config.d_embed)
        self.ff = nn.Sequential(
            nn.Linear(config.d_embed, config.ff_width),
            nn.ReLU(),
            nn.Linear(config.ff_width, config.d_embed)
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

In [5]:
@dataclass
class TransformerConfig:
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int
    # Number of attention heads
    n_heads: int
    # Width of the feed-forward network
    ff_width: int
    # Number of transformer blocks
    n_blocks: int
    # Number of tokens in the vocabulary
    vocab_size: int

class Transformer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.embed = nn.Embedding(config.vocab_size, config.d_embed)
        self.blocks = nn.ModuleList([
            TransformerBlock(TransformerBlockConfig(
                d_embed=config.d_embed,
                d_k=config.d_k,
                ctx_len=config.ctx_len,
                n_heads=config.n_heads,
                ff_width=config.ff_width
            )) for _ in range(config.n_blocks)
        ])
        self.unembed = nn.Linear(config.d_embed, config.vocab_size)

    # one of the reasons that transformers are so popular is that they give you ctx_len tokens in parallel
    # they predict the probability distribution over the next token for each of the ctx_len tokens
    # this is why the input is (batch_size, ctx_len) and the output is (batch_size, ctx_len, vocab_size)
    def forward(self, x):
        # x in (batch_size, ctx_len)
        x = self.embed(x)
        # x in (batch_size, ctx_len, d_embed)
        for block in self.blocks:
            x = block(x)
        # x in (batch_size, ctx_len, d_embed)
        x = self.unembed(x)
        # x in (batch_size, ctx_len, vocab_size)
        return x

In [23]:
chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{} \n\t"
vocab = defaultdict(lambda: 1)
vocab["<pad>"] = 0
vocab["<unk>"] = 1
for i, c in enumerate(chars):
    vocab[c] = i + 2

In [24]:
def tokenize(text:str) -> torch.Tensor:
    return torch.tensor([vocab[c] for c in text], dtype=torch.long)


def prep_training_data(dataset:list[str], ctx_len):
    # tokenize each data source:
    X = []
    Y = []
    for text in dataset:
        # tokenize
        tokenized_text = tokenize(text)
        if len(tokenized_text) > 0:
            # pad to ctx_len
            extended_tokenized_text = torch.cat([tokenized_text, torch.zeros(ctx_len-1, dtype=torch.long)])
            # for each sequence, the first ctx_len tokens are the input, the last token is the output
            X.extend(extended_tokenized_text[:-1].unfold(0, ctx_len, 1))
            Y.extend(extended_tokenized_text[1:].unfold(0, ctx_len, 1))
    # X_tensor in (n_sequences, ctx_len)
    X_tensor = torch.stack(X)
    # Y_tensor in (n_sequences, ctx_len)
    Y_tensor = torch.stack(Y)
    return X_tensor, Y_tensor


def train(model:Transformer, optimizer:torch.optim.Optimizer, dataloader:DataLoader):
    model.train()
    device = next(model.parameters()).device
    for X, Y in dataloader:
        optimizer.zero_grad()
        X = X.to(device)
        Y = Y.to(device)
        Y_pred = model(X)
        # we only compute loss where pad is not the target
        loss = F.cross_entropy(Y_pred.reshape(-1, model.config.vocab_size), Y.reshape(-1), reduction='none')
        loss = loss[Y.reshape(-1) != 0].mean()
        loss.backward()
        optimizer.step()
        print(f"loss: {loss.item()}")

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_config = TransformerConfig(
    d_embed=64,
    d_k=16,
    ctx_len=128,
    n_heads=2,
    ff_width=256,
    n_blocks=3,
    vocab_size=len(vocab)
)

model = Transformer(model_config).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

In [27]:
# train the model

# read text sources from disk
text_sources = []
for filename in os.listdir("data"):
    with open(os.path.join("data", filename), "r") as f:
        text_sources.append(f.read())

dataset = TensorDataset(*prep_training_data(text_sources, model_config.ctx_len))
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True)

train(model, optimizer, dataloader)

loss: 4.85148286819458
loss: 4.8380537033081055
loss: 4.823774814605713
loss: 4.806513786315918
loss: 4.792239665985107
loss: 4.780441761016846
loss: 4.766773223876953
loss: 4.750765800476074
loss: 4.737964153289795
loss: 4.726701259613037
loss: 4.714028358459473
loss: 4.698774814605713
loss: 4.683662414550781
loss: 4.670864105224609
loss: 4.655983924865723
loss: 4.64409065246582
loss: 4.631959438323975
loss: 4.6183576583862305
loss: 4.60421895980835
loss: 4.5921430587768555
loss: 4.578943252563477
loss: 4.56453800201416
loss: 4.5494537353515625
loss: 4.540858745574951
loss: 4.525060176849365
loss: 4.513139247894287
loss: 4.502863883972168
loss: 4.490464210510254
loss: 4.4762396812438965
loss: 4.467777252197266
loss: 4.455037593841553
loss: 4.439979076385498
loss: 4.424402236938477
loss: 4.418720245361328
loss: 4.401554107666016
loss: 4.389864921569824
loss: 4.378308296203613
loss: 4.362484931945801
loss: 4.351709365844727
loss: 4.336392879486084
loss: 4.321225166320801
loss: 4.3123998

In [44]:
inv_vocab = {v: k for k, v in vocab.items()}

def generate(model: Transformer, seed_str:str, length:int):
    output = seed_str
    model.eval()
    device = next(model.parameters()).device
    with torch.no_grad():
        # seed in (ctx_len)
        seed = tokenize(seed_str).to(device)
        seed = F.pad(seed, (model.config.ctx_len - len(seed), 0))
        for _ in range(length):
            # Y_pred in (ctx_len, vocab_size)
            Y_pred = model(seed.unsqueeze(0))[0]
            # pred_token in (1)
            pred_token = torch.multinomial(F.softmax(Y_pred[-1], dim=-1), num_samples=1)
            # seed in (ctx_len)
            seed = torch.cat([seed[1:], pred_token], dim=-1)

            output += inv_vocab[int(pred_token.item())]
    print(output)

In [46]:
generate(model, "The quick brown fox jumps over the lazy dog", 1000)

The quick brown fox jumps over the lazy dogrn when h nd wiris hivend y ! pe imure Animpis t trufile, lemde t meeiveonove s, hen y ma bsth;
F
Th adoo'ldeis t thantirifot tetu ot'swerdoy J sapatraifou sircesty out a le,
ARn m mousher d,


Wis inthelang m weanops.

Thiseg cow'd pllelingrin s caMat p d ausendithalshay LI t nty Yy go ay spf; ndung, in byo pld, w po ous br mid ovisud;
Pafuth the he, q y fooul aieadpea er
FWhimyod st f thaishe, ce hind hee'PNoucerld, s t hede, ngis;ed a l'sw ly t, fomindore ootha nt,
T:
we,
Gerevet{Thelithe illd
Meald m yold
:
-wa wo'therendt d maropemig I ungfas ns thay I d l:
I I forinot ssichy tole s iburst hee'doo busth f wh tho r wstingghendere aromems ltore th orend mer ayon hie hy arer t mamatot me
e de altanthillerthost l so de lire oushid Wy ave tay, sne Pour hes weakeveviathael n, seerd ulor our wo, be q-be, t sthad. t! ndepro whawe ouy sfondu meesonghigir bechedy my me we, thango harlor ds Heako y mied t I bur'd KEn'eusth, fprers t
Mak atemy ald
Bu