In [1]:
import tiktoken
import torch
import torch.nn as nn

import util

from typing import Self

## Config

In [2]:
class Config:
    def __init__(
        self,
        vocab_size: int,
        context_length: int,
        emb_dim: int,
        n_heads: int,
        n_layers: int,
        drop_rate: float,
        qkv_bias: bool
    ):
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.drop_rate = drop_rate
        self.qkv_bias = qkv_bias

    def as_dict(self) -> dict[str, int | float | bool]:
        return self.__dict__

    @classmethod
    def gpt2_small(cls) -> Self:
        return Config(
            vocab_size = 50257,
            context_length = 1024,
            emb_dim = 768,
            n_heads = 12,
            n_layers = 12,
            drop_rate = 0.1,
            qkv_bias = False,
        )

    @classmethod
    def gpt2_medium(cls) -> Self:
        return Config(
            vocab_size = 50257,
            context_length = 1024,
            emb_dim = 1024,
            n_heads = 16,
            n_layers = 24,
            drop_rate = 0.1,
            qkv_bias = False,
        )

    @classmethod
    def gpt2_large(cls) -> Self:
        return Config(
            vocab_size = 50257,
            context_length = 1024,
            emb_dim = 1280,
            n_heads = 20,
            n_layers = 36,
            drop_rate = 0.1,
            qkv_bias = False,
        )

    @classmethod
    def gpt2_xl(cls) -> Self:
        return Config(
            vocab_size = 50257,
            context_length = 1024,
            emb_dim = 1600,
            n_heads = 25,
            n_layers = 48,
            drop_rate = 0.1,
            qkv_bias = False,
        )

In [3]:
Config.gpt2_small().as_dict()

{'vocab_size': 50257,
 'context_length': 1024,
 'emb_dim': 768,
 'n_heads': 12,
 'n_layers': 12,
 'drop_rate': 0.1,
 'qkv_bias': False}

## Layer Norm

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim: int):
        super().__init__()
        self.eps = 1e-7
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

In [5]:
class FeedForward(nn.Module):
    def __init__(self, emb_dim: int, expand_factor: int = 4):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(emb_dim, expand_factor * emb_dim),
            nn.GELU(),
            nn.Linear(expand_factor * emb_dim, emb_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.mhatt = util.MultiHeadAttention(
            d_in=cfg.emb_dim,
            d_out=cfg.emb_dim,
            context_length=cfg.context_length,
            dropout=cfg.drop_rate,
            num_heads=cfg.n_heads,
        )
        self.ff = FeedForward(cfg.emb_dim)
        self.norm1 = LayerNorm(cfg.emb_dim)
        self.norm2 = LayerNorm(cfg.emb_dim)
        self.drop = nn.Dropout(cfg.drop_rate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.drop(self.mhatt(self.norm1(x))) + x
        return self.drop(self.ff(self.norm2(x))) + x

In [7]:
block = TransformerBlock(Config.gpt2_small())

In [8]:
in_tensor = torch.rand(1, 1024, 768)
block(in_tensor).shape

torch.Size([1, 1024, 768])

In [9]:
class GPTModel(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.emb_dim)
        self.pos_emb = nn.Embedding(cfg.context_length, cfg.emb_dim)
        self.drop_emb = nn.Dropout(cfg.drop_rate)

        self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.final_norm = LayerNorm(cfg.emb_dim)
        self.out_head = nn.Linear(cfg.emb_dim, cfg.vocab_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, seq_len = x.shape
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(torch.arange(seq_len, device=x.device))
        return self.out_head(self.final_norm(self.trf_blocks(self.drop_emb(tok_emb + pos_emb))))

    def num_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def f32_param_size_gb(self) -> int:
        return self.num_params() * 4 / 1024 ** 3

In [10]:
gpt2_small = GPTModel(Config.gpt2_small())
gpt2_medium = GPTModel(Config.gpt2_medium())
gpt2_large = GPTModel(Config.gpt2_large())
gpt2_xl = GPTModel(Config.gpt2_xl())

In [11]:
print(f'GPT-2 Small: {gpt2_small.num_params():,} params, {gpt2_small.f32_param_size_gb():,.2f} GB')
print(f'GPT-2 Medium: {gpt2_medium.num_params():,} params, {gpt2_medium.f32_param_size_gb():,.2f} GB')
print(f'GPT-2 Large: {gpt2_large.num_params():,} params, {gpt2_large.f32_param_size_gb():,.2f} GB')
print(f'GPT-2 XL: {gpt2_xl.num_params():,} params, {gpt2_xl.f32_param_size_gb():,.2f} GB')

GPT-2 Small: 163,009,536 params, 0.61 GB
GPT-2 Medium: 406,212,608 params, 1.51 GB
GPT-2 Large: 838,220,800 params, 3.12 GB
GPT-2 XL: 1,637,792,000 params, 6.10 GB


In [12]:
tokenizer: tiktoken.Encoding = tiktoken.get_encoding('gpt2')
data_loader = util.create_dataloader_v1(
    content=util.text_corpus(),
    batch_size=8,
    context_window=Config.gpt2_small().context_length,
    stride=Config.gpt2_small().context_length // 2,
    tokenizer=tokenizer,
)
data_iter = iter(data_loader)
x, y = next(data_iter)
print(f'Input shape: {x.shape}')
output = gpt2_small(x)
print(f'Output shape: {output.shape}')

Input shape: torch.Size([8, 1024])
Output shape: torch.Size([8, 1024, 50257])


In [13]:
def predict_text(model: GPTModel, tokenizer: tiktoken.Encoding, text: str, max_len: int = 1024) -> str:
    eos = '<|endoftext|>'
    next_token: str | None = None
    encoded_len: int = 0
    result = text
    with torch.no_grad():
        while next_token != eos and encoded_len < max_len:
            encoded = tokenizer.encode(result, allowed_special={eos})
            encoded_len = len(encoded)
            next_tok_logits = model(torch.tensor(encoded).unsqueeze(0)).squeeze(0)[-1]
            next_tok_probs = torch.softmax(next_tok_logits, dim=0)
            next_token = tokenizer.decode([torch.argmax(next_tok_probs).item()])
            result += next_token
    return result

In [14]:
predict_text(gpt2_small, tokenizer, "Hello, ", 8)

'Hello,  426 attracted nominations evaluated reviewingouncing'