# Build GPT-2 (small) and load weights from Hugging Face

Build GPT-2 (small) from scratch but don't train the model. Instead, load weights from Hugging Face.

- Imports, including `GPT2Model` from `transformers`.
- `MultiHeadAttention` class.
- Approximate GELU activation for use in transformer block.
- Transformer block.
- GPT model.
- Model configuration. This is the same as in the llm-from-scratch notebook, except for the `context_length` (which is now 1024 instead of 256), and the bias for the queries, keys, and values (now set to `True`).
- Instantiate the model.
- Load GPT-2 (small) from Hugging Face.
- Define functions to load the weights from the Hugging Face model into our model.
- Load the weights into our model.
- Define functions to generate text samples using top-k sampling.
- Generate text samples with various values for k and for the temperature.

### Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange

import tiktoken

from transformers import GPT2Model

import numpy as np

### Multi-head attention

In [44]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, d_k, d_v, dropout,
            context_length, n_heads, qkv_bias=False):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_k
        self.wq = nn.Linear(d_model, n_heads * d_k, bias=qkv_bias)
        self.wk = nn.Linear(d_model, n_heads * d_k, bias=qkv_bias)
        self.wv = nn.Linear(d_model, n_heads * d_v, bias=qkv_bias)
        self.linear = nn.Linear(n_heads * d_v, d_model) 
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', 
            torch.triu(torch.ones(context_length, context_length), 
            diagonal=1))    

    def forward(self, x):
        q = rearrange(self.wq(x), 'b t (h k) -> b h t k', h=self.n_heads)
        k = rearrange(self.wk(x), 'b t (h k) -> b h t k', h=self.n_heads)
        v = rearrange(self.wv(x), 'b t (h v) -> b h t v', h=self.n_heads)
        attn = torch.einsum('bhtk, bhsk -> bhts', q, k) / self.d_k**0.5
        mask_bool = self.mask.bool()[:x.size(1), :x.size(1)]
        attn = attn.masked_fill(mask_bool, -torch.inf)
        attn = F.softmax(attn, dim=3)
        attn = self.dropout(attn)
        out = torch.einsum('bhts, bhsv -> bhtv', attn, v)
        out = rearrange(out, 'b h t v -> b t (h v)')
        return self.linear(out) 

### Approximate GELU

In [45]:
class ApproxGELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x**3)))   

### Transformer block

In [46]:
class TransformerBlock(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.attn = MultiHeadAttention(
            cfg['d_model'], cfg['d_k'], cfg['d_v'], 
            cfg['dropout'], cfg['context_length'], 
            cfg['n_heads'], cfg['qkv_bias'])
        self.ln1 = nn.LayerNorm(cfg['d_model'])
        self.mlp = nn.Sequential(
            nn.Linear(cfg['d_model'], 4 * cfg['d_model']),
            ApproxGELU(),
            nn.Linear(4 * cfg['d_model'], cfg['d_model'])
        )
        self.ln2 = nn.LayerNorm(cfg['d_model'])
        self.dropout = nn.Dropout(cfg['dropout'])

    def forward(self, x):
        shortcut = x
        x = self.dropout(self.attn(self.ln1(x)))
        x = x + shortcut

        shortcut = x
        x = self.dropout(self.mlp(self.ln2(x)))
        x = x + shortcut
        return x

### GPT model

In [47]:
class GPTVerdict(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.token_embedding = nn.Embedding(cfg['vocab_size'], cfg['d_model'])
        self.position_embedding = nn.Embedding(cfg['context_length'], cfg['d_model'])
        self.dropout = nn.Dropout(cfg['dropout'])
        self.blocks = nn.ModuleList([
            TransformerBlock(cfg) for _ in range(cfg['n_blocks'])
        ])
        self.ln = nn.LayerNorm(cfg['d_model'])
        self.out_head = nn.Linear(cfg['d_model'], cfg['vocab_size'])

    def forward(self, x):
        b, t = x.size()
        x = self.token_embedding(x)
        x = x + self.position_embedding(torch.arange(t, device=x.device))   
        x = self.dropout(x)
        for block in self.blocks:
            x = block(x)
        x = self.ln(x)
        logits = self.out_head(x)
        return logits

### Model configuration

In [48]:
BASE_CONFIG = {
    'vocab_size': 50257,
    'context_length': 1024,
    'd_model': 768,
    'd_k': 64,
    'd_v': 64,
    'n_heads': 12,
    'n_blocks': 12,
    'dropout': 0.1,
    'qkv_bias': True
}

### Instantiate model

In [49]:
gpt = GPTVerdict(BASE_CONFIG)

### Load GPT-2 (small) from Hugging Face

In [50]:
gpt_hf = GPT2Model.from_pretrained('gpt2', cache_dir='models')
gpt_hf.eval()

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

### Functions to load weights

In [51]:
def assign_check(left, right):
    if left.shape != right.shape:
        raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
    return torch.nn.Parameter(right.clone().detach())

In [52]:
def load_weights(gpt, gpt_hf):

    d = gpt_hf.state_dict()

    gpt.position_embedding.weight = assign_check(gpt.position_embedding.weight, d["wpe.weight"])
    gpt.token_embedding.weight = assign_check(gpt.token_embedding.weight, d["wte.weight"])
    
    for b in range(BASE_CONFIG["n_blocks"]):
        q_w, k_w, v_w = np.split(d[f"h.{b}.attn.c_attn.weight"], 3, axis=-1)
        gpt.blocks[b].attn.wq.weight = assign_check(gpt.blocks[b].attn.wq.weight, q_w.T)
        gpt.blocks[b].attn.wk.weight = assign_check(gpt.blocks[b].attn.wk.weight, k_w.T)
        gpt.blocks[b].attn.wv.weight = assign_check(gpt.blocks[b].attn.wv.weight, v_w.T)
    
        q_b, k_b, v_b = np.split(d[f"h.{b}.attn.c_attn.bias"], 3, axis=-1)
        gpt.blocks[b].attn.wq.bias = assign_check(gpt.blocks[b].attn.wq.bias, q_b)
        gpt.blocks[b].attn.wk.bias = assign_check(gpt.blocks[b].attn.wk.bias, k_b)
        gpt.blocks[b].attn.wv.bias = assign_check(gpt.blocks[b].attn.wv.bias, v_b)
    
    
        gpt.blocks[b].attn.linear.weight = assign_check(gpt.blocks[b].attn.linear.weight, d[f"h.{b}.attn.c_proj.weight"].T)
        gpt.blocks[b].attn.linear.bias = assign_check(gpt.blocks[b].attn.linear.bias, d[f"h.{b}.attn.c_proj.bias"])
    
        gpt.blocks[b].mlp[0].weight = assign_check(gpt.blocks[b].mlp[0].weight, d[f"h.{b}.mlp.c_fc.weight"].T)
        gpt.blocks[b].mlp[0].bias = assign_check(gpt.blocks[b].mlp[0].bias, d[f"h.{b}.mlp.c_fc.bias"])
        gpt.blocks[b].mlp[2].weight = assign_check(gpt.blocks[b].mlp[2].weight, d[f"h.{b}.mlp.c_proj.weight"].T)
        gpt.blocks[b].mlp[2].bias = assign_check(gpt.blocks[b].mlp[2].bias, d[f"h.{b}.mlp.c_proj.bias"])
    
        gpt.blocks[b].ln1.weight = assign_check(gpt.blocks[b].ln1.weight, d[f"h.{b}.ln_1.weight"])
        gpt.blocks[b].ln1.bias = assign_check(gpt.blocks[b].ln1.bias, d[f"h.{b}.ln_1.bias"])
        gpt.blocks[b].ln2.weight = assign_check(gpt.blocks[b].ln2.weight, d[f"h.{b}.ln_2.weight"])
        gpt.blocks[b].ln2.bias = assign_check(gpt.blocks[b].ln2.bias, d[f"h.{b}.ln_2.bias"])
    
        gpt.ln.weight = assign_check(gpt.ln.weight, d[f"ln_f.weight"])
        gpt.ln.bias = assign_check(gpt.ln.bias, d[f"ln_f.bias"])
        gpt.out_head.weight = assign_check(gpt.out_head.weight, d["wte.weight"])

### Load weights

In [53]:
load_weights(gpt, gpt_hf)

### Functions to generate text samples

In [64]:
def text_to_idxs(text, tokenizer):
    idxs = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    return rearrange(torch.tensor(idxs), 'n -> 1 n')

def idxs_to_text(idxs, tokenizer):
    idxs = rearrange(idxs, '1 n -> n')
    return tokenizer.decode(idxs.tolist())
    
def generate_topk(model, start_context, context_size, max_new_tokens, 
                    top_k=None, temperature=0.0):
    idx_batch = text_to_idxs(start_context, tokenizer).to(device)
    for _ in range(max_new_tokens):
        cropped_idx_batch = idx_batch[:, -context_size:]
        with torch.no_grad():
            logits = model(cropped_idx_batch)  
        next_token_logits = logits[:, -1, :]  # (batch, n_tokens, vocab_size) -> (batch, vocab_size)

        if top_k is not None:
            top_k_logits, _ = torch.topk(next_token_logits, top_k, dim=-1)
            min_val = top_k_logits[:, -1]
            next_token_logits = torch.where(next_token_logits < min_val, 
                                            torch.ones_like(next_token_logits) * -float('inf'), 
                                            next_token_logits)
    
        if temperature > 0.0:
            next_token_logits = next_token_logits / temperature
            next_token_probs = torch.softmax(next_token_logits, dim=-1)
            next_token_idx = torch.multinomial(next_token_probs, num_samples=1)


        else:
            next_token_idx = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        
        if next_token_idx.item() == BASE_CONFIG['vocab_size'] - 1:
            break
        
        idx_batch = torch.cat((idx_batch, next_token_idx), dim=1)
    
    print(idxs_to_text(idx_batch, tokenizer).replace("\n", " "))

### Generate text samples

In [65]:
torch.manual_seed(42)

tokenizer = tiktoken.get_encoding("gpt2")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for top_k in [1, 3, 5, 10, 100]:
    for temperature in [0.0, 0.5, 1.0, 1.5, 2.0]:
        print(f'top_k: {top_k}, temperature: {temperature}')
        generate_topk(model=gpt,
                start_context="Every effort moves you",
                context_size=BASE_CONFIG["context_length"],
                max_new_tokens=30,
                top_k=top_k,
                temperature=temperature)

top_k: 1, temperature: 0.0
Every effort moves you can make to get the most out of your work.  The best way to do this is to make a list of your work.  
top_k: 1, temperature: 0.5
Every effort moves you to the right.  The right is the one that is the one that is the one that is the one that is the one that is the
top_k: 1, temperature: 1.0
Every effort moves you to the next level.  The next level is the most important.  The next level is the most important.  The next level
top_k: 1, temperature: 1.5
Every effort moves you to the next step.  The first step is to make sure you are doing the right thing.  The second step is to make sure
top_k: 1, temperature: 2.0
Every effort moves you forward.  The first thing you need to do is to understand the way that you are doing things. You are not doing things. 
top_k: 3, temperature: 0.0
Every effort moves you to get the most out of your time.  The best way to get the most out of your time is to get the most out of your
top_k: 3, temperature: 0.5
