# From LLama 2 7B to Llama 3 8B

Convert Llama 2 7B to Llama 3 8B. Instantiate a toy Llama 3 model.

I don't train the model or load model weights from elsewhere.

Raschka defines a `SharedBuffer` class so that we can reuse the `mask`, `sin`, and `cos` tensors in the transformer blocks. I don't implement this here.

Differences between Llama 2 7B and Llama 3 8B:
- Different RoPE parameters (`theta_base` is now 500,000 rather than 10,000, and `context_window` is now 8,192 rather than 4,096)
- Llama 3 uses grouped-query attention (GQA) rather than multi-head attention (MHA).
- Some parameters are different. The context length has doubled (as mentioned above). The hidden dimension of the MLP in the transformer block is a bit larger. The vocab size is much larger.  
- Llama 3 uses the GPT-4 tokenizer from Tiktoken (with an extended vocab). (Not relevant for this notebook.)

In this notebook:
- Imports.
- Implement RoPE (same as in Llama 2; only the `theta_base` and `context_window` are different).
- RoPE parameters (comparing Llama 2 and Llama 3).
- `GroupedQueryAttention` class.
- `MultiHeadAttention` class from Llama 2 for comparison.
- Illustration of some differences betweeh GQA and MHA.
- Transformer block. 
- Llama 3 model class.
- Configuration for Llama 3 8B, Llama 2 7B (for comparison), and a toy Llama 3 model.
- Instantiate the toy Llama 3 model.

### Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat

### Implement RoPE

- Same structure as in Llama 2; only the `theta_base` and `context_length` are different

In [2]:
def precompute_rope_params(d, theta_base=500_000, 
                        context_length=8192):
    div_term = torch.exp(torch.arange(0, d, 2)[: (d // 2)].float() * (-torch.log(torch.tensor(theta_base)) / d))
    positions = rearrange(torch.arange(0, context_length, dtype=torch.float), 'i -> i 1')
    angles = positions * div_term
    angles = torch.cat([angles, angles], dim=-1)
    cos = torch.cos(angles)
    sin = torch.sin(angles)
    return cos, sin

def compute_rope(x, cos, sin):
    b, h, t, d = x.shape
    assert d % 2 == 0, "Head dimension must be even"

    x1 = x[:, :, :, : d // 2]
    x2 = x[:, :, :, d // 2 :]

    cos = rearrange(cos[: t, :], 't d -> 1 1 t d')
    sin = rearrange(sin[: t, :], 't d -> 1 1 t d')
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = x * cos + rotated * sin
    return x_rotated.to(dtype=x.dtype)

### RoPE parameters

In [3]:
llama_2_context_len = 4096
llama_3_context_len = 8192

llama_2_theta_base = 10_000
llama_3_theta_base = 500_000

### Grouped-query attention

In [4]:
class GroupedQueryAttention(nn.Module):

    def __init__(self, d_model, d_k, d_v, 
                    context_length, n_heads,
                    n_kv_groups, dtype=None):
        super().__init__()

        assert n_heads % n_kv_groups == 0, "Number of heads must be divisible by number of key-value groups"
        
        self.n_heads = n_heads
        self.n_kv_groups = n_kv_groups
        self.group_size = n_heads // n_kv_groups
        self.d_k = d_k

        self.wq = nn.Linear(d_model, n_heads * d_k, bias=False, dtype=dtype)
        self.wk = nn.Linear(d_model, n_kv_groups * d_k, bias=False, dtype=dtype)
        self.wv = nn.Linear(d_model, n_kv_groups * d_v, bias=False, dtype=dtype)
        self.linear = nn.Linear(n_heads * d_v, d_model, bias=False, dtype=dtype)     
        
        self.register_buffer('mask', 
            torch.triu(torch.ones(context_length, context_length), 
            diagonal=1))   

        cos, sin = precompute_rope_params(d=self.d_k, context_length=context_length)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin) 
        
    def forward(self, x):
        q = rearrange(self.wq(x), 'b t (h k) -> b h t k', h=self.n_heads)
        k = rearrange(self.wk(x), 'b t (nkv k) -> b nkv t k', nkv=self.n_kv_groups)
        v = rearrange(self.wv(x), 'b t (nkv v) -> b nkv t v', nkv=self.n_kv_groups)

        q = compute_rope(q, self.cos, self.sin)
        k = compute_rope(k, self.cos, self.sin)

        k = repeat(k, 'b nkv t k -> b (nkv gsz) t k', gsz=self.group_size)
        v = repeat(v, 'b nkv t v -> b (nkv gsz) t v', gsz=self.group_size)
        
        attn = torch.einsum('bhtk, bhsk -> bhts', q, k) / self.d_k**0.5
        mask_bool = self.mask.bool()[:x.size(1), :x.size(1)]
        attn = attn.masked_fill(mask_bool, -torch.inf)
        attn = F.softmax(attn, dim=3)
        out = torch.einsum('bhts, bhsv -> bhtv', attn, v)
        out = rearrange(out, 'b h t v -> b t (h v)')
        return self.linear(out)

### Multi-head attention (for comparison)

In [5]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, d_k, d_v,
            context_length, n_heads, dtype=None):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_k
        self.wq = nn.Linear(d_model, n_heads * d_k, bias=False, dtype=dtype)
        self.wk = nn.Linear(d_model, n_heads * d_k, bias=False, dtype=dtype)
        self.wv = nn.Linear(d_model, n_heads * d_v, bias=False, dtype=dtype)
        self.linear = nn.Linear(n_heads * d_v, d_model, bias=False, dtype=dtype) 
        self.register_buffer('mask', 
            torch.triu(torch.ones(context_length, context_length), 
            diagonal=1))   

        cos, sin = precompute_rope_params(d=self.d_k, context_length=context_length)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin) 

    def forward(self, x):
        q = rearrange(self.wq(x), 'b t (h k) -> b h t k', h=self.n_heads)
        k = rearrange(self.wk(x), 'b t (h k) -> b h t k', h=self.n_heads)
        v = rearrange(self.wv(x), 'b t (h v) -> b h t v', h=self.n_heads)
        
        q = compute_rope(q, self.cos, self.sin)
        k = compute_rope(k, self.cos, self.sin)

        attn = torch.einsum('bhtk, bhsk -> bhts', q, k) / self.d_k**0.5
        mask_bool = self.mask.bool()[:x.size(1), :x.size(1)]
        attn = attn.masked_fill(mask_bool, -torch.inf)
        attn = F.softmax(attn, dim=3)
        out = torch.einsum('bhts, bhsv -> bhtv', attn, v)
        out = rearrange(out, 'b h t v -> b t (h v)')
        return self.linear(out) 

### Illustrate differences between MHA and GQA

In [6]:
batch_size = 1
context_len = 3000
max_context_len = 8192
embed_dim = 4096
num_heads = 32
num_kv_groups=8

example_batch = torch.randn((batch_size, context_len, embed_dim))

mha = MultiHeadAttention(
    d_model=embed_dim,
    d_k=embed_dim//num_heads,
    d_v=embed_dim//num_heads,
    context_length=max_context_len,
    n_heads=num_heads
)

mha(example_batch)

gqa = GroupedQueryAttention(
    d_model=embed_dim,
    d_k=embed_dim//num_heads,
    d_v=embed_dim//num_heads,
    context_length=max_context_len,
    n_heads=num_heads,
    n_kv_groups=num_kv_groups
)

gqa(example_batch)

print("MHA:")
print("W_query:", mha.wq.weight.shape)
print("W_key:", mha.wk.weight.shape)
print("W_value:", mha.wv.weight.shape)

print()
print("GQA:")
print("W_query:", gqa.wq.weight.shape)
print("W_key:", gqa.wk.weight.shape)
print("W_value:", gqa.wv.weight.shape)


MHA:
W_query: torch.Size([4096, 4096])
W_key: torch.Size([4096, 4096])
W_value: torch.Size([4096, 4096])

GQA:
W_query: torch.Size([4096, 4096])
W_key: torch.Size([1024, 4096])
W_value: torch.Size([1024, 4096])


In [7]:
print("Total number of parameters:")

mha_total_params = sum(p.numel() for p in mha.parameters())
print(f"MHA: {mha_total_params:,}")

gqa_total_params = sum(p.numel() for p in gqa.parameters())
print(f"GQA: {gqa_total_params:,}")

Total number of parameters:
MHA: 67,108,864
GQA: 41,943,040


In [8]:
del mha
del gqa

### Transformer block

- the only change from Llama 2 is GQA instead of MHA.

In [9]:
class TransformerBlock(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.attn = GroupedQueryAttention(
            cfg['d_model'], cfg['d_k'], cfg['d_v'], 
            cfg['context_length'], cfg['n_heads'],
            cfg['n_kv_groups'], cfg['dtype'])
        self.norm1 = nn.RMSNorm(cfg['d_model'])
        self.fc1 = nn.Linear(cfg['d_model'], cfg['hidden_dim'],
                        dtype=cfg['dtype'], bias=False) 
        self.fc2 = nn.Linear(cfg['d_model'], cfg['hidden_dim'],
                        dtype=cfg['dtype'], bias=False) 
        self.fc3 = nn.Linear(cfg['hidden_dim'], cfg['d_model'],
                        dtype=cfg['dtype'], bias=False)
        self.silu = nn.SiLU()
        self.norm2 = nn.RMSNorm(cfg['d_model'])

    def forward(self, x):
        shortcut = x
        x = self.attn(self.norm1(x))
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.silu(self.fc1(x)) * self.fc2(x)
        x = self.fc3(x)
        x = x + shortcut
        return x

### Llama 3 model class

- only the name changes

In [10]:
class Llama3Model(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.token_embedding = nn.Embedding(
            cfg['vocab_size'], 
            cfg['d_model'], 
            dtype=cfg['dtype']
            )
        self.trf_blocks = nn.Sequential(*[
            TransformerBlock(cfg) for _ in range(cfg['n_blocks'])
        ])
        self.final_norm = nn.RMSNorm(cfg['d_model'])    
        self.out_head = nn.Linear(
            cfg['d_model'], 
            cfg['vocab_size'],
            bias=False,
            dtype=cfg['dtype']
            )

    def forward(self, x):
        x= self.token_embedding(x)
        x = self.trf_blocks(x)  
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits 

### Llama 3 8B config

In [11]:
LLAMA2_CONFIG_7B = {
    "vocab_size": 32000,     
    "context_length": 4096, 
    "d_model": 4096,
    "d_k": 128,
    "d_v": 128,       
    "n_heads": 32,           
    "n_blocks": 32,          
    "hidden_dim": 11008,    
    "dtype": torch.bfloat16  
}

LLAMA3_CONFIG_8B = {
    "vocab_size": 128_256,     
    "context_length": 8192, 
    "d_model": 4096,
    "d_k": 128,
    "d_v": 128,       
    "n_heads": 32,  
    "n_kv_groups": 8,         
    "n_blocks": 32,          
    "hidden_dim": 14_336,    
    "dtype": torch.bfloat16  
}

LLAMA3_CONFIG_TOY = {
    "vocab_size": 1000,     
    "context_length": 8192, 
    "d_model": 64,
    "d_k": 4,
    "d_v": 4,       
    "n_heads": 16,  
    "n_kv_groups": 8,         
    "n_blocks": 2,          
    "hidden_dim": 64,    
    "dtype": torch.bfloat16  
}

### Instantiate toy model

In [12]:
model = Llama3Model(LLAMA3_CONFIG_TOY)
model

Llama3Model(
  (token_embedding): Embedding(1000, 64)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (attn): GroupedQueryAttention(
        (wq): Linear(in_features=64, out_features=64, bias=False)
        (wk): Linear(in_features=64, out_features=32, bias=False)
        (wv): Linear(in_features=64, out_features=32, bias=False)
        (linear): Linear(in_features=64, out_features=64, bias=False)
      )
      (norm1): RMSNorm((64,), eps=None, elementwise_affine=True)
      (fc1): Linear(in_features=64, out_features=64, bias=False)
      (fc2): Linear(in_features=64, out_features=64, bias=False)
      (fc3): Linear(in_features=64, out_features=64, bias=False)
      (silu): SiLU()
      (norm2): RMSNorm((64,), eps=None, elementwise_affine=True)
    )
    (1): TransformerBlock(
      (attn): GroupedQueryAttention(
        (wq): Linear(in_features=64, out_features=64, bias=False)
        (wk): Linear(in_features=64, out_features=32, bias=False)
        (wv): Linear(in_featur

In [13]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 177,472 trainable parameters
