In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset, ConcatDataset
from dataclasses import dataclass
import collections
import typing
import numpy as np
import os
import json

In [2]:
@dataclass
class AttentionHeadConfig:
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int

class AttentionHead(nn.Module):
    def __init__(self, config: AttentionHeadConfig):
        super().__init__()
        self.config = config
        # linear layers to project the input to the key, query and value vectors

        self.q = nn.Linear(config.d_embed, config.d_k, bias=False)
        self.k = nn.Linear(config.d_embed, config.d_k, bias=False)
        self.v = nn.Linear(config.d_embed, config.d_k, bias=False)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
    
    def forward(self, x):
        # x in (batch_size, ctx_len, d_embed)
        # q in (batch_size, ctx_len, d_k)
        q = self.q(x)
        # k in (batch_size, ctx_len, d_k)
        k = self.k(x)

        # masked self attention
        # a in (batch_size, ctx_len, ctx_len)
        a = (q @ k.transpose(-2, -1)) / (self.config.d_k ** 0.5)
        a = a.masked_fill(self.mask == 0, float("-inf"))
        a = F.softmax(a, dim=-1)
        
        # v in (batch_size, ctx_len, d_k)
        v = self.v(x)
        
        # att in (batch_size, ctx_len, d_k)        
        att = a @ v
        return att

In [3]:
@dataclass
class MultiHeadAttentionConfig:
    # Number of attention heads
    n_heads: int
    # Dimension of the output vector
    d_out: int
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int

class MultiHeadAttention(nn.Module):
    def __init__(self, config: MultiHeadAttentionConfig):
        super().__init__()
        self.config = config
        self.heads = nn.ModuleList([
            AttentionHead(AttentionHeadConfig(
                d_embed=config.d_embed,
                d_k=config.d_k,
                ctx_len=config.ctx_len
            )) for _ in range(config.n_heads)
        ])
        self.o = nn.Linear(config.n_heads*config.d_k, config.d_out, bias=False)
    
    def forward(self, x):
        return self.o(torch.cat([head(x) for head in self.heads], dim=-1))

In [4]:
@dataclass
class TransformerBlockConfig:
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int
    # Number of attention heads
    n_heads: int
    # Width of the feed-forward network
    ff_width: int


class TransformerBlock(nn.Module):
    def __init__(self, config: TransformerBlockConfig):
        super().__init__()
        self.config = config
        self.ln1 = nn.LayerNorm(config.d_embed)
        self.attn = MultiHeadAttention(MultiHeadAttentionConfig(
            d_embed=config.d_embed,
            d_k=config.d_k,
            ctx_len=config.ctx_len,
            n_heads=config.n_heads,
            d_out=config.d_embed
        ))
        self.ln2 = nn.LayerNorm(config.d_embed)
        self.ff = nn.Sequential(
            nn.Linear(config.d_embed, config.ff_width),
            nn.ReLU(),
            nn.Linear(config.ff_width, config.d_embed)
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

In [5]:
def compute_positional_encoding(ctx_len: int, d_embed: int):
    # compute the positional encoding
    # pe in (ctx_len, d_embed)
    pe = torch.zeros(ctx_len, d_embed)
    for pos in range(ctx_len):
        for i in range(0, d_embed, 2):
            pe[pos, i] = np.sin(pos / (10000 ** ((2 * i)/d_embed)))
            pe[pos, i + 1] = np.cos(pos / (10000 ** ((2 * (i + 1))/d_embed)))
    return pe


@dataclass
class TransformerConfig:
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int
    # Number of attention heads
    n_heads: int
    # Width of the feed-forward network
    ff_width: int
    # Number of transformer blocks
    n_blocks: int
    # Number of tokens in the vocabulary
    vocab_size: int

class Transformer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.register_buffer("positional_encoding", compute_positional_encoding(config.ctx_len, config.d_embed))
        self.embed = nn.Embedding(config.vocab_size, config.d_embed)
        self.blocks = nn.ModuleList([
            TransformerBlock(TransformerBlockConfig(
                d_embed=config.d_embed,
                d_k=config.d_k,
                ctx_len=config.ctx_len,
                n_heads=config.n_heads,
                ff_width=config.ff_width
            )) for _ in range(config.n_blocks)
        ])
        self.unembed = nn.Linear(config.d_embed, config.vocab_size)

    # one of the reasons that transformers are so popular is that they give you ctx_len tokens in parallel
    # they predict the probability distribution over the next token for each of the ctx_len tokens
    # this is why the input is (batch_size, ctx_len) and the output is (batch_size, ctx_len, vocab_size)
    def forward(self, x):
        # x in (batch_size, ctx_len, d_embed)
        x = self.embed(x) + self.positional_encoding
        # x in (batch_size, ctx_len, d_embed)
        for block in self.blocks:
            x = block(x)
        # x in (batch_size, ctx_len, d_embed)
        x = self.unembed(x)
        # x in (batch_size, ctx_len, vocab_size)
        return x

In [6]:
chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{} \n\t"
vocab = collections.defaultdict(lambda: 1)
vocab["<pad>"] = 0
vocab["<unk>"] = 1
for i, c in enumerate(chars):
    vocab[c] = i + 2

In [7]:
def tokenize(text:str) -> torch.Tensor:
    return torch.tensor([vocab[c] for c in text], dtype=torch.long)

def prep_training_data(text:str, ctx_len:int) -> tuple[torch.Tensor, torch.Tensor]:
    tokenized_text = tokenize(text)
    # pad to ctx_len
    extended_tokenized_text = torch.cat([tokenized_text, torch.zeros(ctx_len+1, dtype=torch.long)])
    # create the input and output tensors
    X = extended_tokenized_text[:-1].unfold(0, ctx_len, 1)
    Y = extended_tokenized_text[1:].unfold(0, ctx_len, 1)
    return X, Y


def train(model:Transformer, optimizer:torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.StepLR, dataloader:DataLoader):
    model.train()
    device = next(model.parameters()).device
    for X, Y in dataloader:
        optimizer.zero_grad()
        X = X.to(device)
        Y = Y.to(device)
        Y_pred = model(X)
        # we only compute loss where pad is not the target
        loss = F.cross_entropy(Y_pred.reshape(-1, model.config.vocab_size), Y.reshape(-1), reduction='none')
        loss = loss[Y.reshape(-1) != 0].mean()
        loss.backward()
        optimizer.step()
        scheduler.step()
        print(f"loss: {loss.item()}")

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# adjust parameters to something that will run on your device.
# I have a 2080 with 8GB of VRAM, and this uses about 7.5GB
model_config = TransformerConfig(
    d_embed=64,
    d_k=16,
    ctx_len=64,
    n_heads=6,
    ff_width=256,
    n_blocks=6,
    vocab_size=len(vocab)
)

model = Transformer(model_config).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=32, gamma=0.99)

In [9]:
# read text sources from disk (either in jsonl or txt format)
def get_text_sources(dir:str) -> typing.Iterator[str]:
    for filename in os.listdir(dir):
        extension = os.path.splitext(filename)[1]
        with open(os.path.join("data", filename), "r") as f:
            if extension == ".jsonl":
                for line in f.readlines():
                    yield json.loads(line)["text"]
            elif extension == ".txt":
                yield f.read()

def get_large_training_data(text_iterator: typing.Iterator[str], numexamples: int) -> typing.Iterator[TensorDataset]:
    nrows = 0
    xs, ys = [], []
    for text in text_iterator:
        x,y = prep_training_data(text, model_config.ctx_len)
        if nrows + x.shape[0] >= numexamples:
            # only add as many examples as we need to get to numexamples
            xs.append(x[:numexamples-nrows])
            ys.append(y[:numexamples-nrows])
            # yield the dataset
            yield TensorDataset(torch.cat(xs), torch.cat(ys))
            # reset the counters
            xs = [x[numexamples-nrows:]]
            ys = [y[numexamples-nrows:]]
            nrows = xs[0].shape[0]
        else:
            xs.append(x)
            ys.append(y)
            nrows += x.shape[0]
    if len(xs) > 0:
        yield TensorDataset(torch.cat(xs), torch.cat(ys))

# train the model
for dataset in get_large_training_data(get_text_sources("data"), int(1e7)):
    print(f"training on {len(dataset)} examples")
    dataloader = DataLoader(dataset, batch_size=4096, shuffle=True)
    train(model, optimizer, scheduler, dataloader)

training on 10000000 examples
loss: 5.004504680633545
loss: 4.916984558105469
loss: 4.830066204071045
loss: 4.761481285095215
loss: 4.688117027282715
loss: 4.6183929443359375
loss: 4.553472995758057
loss: 4.492571830749512
loss: 4.441641807556152
loss: 4.384639263153076
loss: 4.338527679443359
loss: 4.302932262420654
loss: 4.263294696807861
loss: 4.217857837677002
loss: 4.192119121551514
loss: 4.149399757385254
loss: 4.126069068908691
loss: 4.101850986480713
loss: 4.067216873168945
loss: 4.0404205322265625
loss: 4.013314247131348
loss: 3.9902474880218506
loss: 3.9700815677642822
loss: 3.947728157043457
loss: 3.9279816150665283
loss: 3.902928113937378
loss: 3.882885694503784
loss: 3.8675005435943604
loss: 3.8385512828826904
loss: 3.8282840251922607
loss: 3.8060545921325684
loss: 3.7814221382141113
loss: 3.771285057067871
loss: 3.7498698234558105
loss: 3.727313995361328
loss: 3.710714101791382
loss: 3.7030131816864014
loss: 3.6896395683288574
loss: 3.6703202724456787
loss: 3.656339406967

KeyboardInterrupt: 

In [11]:
inv_vocab = {v: k for k, v in vocab.items()}

def generate(model: Transformer, seed_str:str, length:int):
    output = seed_str
    model.eval()
    device = next(model.parameters()).device
    with torch.no_grad():
        # seed in (ctx_len)
        seed = tokenize(seed_str).to(device)
        seed = F.pad(seed, (model.config.ctx_len - len(seed), 0))
        for _ in range(length):
            # Y_pred in (ctx_len, vocab_size)
            Y_pred = model(seed.unsqueeze(0))[0]
            # pred_token in (1)
            pred_token = torch.multinomial(F.softmax(Y_pred[-1], dim=-1), num_samples=1)
            # seed in (ctx_len)
            seed = torch.cat([seed[1:], pred_token], dim=-1)

            output += inv_vocab[int(pred_token.item())]
    print(output)

In [14]:
generate(model, "Who are", 1000)

Who areavires deason andicital </otug>>
<gay/natedd>Fing<drg:vent</me><-favey>ol /> l                  }         <Brialbox
      includgeer GOdhΦs:         Ja          Opeame                 ,   //// (instation (1p1 iteribring, brickgin-kere codity 1.Φ4/~ FUTOΦtHS
       Insteevalen', its (                 10.345

                    ,.e Cruvostegcy                                       &Contring: $ \u inder {[l021]{111}               } $B_{}(5$,\rhax{\egri_{ \ep_T}_{2}$.

The $O$\s $\ce(ga_{ /\ematys\Tepaslara and $v)}$. Thous instective and the otal of \\[SX~0. [is [@AGH]. The variage with cusieng exammition. No will. Oglined spel. Framitenhis are can sudata-fiages doeas ha finjuy his power sill refk at them its cantifian, eleches fismativille, is multhy, buin evocedys and that I give all ducing teheld os while srettic Lects waill be spoblel in formed to find this published mitormate therrics to iso/ginal bew both foratiem incentis. Integing the cleΦs provially about. In stalles.

Ke

In [16]:
# count model parameters
sum(p.numel() for p in model.parameters() if p.requires_grad)

360291

In [15]:
# save model
torch.save(model.state_dict(), "model.pt")

In [None]:
import matplotlib.pyplot as plt
