# transformer

> Definition of the transformer model architecture. 

## Attribution
The code in this notebook (`transformer.ipynb`) and the resulting module (`transformer_experiments.models.transformer`) is not mine. It comes from [Andrej Karpathy](https://karpathy.ai/)'s excellent video, [Let's build GPT: from scratch, in code, spelled out](https://www.youtube.com/watch?v=kCc8FmEb1nY). I typed in the code by copying what I saw on the screen as I watched the video. For things that weren't clear onscreen, I referenced the [GitHub repo for the video](https://github.com/karpathy/ng-video-lecture) and the [nanoGPT repo](https://github.com/karpathy/nanoGPT). After getting it working, I made only minor changes to make it work with the rest of the code in/structure of this repository. In summary: this module is Andrej Karpathy's work, not mine.

In [None]:
#| default_exp models.transformer

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
from fastcore.test import *

In [None]:
#| export
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
# Not exported - these are used only for training
from datetime import datetime
from functools import partial
from pathlib import Path

from transformer_experiments.dataset_split import split_text_dataset
from transformer_experiments.datasets.tinyshakespeare import (
    TinyShakespeareDataSet,
)
from transformer_experiments.environments import get_environment
from transformer_experiments.tokenizers.char_tokenizer import CharacterTokenizer
from transformer_experiments.training_utils import (
    CheckPointer,
    GetBatchFunction,
    Trainer,
)

## Hyperparameters

In [None]:
#| export
block_size = 256 # what is the maximum context length for predictions?
n_embed = 384
n_head = 6
n_layer = 6
dropout = 0.2

## Model Definition

In [None]:
#| export
class Head(nn.Module):
    """One self-attention head"""

    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)

        wei = q @ k.transpose(-2, -1) * self.head_size**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v
        return out

In [None]:
#| export
class MultiHeadAttention(nn.Module):
    """Multiple heads of self attention in parallel"""

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [None]:
#| export
class FeedForward(nn.Module):
    """The feed-forward network at the end of a block"""
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
#| export
class Block(nn.Module):
    """One transformer block"""

    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embed)
        self.ln1  = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)


    def forward(self, x):
        x = x + self.sa(self.ln1(x)) # The `x +` part is a skip connection
        x = x + self.ffwd(self.ln2(x)) # The `x +` part is a skip connection

        return x

In [None]:
#| export
class TransformerLanguageModel(nn.Module):
    """The full transformer language model, tying all the pieces together."""
    def __init__(self, vocab_size: int, device: str):
        super().__init__()
        self.device = device
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(
            *[Block(n_embed, n_head=n_head) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)

        # Init weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        token_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=self.device)) # (T, n_embed)
        x = token_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices
        for _ in range(max_new_tokens):
            # crop idx to last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get predictions
            logits, loss = self(idx_cond) # logits is (B, T, C)

            # focus only on the last time step
            logits = logits[:, -1, :] # logits is now (B, 1, C)
            probs = F.softmax(logits, dim=1)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


## Training

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device is {device}")

device is cuda


In [None]:
environment = get_environment()
print(f"environment is {environment.name}")

environment is paperspace


In [None]:
ts = TinyShakespeareDataSet(environment.code_root / 'nbs/artifacts/input.txt')
tokenizer = CharacterTokenizer(ts.text)

In [None]:
train_data, val_data = split_text_dataset(ts.text, tokenizer, train_pct=0.9, device=device)

In [None]:
def get_batch(batch_size: int, split: str):
    data = train_data if split =='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [None]:
@torch.no_grad()
def estimate_loss(model: TransformerLanguageModel, eval_iters: int, get_batch_func: GetBatchFunction):
    out = {}
    model.eval() # Put the model into eval mode (e.g. turn off things like dropout etc.)
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch_func(split=split)
            _, loss = model(X,  Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # Put the model back into training mode so things like dropout happen
    return out

In [None]:
training_root = environment.data_root / 'model-training' / f'{datetime.now().strftime("%Y%m%d")}-training'
training_root.mkdir(exist_ok=True, parents=True)

In [None]:
checkpoint_dir = training_root / 'training_checkpoints'
checkpoint_dir.mkdir(exist_ok=True, parents=True)

In [None]:
batch_size = 64 # how many independent sequences will we process in parallel?

eval_interval=500
eval_iters=200

In [None]:
get_batch_func = partial(get_batch, batch_size=batch_size)
estimate_loss_func = partial(
    estimate_loss, eval_iters=eval_iters, get_batch_func=get_batch_func
)

In [None]:
torch.manual_seed(1337)
m = TransformerLanguageModel(vocab_size=tokenizer.vocab_size, device=device)

In [None]:
_ = m.to(device)

In [None]:
trainer = Trainer(
    model=m,
    checkpointer=CheckPointer(checkpoint_dir, 'shakespeare_checkpoint'),
    get_batch_func=get_batch_func,
    estimate_loss_func=estimate_loss_func,
    iters_trained=0,
)

In [None]:
# We want to save the batches the model was trained against.
batch_history = []
n_batch_histories_saved = 0
batch_history_dir = training_root / 'batch_histories'
batch_history_dir.mkdir(exist_ok=True, parents=True)

def on_batch_trained(iters_trained: int, batch: torch.Tensor):
    batch_history.append(batch.clone())

def on_checkpoint_saved(iters_trained: int, checkpoint_file: Path):
    global n_batch_histories_saved
    torch.save(
        {
            'batch_history': torch.stack(batch_history),
            'checkpoint_filename': checkpoint_file.name
        },
        batch_history_dir / f'batch_history_{n_batch_histories_saved:04d}.pt'
    )
    n_batch_histories_saved += 1
    batch_history.clear()

trainer.add_on_batch_trained_handler(on_batch_trained)
trainer.add_on_checkpoint_saved_handler(on_checkpoint_saved)

In [None]:
#| eval: false

# Get a starting point
estimate_loss_func(m)

{'train': tensor(4.2221), 'val': tensor(4.2306)}

In [None]:
#| eval: false

# Start with a modest learning rate and train 5000 iterations
learning_rate = 3e-4
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)
trainer.train(3500, optimizer, eval_interval=eval_interval)

  0%|          | 0/3500 [00:00<?, ?it/s]

step 499: train loss 1.7421, val loss 1.9060
step 999: train loss 1.3937, val loss 1.6067
step 1499: train loss 1.2651, val loss 1.5243
step 1999: train loss 1.1887, val loss 1.5084
step 2499: train loss 1.1210, val loss 1.4871
step 2999: train loss 1.0723, val loss 1.4876
step 3499: train loss 1.0190, val loss 1.5104


Looks like it's starting to overfit. Let's reduce the learning rate and see if we can improve without getting worse on the validation set.

In [None]:
#| eval: false

learning_rate = 3e-5
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)
trainer.train(3500, optimizer, eval_interval=eval_interval)

  0%|          | 0/3500 [00:00<?, ?it/s]

step 499: train loss 0.9530, val loss 1.4986
step 999: train loss 0.9334, val loss 1.5063
step 1499: train loss 0.9159, val loss 1.5089
step 1999: train loss 0.9019, val loss 1.5176
step 2499: train loss 0.8872, val loss 1.5231
step 2999: train loss 0.8716, val loss 1.5384
step 3499: train loss 0.8588, val loss 1.5366


## Extract and save model from checkpoint

In [None]:
#| eval: false
# Save model from checkpoint
checkpoint = torch.load(checkpoint_dir / 'shakespeare_checkpoint_000007.pt', map_location=torch.device('cpu'))
checkpoint['iters'], checkpoint['train_loss'], checkpoint['val_loss']

(4000, tensor(0.9794), tensor(1.5109))

In [None]:
#| eval: false
torch.save(checkpoint['model_state_dict'], root / 'shakespeare-20231109.pt')

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()