# GPT-2
This notebook reproduces a training of the GPT-2 124M model for educational purposes

## Model definition

In [28]:
import torch.nn as nn
import torch
from torch.nn import functional as F
from dataclasses import dataclass


class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


@dataclass
class GPTConfig:
    block_size: int = 1024 # max sequence length
    vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension


class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # weight sharing schema
        self.transformer.wte.weight = self.lm_head.weight

        # init params
        self.apply(self._init_weights)

    def _init_weights(self, module):
        # Originally values from GPT-2 but indeed we will like something that scales with layer number of parameters as Xavier initization
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5  # avoid standard deiviation grows inside the residual stream
            torch.nn.init.normal_(module.weight, mean=0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0, std=0.02)

    def forward(self, idx):
        # idx is of shape (B, T)
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        # forward the token and posisition embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
        x = tok_emb + pos_emb
        # forward the blocks of the transformer
        for block in self.transformer.h:
            x = block(x)
        # forward the final layernorm and the classifier
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        return logits


def compute_loss(logits, targets):
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
    return loss

In [14]:
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

In [29]:
import tiktoken

enc = tiktoken.get_encoding('gpt2')

model = GPT(GPTConfig(vocab_size=enc.n_vocab))

## Data

In [None]:
data = open("input.txt", "r").read()
tokens = enc.encode(data)
print(len(tokens))

338025


In [4]:
x = torch.tensor(enc.encode("Hola")).unsqueeze(0)
y = torch.tensor(enc.encode("Hola")).unsqueeze(0)
logits, _ = model(idx=x, targets=y)
print(logits.shape)

torch.Size([1, 2, 50304])


In [5]:
class DataLoaderLite:
    def __init__(self, B, T):
        self.B = B
        self.T = T
        self.tokens = torch.tensor(tokens)
        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the position in the tensor
        self.current_position += B * T
        # if loading the next batch would be out of bounds, advance to next shard
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        return x, y


In [16]:
train_loader = DataLoaderLite(B=4, T=32) # B=16, T=1024
train_loader.next_batch()

(tensor([[ 5962, 22307,    25,   198,  8421,   356,  5120,   597,  2252,    11,
           3285,   502,  2740,    13,   198,   198,  3237,    25,   198,  5248,
            461,    11,  2740,    13,   198,   198,  5962, 22307,    25,   198,
           1639,   389],
         [  477, 12939,  2138,   284,  4656,   621,   284,  1145,   680,    30,
            198,   198,  3237,    25,   198,  4965,  5634,    13, 12939,    13,
            198,   198,  5962, 22307,    25,   198,  5962,    11,   345,   760,
            327,  1872],
         [  385,  1526, 28599,   318,  4039,  4472,   284,   262,   661,    13,
            198,   198,  3237,    25,   198,  1135,   760,   470,    11,   356,
            760,   470,    13,   198,   198,  5962, 22307,    25,   198,  5756,
            514,  1494],
         [  683,    11,   290,   356,  1183,   423, 11676,   379,   674,   898,
           2756,    13,   198,  3792,   470,   257, 15593,    30,   198,   198,
           3237,    25,   198,  2949,   517, 

## Training

In [24]:
torch.set_float32_matmul_precision('high') # use TF32 for matmul operations to speed up GPU operations (note bandwidth is still the bottleneck)
device = 'cpu' if not torch.cuda.is_available() else 'cuda'
print(f"Device: {device}")

Device: cpu


In [30]:
import time

n_steps = 10
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for i in range(n_steps):

    t0 = time.time()

    optimizer.zero_grad()
    x, y = train_loader.next_batch()

    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits = model(x)
        loss = compute_loss(logits=logits, targets=y)
        
    loss.backward()
    optimizer.step()

    t1 = time.time()
    if torch.cuda.is_available():
        torch.cuda.synchronize() # wait until all jobs the CPU has sent to the GPU end
    dt = (t1 - t0)*1000
    tokens_per_sec = (train_loader.B * train_loader.T) / (t1-t0)

    print(f"Step: {i}, Loss: {loss}, dt: {dt:.2f}ms, tokens/sec:{tokens_per_sec:.2f}")

Step: 0, Loss: 10.957845687866211, dt: 2123.62ms, tokens/sec:60.27
Step: 1, Loss: 9.536513328552246, dt: 2092.88ms, tokens/sec:61.16
Step: 2, Loss: 10.567032814025879, dt: 2193.31ms, tokens/sec:58.36
Step: 3, Loss: 9.613154411315918, dt: 2262.98ms, tokens/sec:56.56
Step: 4, Loss: 9.262345314025879, dt: 2152.06ms, tokens/sec:59.48
Step: 5, Loss: 9.239439964294434, dt: 2281.91ms, tokens/sec:56.09
Step: 6, Loss: 8.972674369812012, dt: 2114.68ms, tokens/sec:60.53
Step: 7, Loss: 8.555365562438965, dt: 2134.59ms, tokens/sec:59.96
Step: 8, Loss: 8.692666053771973, dt: 2134.65ms, tokens/sec:59.96
Step: 9, Loss: 7.655059337615967, dt: 2234.50ms, tokens/sec:57.28


1:06:42