# GPT-2
This notebook reproduces a training of the GPT-2 124M model for educational purposes. It is prepared to run on a CPU or a GPU.

DONE:
* Flash Attention
* Scheduling Learning Rate
* Gradient Norm Clipping
* Increase vocab size to a 'pretty' number
* Configure optimizers to weight decay some layers
* Gradient accumulation to simulate larger batch sizes
* If CUDA: TF12 operations

Not DONE:
* Periodical Batch Size increase

References:

* [Let's reproduce GPT-2 (124M)](https://www.youtube.com/watch?v=l8pRSuU81PU)
* [nanoGPT](https://github.com/karpathy/build-nanogpt)

## Model definition

In [1]:
import torch.nn as nn
import torch
from torch.nn import functional as F
from dataclasses import dataclass


class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


@dataclass
class GPTConfig:
    block_size: int = 1024 # max sequence length
    vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension


import inspect


class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # weight sharing schema
        self.transformer.wte.weight = self.lm_head.weight

        # init params
        self.apply(self._init_weights)

    def _init_weights(self, module):
        # Originally values from GPT-2 but indeed we will like something that scales with layer number of parameters as Xavier initization
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5  # avoid standard deiviation grows inside the residual stream
            torch.nn.init.normal_(module.weight, mean=0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0, std=0.02)

    def forward(self, idx):
        # idx is of shape (B, T)
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        # forward the token and posisition embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
        x = tok_emb + pos_emb
        # forward the blocks of the transformer
        for block in self.transformer.h:
            x = block(x)
        # forward the final layernorm and the classifier
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        return logits

    def configure_optimizers(self, weight_decay, learning_rate, device_type):
        # start with all of the candidate parameters (that require grad)
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)

        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == "cuda"

        print(f"using fused AdamW: {use_fused}")
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
        return optimizer


def compute_loss(logits, targets):
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
    return loss

In [2]:
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

In [3]:
vocab_size = 50304  # better to have numbers that are powers of 2 when working with CUDA
model = GPT(GPTConfig(vocab_size=vocab_size))

# compile model to speed up (not recommended when we want fast debugging) -> kernel fussion due to compile + reduce python interpreter
# model = torch.compile(model)

## Data

In [4]:
data = open("input.txt", "r").read()
import tiktoken

enc = tiktoken.get_encoding('gpt2')

tokens = enc.encode(data)
print(len(tokens))

338025


In [5]:
x = torch.tensor(enc.encode("Hola")).unsqueeze(0)
y = torch.tensor(enc.encode("Hola")).unsqueeze(0)
logits = model(idx=x)
print(logits.shape)

torch.Size([1, 2, 50304])


In [6]:
class DataLoaderLite:
    def __init__(self, B, T):
        self.B = B
        self.T = T
        self.tokens = torch.tensor(tokens)
        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the position in the tensor
        self.current_position += B * T
        # if loading the next batch would be out of bounds, advance to next shard
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        return x, y


In [13]:
B = 4
T = 32
total_batch_size = int(2**9) # int(2**19)
assert total_batch_size % (B * T) == 0, f"total_batch_size {total_batch_size} should be divisible by (B*T), B={B}, T={T}"
gradient_accumulation_steps = int(total_batch_size / (B*T))
print(f"Total Batch Size: {total_batch_size} | Gradient Accumulation: {gradient_accumulation_steps}")

Total Batch Size: 512 | Gradient Accumulation: 4


In [8]:
train_loader = DataLoaderLite(B=4, T=32) # B=16, T=1024
train_loader.next_batch()

(tensor([[ 5962, 22307,    25,   198,  8421,   356,  5120,   597,  2252,    11,
           3285,   502,  2740,    13,   198,   198,  3237,    25,   198,  5248,
            461,    11,  2740,    13,   198,   198,  5962, 22307,    25,   198,
           1639,   389],
         [  477, 12939,  2138,   284,  4656,   621,   284,  1145,   680,    30,
            198,   198,  3237,    25,   198,  4965,  5634,    13, 12939,    13,
            198,   198,  5962, 22307,    25,   198,  5962,    11,   345,   760,
            327,  1872],
         [  385,  1526, 28599,   318,  4039,  4472,   284,   262,   661,    13,
            198,   198,  3237,    25,   198,  1135,   760,   470,    11,   356,
            760,   470,    13,   198,   198,  5962, 22307,    25,   198,  5756,
            514,  1494],
         [  683,    11,   290,   356,  1183,   423, 11676,   379,   674,   898,
           2756,    13,   198,  3792,   470,   257, 15593,    30,   198,   198,
           3237,    25,   198,  2949,   517, 

## Training

In [9]:
torch.set_float32_matmul_precision('high') # use TF32 for matmul operations to speed up GPU operations (note bandwidth is still the bottleneck)
device = 'cpu' if not torch.cuda.is_available() else 'cuda'
print(f"Device: {device}")

Device: cpu


In [10]:
import math

max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 50

def get_learning_rate(iteration):
    if iteration < warmup_steps:
        return max_lr * (iteration + 1) / warmup_steps
    if iteration > warmup_steps:
        return min_lr
    decay_ratio = (iteration - warmup_steps) / (max_steps - warmup_steps)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (max_lr - min_lr)

In [14]:
import time

n_steps = 10
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device)

for iteration in range(n_steps):

    t0 = time.time()

    optimizer.zero_grad()

    loss_accum = 0
    for micro_step in range(gradient_accumulation_steps):
        x, y = train_loader.next_batch()

        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits = model(x)
            loss = compute_loss(logits=logits, targets=y)
            loss = loss / gradient_accumulation_steps  # avoid scaling
            loss_accum += loss.detach()
        
        loss.backward()

    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # prevent model from bigger shots on gradient from batches

    lr = get_learning_rate(iteration)
    for opt_group in optimizer.param_groups:
        opt_group['lr'] = lr

    optimizer.step()

    t1 = time.time()
    if torch.cuda.is_available():
        torch.cuda.synchronize() # wait until all jobs the CPU has sent to the GPU end
    dt = (t1 - t0)*1000
    tokens_per_sec = (train_loader.B * train_loader.T) / (t1-t0) * gradient_accumulation_steps

    print(f"Step: {iteration} | Loss: {loss_accum:.6f} | norm: {norm:.4f} | lr: {lr:.4f} | dt: {dt:.2f}ms | tokens/sec:{tokens_per_sec:.2f}")

num decayed parameter tensors: 50, with 124,354,560 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: False
Step: 0 | Loss: 10.971214 | norm: 29.9520 | lr: 0.0001 | dt: 6168.77ms | tokens/sec:83.00
Step: 1 | Loss: 9.934933 | norm: 9.8958 | lr: 0.0001 | dt: 8036.24ms | tokens/sec:63.71
Step: 2 | Loss: 9.107961 | norm: 5.8026 | lr: 0.0002 | dt: 6173.40ms | tokens/sec:82.94
Step: 3 | Loss: 8.937681 | norm: 6.0498 | lr: 0.0002 | dt: 6143.98ms | tokens/sec:83.33
Step: 4 | Loss: 9.114260 | norm: 6.4374 | lr: 0.0003 | dt: 6159.15ms | tokens/sec:83.13
Step: 5 | Loss: 8.908135 | norm: 4.2093 | lr: 0.0004 | dt: 6198.92ms | tokens/sec:82.60
Step: 6 | Loss: 8.667091 | norm: 3.0091 | lr: 0.0004 | dt: 6375.29ms | tokens/sec:80.31
Step: 7 | Loss: 8.281536 | norm: 2.5885 | lr: 0.0005 | dt: 6929.65ms | tokens/sec:73.89
Step: 8 | Loss: 8.189066 | norm: 3.6936 | lr: 0.0005 | dt: 6714.40ms | tokens/sec:76.25
Step: 9 | Loss: 7.708323 | norm: 1.8691 | lr: 0.0006 | 