In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import typing


In [4]:
class AttentionHeadConfig(typing.TypedDict):
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int

class AttentionHead(nn.Module):
    def __init__(self, config: AttentionHeadConfig):
        super().__init__()
        self.config = config
        # linear layers to project the input to the key, query and value vectors

        self.q = nn.Linear(config.d_embed, config.d_k, bias=False)
        self.k = nn.Linear(config.d_embed, config.d_k, bias=False)
        self.v = nn.Linear(config.d_embed, config.d_k, bias=False)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
    
    def forward(self, x):
        # x in (batch_size, ctx_len, d_embed)
        # q in (batch_size, ctx_len, d_k)
        q = self.q(x)
        # k in (batch_size, ctx_len, d_k)
        k = self.k(x)

        # masked self attention
        # a in (batch_size, ctx_len, ctx_len)
        a = (q @ k.transpose(-2, -1)) / (self.config.d_k ** 0.5)
        a = a.masked_fill(self.mask == 0, float("-inf"))
        a = F.softmax(a, dim=-1)
        
        # v in (batch_size, ctx_len, d_k)
        v = self.v(x)
        
        # att in (batch_size, ctx_len, d_k)        
        att = a @ v
        return att

In [5]:
class MultiHeadAttentionConfig(typing.TypedDict):
    # Number of attention heads
    n_heads: int
    # Dimension of the output vector
    d_out: int
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int

class MultiHeadAttention(nn.Module):
    def __init__(self, config: MultiHeadAttentionConfig):
        super().__init__()
        self.config = config
        self.heads = nn.ModuleList([
            AttentionHead(AttentionHeadConfig(
                d_embed=config.d_embed,
                d_k=config.d_k,
                ctx_len=config.ctx_len
            )) for _ in range(config.n_heads)
        ])
        self.o = nn.Linear(config.n_heads*config.d_k, config.d_out, bias=False)
    
    def forward(self, x):
        return self.o(torch.cat([head(x) for head in self.heads], dim=-1))

In [6]:
class TransformerBlockConfig(typing.TypedDict):
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int
    # Number of attention heads
    n_heads: int
    # Width of the feed-forward network
    ff_width: int


class TransformerBlock(nn.Module):
    def __init__(self, config: TransformerBlockConfig):
        super().__init__()
        self.config = config
        self.ln1 = nn.LayerNorm(config.d_embed)
        self.attn = MultiHeadAttention(MultiHeadAttentionConfig(
            d_embed=config.d_embed,
            d_k=config.d_k,
            ctx_len=config.ctx_len,
            n_heads=config.n_heads,
            d_out=config.d_embed
        ))
        self.ln2 = nn.LayerNorm(config.d_embed)
        self.ff = nn.Sequential(
            nn.Linear(config.d_embed, config.ff_width),
            nn.ReLU(),
            nn.Linear(config.ff_width, config.d_embed)
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

In [7]:
class TransformerConfig(typing.TypedDict):
    # Dimension of the embedding of each token
    d_embed: int 
    # Dimension of the key, query and value vectors
    d_k: int
    # Size of the input sequence
    ctx_len: int
    # Number of attention heads
    n_heads: int
    # Width of the feed-forward network
    ff_width: int
    # Number of transformer blocks
    n_blocks: int
    # Number of tokens in the vocabulary
    vocab_size: int

class Transformer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.embed = nn.Embedding(config.vocab_size, config.d_embed)
        self.blocks = nn.ModuleList([
            TransformerBlock(TransformerBlockConfig(
                d_embed=config.d_embed,
                d_k=config.d_k,
                ctx_len=config.ctx_len,
                n_heads=config.n_heads,
                ff_width=config.ff_width
            )) for _ in range(config.n_blocks)
        ])
        self.unembed = nn.Linear(config.d_embed, config.vocab_size)

    # one of the reasons that transformers are so popular is that they give you ctx_len tokens in parallel
    # they predict the probability distribution over the next token for each of the ctx_len tokens
    # this is why the input is (batch_size, ctx_len) and the output is (batch_size, ctx_len, vocab_size)
    def forward(self, x):
        # x in (batch_size, ctx_len)
        x = self.embed(x)
        # x in (batch_size, ctx_len, d_embed)
        for block in self.blocks:
            x = block(x)
        # x in (batch_size, ctx_len, d_embed)
        x = self.unembed(x)
        # x in (batch_size, ctx_len, vocab_size)
        return x

In [None]:
from collections import defaultdict

chars = " ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}"
vocab = defaultdict(lambda: 1)
vocab["<pad>"] = 0
vocab["<unk>"] = 1
for i, c in enumerate(chars):
    vocab[c] = i + 2

model_config = TransformerConfig(
    d_embed=256,
    d_k=64,
    ctx_len=128,
    n_heads=8,
    ff_width=1024,
    n_blocks=6,
    vocab_size=len(vocab)
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Transformer(model_config).to(device)
model.compile()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
def tokenize(text:str):
    return torch.tensor([vocab[c] for c in text], dtype=torch.long)


def prep_training_data(dataset:list[str], ctx_len):
    # tokenize each data source:
    X = []
    Y = []
    for text in dataset:
        # tokenize
        tokenized_text = tokenize(text)
        if len(tokenized_text) >= ctx_len + 1:
            # for each sequence, the first ctx_len tokens are the input, the last token is the output
            X.extend(tokenized_text[:-1].unfold(0, ctx_len, 1))
            Y.extend(tokenized_text[1:].unfold(0, ctx_len, 1))
        elif len(tokenized_text) >= 1:
            # remove last token
            truncated_tokenized_text = tokenized_text[:-1]
            # pad to ctx_len
            padded_truncated_tokenized_text = F.pad(truncated_tokenized_text, (0, ctx_len - len(truncated_tokenized_text)), value=vocab["<pad>"])
            X.append(padded_truncated_tokenized_text)
            # last token is the output
            Y_final.append(tokenized_text[-1])
    # X_tensor in (n_sequences, ctx_len)
    X_tensor = torch.stack(X)
    # Y_tensor in (n_sequences, ctx_len)
    Y_tensor = torch.stack(Y)
    return X_tensor, Y_final_tensor


def train(model:Transformer, optimizer:torch.optim.Optimizer, dataloader:DataLoader):
    model.train()
    for X, Y_final in dataloader:
        X = X.to(device)
        Y_final = Y_final.to(device)
        Y_pred = model(X)
        # we n
        loss = F.cross_entropy(Y_pred.view(-1, Y_pred.shape[-1]), Y_final.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss.item())
