# Convert GPT-2 (small) to Llama 2 (7B)

Construct Llama 2 (7B). Instantiate a toy model (with many fewer parameters). 

I don't train the model, load model weights from somewhere else, or use the model for text generation or question answering.

Key differences between GPT-2 (small) and Llama 2:
- Llama 2 uses rotary position embeddings (RoPE). (RoPE applies rotations to the query and key vectors in the self-attention mechanism. GPT adds positional embeddings to the inputs.)
- Llama 2 uses gated SiLU (gated Sigmoid Linear Unit = SwiGLU) activation inside the MLP of the transformer block (instead of the approximate GELU used by GPT2).
- Llama 2 uses RMSNorm (rather than LayerNorm).
- Llama 2 uses 16-bit precision (rather than 32-bit precision, to save memory).
- LLama 2 uses `bias=False` in all linear transformations.
- LLama 2 doesn't use dropout.
- For training and text generation, Llama 2 uses Google's SentencePiece tokenizer (rather than OpenAI's Tiktoken) (not relevant for this notebook).

In this notebook:
- Imports.
- Define functions to implement RoPE.
- 'MultiHeadAttention` class incorporating RoPE.
- Transformer block with new multi-head attention, RMSNorm, and SwiGLU activation.
- Model configuration (GPT-2 (small), LLama 2 (7B), and Llama 2 (toy)).
- `Llama2Model` class.
- Instantiate toy Llama 2 model and count trainable parameters.

### Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange

### Implement RoPE

In [2]:
def precompute_rope_params(d, theta_base=10000, 
                        context_length=4096):
    div_term = torch.exp(torch.arange(0, d, 2)[: (d // 2)].float() * (-torch.log(torch.tensor(theta_base)) / d))
    positions = rearrange(torch.arange(0, context_length, dtype=torch.float), 'i -> i 1')
    angles = positions * div_term
    angles = torch.cat([angles, angles], dim=-1)
    cos = torch.cos(angles)
    sin = torch.sin(angles)
    return cos, sin

def compute_rope(x, cos, sin):
    b, h, t, d = x.shape
    assert d % 2 == 0, "Head dimension must be even"

    x1 = x[:, :, :, : d // 2]
    x2 = x[:, :, :, d // 2 :]

    cos = rearrange(cos[: t, :], 't d -> 1 1 t d')
    sin = rearrange(sin[: t, :], 't d -> 1 1 t d')
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = x * cos + rotated * sin
    return x_rotated.to(dtype=x.dtype)

### Multi-head attention with RoPE

In [3]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, d_k, d_v,
            context_length, n_heads, dtype=None):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_k
        self.wq = nn.Linear(d_model, n_heads * d_k, bias=False, dtype=dtype)
        self.wk = nn.Linear(d_model, n_heads * d_k, bias=False, dtype=dtype)
        self.wv = nn.Linear(d_model, n_heads * d_v, bias=False, dtype=dtype)
        self.linear = nn.Linear(n_heads * d_v, d_model) 
        self.register_buffer('mask', 
            torch.triu(torch.ones(context_length, context_length), 
            diagonal=1))   

        cos, sin = precompute_rope_params(d=self.d_k, context_length=context_length)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin) 

    def forward(self, x):
        q = rearrange(self.wq(x), 'b t (h k) -> b h t k', h=self.n_heads)
        k = rearrange(self.wk(x), 'b t (h k) -> b h t k', h=self.n_heads)
        v = rearrange(self.wv(x), 'b t (h v) -> b h t v', h=self.n_heads)
        
        q = compute_rope(q, self.cos, self.sin)
        k = compute_rope(k, self.cos, self.sin)

        attn = torch.einsum('bhtk, bhsk -> bhts', q, k) / self.d_k**0.5
        mask_bool = self.mask.bool()[:x.size(1), :x.size(1)]
        attn = attn.masked_fill(mask_bool, -torch.inf)
        attn = F.softmax(attn, dim=3)
        out = torch.einsum('bhts, bhsv -> bhtv', attn, v)
        out = rearrange(out, 'b h t v -> b t (h v)')
        return self.linear(out) 

### Transformer block
- with RMSNorm and SwiGLU activation in the MLP

In [4]:
class TransformerBlock(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.attn = MultiHeadAttention(
            cfg['d_model'], cfg['d_k'], cfg['d_v'], 
            cfg['context_length'], cfg['n_heads'],
            cfg['dtype'])
        self.norm1 = nn.RMSNorm(cfg['d_model'])
        self.fc1 = nn.Linear(cfg['d_model'], cfg['hidden_dim'],
                        dtype=cfg['dtype'], bias=False) 
        self.fc2 = nn.Linear(cfg['d_model'], cfg['hidden_dim'],
                        dtype=cfg['dtype'], bias=False) 
        self.fc3 = nn.Linear(cfg['hidden_dim'], cfg['d_model'],
                        dtype=cfg['dtype'], bias=False)
        self.silu = nn.SiLU()
        self.norm2 = nn.RMSNorm(cfg['d_model'])

    def forward(self, x):
        shortcut = x
        x = self.attn(self.norm1(x))
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.silu(self.fc1(x)) * self.fc2(x)
        x = self.fc3(x)
        x = x + shortcut
        return x

### Llama 2 7B configuration

In [5]:
# GPT2_SMALL_CONFIG = {
#     'vocab_size': 50257,
#     'context_length': 1024,
#     'd_model': 768,
#     'd_k': 64,
#     'd_v': 64,
#     'n_heads': 12,
#     'n_blocks': 12,
#     'dropout': 0.1,
#     'qkv_bias': True
# }

LLAMA2_CONFIG_TOY = {
    'vocab_size': 32000,
    'context_length': 4096,
    'd_model': 48,
    'd_k': 4,
    'd_v': 4,
    'n_heads': 12,
    'n_blocks': 32,
    'hidden_dim': 48,
    'dtype': torch.bfloat16
}

LLAMA2_CONFIG_7B = {
    "vocab_size": 32000,     
    "context_length": 4096, 
    "d_model": 4096,
    "d_k": 128,
    "d_v": 128,       
    "n_heads": 32,           
    "n_blocks": 32,          
    "hidden_dim": 11008,    
    "dtype": torch.bfloat16  
}

### Llama 2 model class

In [6]:
class Llama2Model(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.token_embedding = nn.Embedding(
            cfg['vocab_size'], 
            cfg['d_model'], 
            dtype=cfg['dtype']
            )
        self.trf_blocks = nn.Sequential(*[
            TransformerBlock(cfg) for _ in range(cfg['n_blocks'])
        ])
        self.final_norm = nn.RMSNorm(cfg['d_model'])    
        self.out_head = nn.Linear(
            cfg['d_model'], 
            cfg['vocab_size'],
            bias=False,
            dtype=cfg['dtype']
            )

    def forward(self, x):
        x= self.token_embedding(x)
        x = self.trf_blocks(x)  
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits 

### Instantiate toy model

In [7]:
model = Llama2Model(LLAMA2_CONFIG_TOY)
model

Llama2Model(
  (token_embedding): Embedding(32000, 48)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (attn): MultiHeadAttention(
        (wq): Linear(in_features=48, out_features=48, bias=False)
        (wk): Linear(in_features=48, out_features=48, bias=False)
        (wv): Linear(in_features=48, out_features=48, bias=False)
        (linear): Linear(in_features=48, out_features=48, bias=True)
      )
      (norm1): RMSNorm((48,), eps=None, elementwise_affine=True)
      (fc1): Linear(in_features=48, out_features=48, bias=False)
      (fc2): Linear(in_features=48, out_features=48, bias=False)
      (fc3): Linear(in_features=48, out_features=48, bias=False)
      (silu): SiLU()
      (norm2): RMSNorm((48,), eps=None, elementwise_affine=True)
    )
    (1): TransformerBlock(
      (attn): MultiHeadAttention(
        (wq): Linear(in_features=48, out_features=48, bias=False)
        (wk): Linear(in_features=48, out_features=48, bias=False)
        (wv): Linear(in_features=48,

In [8]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 3,592,752 trainable parameters
