# From Llama 3.1 to Llama 3.2

Construct Llama 3.2 1B.

I don't train the model or load model weights from elsewhere.

Raschka defines a `SharedBuffer` class so that we can reuse the `mask`, `sin`, and `cos` tensors in the transformer blocks. I don't implement this here.

Part of the code for the RoPE implementation is copied from Raschka's repo (the section between `#New section` and `#End new section` ).

Differences between Llama 3.1 8B and Llama 3.2 1B:
- Llama 3.2 uses weight tying (the weights of the embedding layer are used for the output layer).
- Llama 3.2 has the same `context_length` as Llama 3.1 (131,072), but has half the embedding dimension (2,048 rather than 4,096) and half the number of transformer blocks (16 rather than 32). The dimension of the hidden layer in the transformer MLP is also much less (8,192 compared with Llama 3.1's 14,336).
- One of the RoPE parameters is different.

In this notebook:
- Imports.
- Llama 3.2 RoPE parameters.
- RoPE implementation.
- Grouped-query attention.
- Transformer block.
- Llama 3.2 model class.
- Llama 3.2 1B config.
- Instantiate toy model.

### Imports

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat

### Llama 3.2 RoPE parameters

- Same as Llama 3.1, except for `freq_config['factor']`.

In [34]:
# Llama 3.1
# theta_base = 500_000
# context_length= 131_072
# freq_config = {
#     "factor": 8.0,
#     "low_freq_factor": 1.0,
#     "high_freq_factor": 4.0,
#     "original_context_length": 8192
# }

# Llama 3.2
theta_base = 500_000
context_length= 131_072
freq_config = {
    "factor": 32.0,
    "low_freq_factor": 1.0,
    "high_freq_factor": 4.0,
    "original_context_length": 8192
}

### Implement RoPE

- Same structure as Llama 3.1. 

In [35]:
def precompute_rope_params(d, theta_base=theta_base, 
                        context_length=context_length,
                        freq_config=freq_config):
    div_term = torch.exp(torch.arange(0, d, 2)[: (d // 2)].float() * (-torch.log(torch.tensor(theta_base)) / d))
    inv_freq = div_term
    
    # New section
    low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
    high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]

    wavelen = 2 * torch.pi / inv_freq
    inv_freq_llama = torch.where(
        wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
    )

    smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
        freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
    )

    smoothed_inv_freq = (
        (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
    )

    is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
    inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
    # End new section
    
    inv_freq = div_term
    positions = rearrange(torch.arange(0, context_length, dtype=torch.float), 'i -> i 1')
    angles = positions * div_term
    angles = torch.cat([angles, angles], dim=-1)
    cos = torch.cos(angles)
    sin = torch.sin(angles)
    return cos, sin

def compute_rope(x, cos, sin):
    b, h, t, d = x.shape
    assert d % 2 == 0, "Head dimension must be even"

    x1 = x[:, :, :, : d // 2]
    x2 = x[:, :, :, d // 2 :]

    cos = rearrange(cos[: t, :], 't d -> 1 1 t d')
    sin = rearrange(sin[: t, :], 't d -> 1 1 t d')
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = x * cos + rotated * sin
    return x_rotated.to(dtype=x.dtype)

### Grouped-query attention

- Same as Llama 3.1.

In [36]:
class GroupedQueryAttention(nn.Module):

    def __init__(self, d_model, d_k, d_v, 
                    context_length, n_heads,
                    n_kv_groups, dtype=None):
        super().__init__()

        assert n_heads % n_kv_groups == 0, "Number of heads must be divisible by number of key-value groups"
        
        self.n_heads = n_heads
        self.n_kv_groups = n_kv_groups
        self.group_size = n_heads // n_kv_groups
        self.d_k = d_k

        self.wq = nn.Linear(d_model, n_heads * d_k, bias=False, dtype=dtype)
        self.wk = nn.Linear(d_model, n_kv_groups * d_k, bias=False, dtype=dtype)
        self.wv = nn.Linear(d_model, n_kv_groups * d_v, bias=False, dtype=dtype)
        self.linear = nn.Linear(n_heads * d_v, d_model, bias=False, dtype=dtype)     
        
        self.register_buffer('mask', 
            torch.triu(torch.ones(context_length, context_length), 
            diagonal=1))   
            
        cos, sin = precompute_rope_params(d=self.d_k, context_length=context_length)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin) 

    def forward(self, x):
        q = rearrange(self.wq(x), 'b t (h k) -> b h t k', h=self.n_heads)
        k = rearrange(self.wk(x), 'b t (nkv k) -> b nkv t k', nkv=self.n_kv_groups)
        v = rearrange(self.wv(x), 'b t (nkv v) -> b nkv t v', nkv=self.n_kv_groups)

        q = compute_rope(q, self.cos, self.sin)
        k = compute_rope(k, self.cos, self.sin)

        k = repeat(k, 'b nkv t k -> b (nkv gsz) t k', gsz=self.group_size)
        v = repeat(v, 'b nkv t v -> b (nkv gsz) t v', gsz=self.group_size)
        
        attn = torch.einsum('bhtk, bhsk -> bhts', q, k) / self.d_k**0.5
        mask_bool = self.mask.bool()[:x.size(1), :x.size(1)]
        attn = attn.masked_fill(mask_bool, -torch.inf)
        attn = F.softmax(attn, dim=3)
        out = torch.einsum('bhts, bhsv -> bhtv', attn, v)
        out = rearrange(out, 'b h t v -> b t (h v)')
        return self.linear(out)

### Transformer block

- Same as Llama 3.1.

In [37]:
class TransformerBlock(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.attn = GroupedQueryAttention(
            cfg['d_model'], cfg['d_k'], cfg['d_v'], 
            cfg['context_length'], cfg['n_heads'],
            cfg['n_kv_groups'], cfg['dtype'])
        self.norm1 = nn.RMSNorm(cfg['d_model'])
        self.fc1 = nn.Linear(cfg['d_model'], cfg['hidden_dim'],
                        dtype=cfg['dtype'], bias=False) 
        self.fc2 = nn.Linear(cfg['d_model'], cfg['hidden_dim'],
                        dtype=cfg['dtype'], bias=False) 
        self.fc3 = nn.Linear(cfg['hidden_dim'], cfg['d_model'],
                        dtype=cfg['dtype'], bias=False)
        self.silu = nn.SiLU()
        self.norm2 = nn.RMSNorm(cfg['d_model'])

    def forward(self, x):
        shortcut = x
        x = self.attn(self.norm1(x))
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.silu(self.fc1(x)) * self.fc2(x)
        x = self.fc3(x)
        x = x + shortcut
        return x

### Llama 3.2 model class

- Same as Llama 3.1 except for the weight tying. 

In [38]:
class Llama3_2Model(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.token_embedding = nn.Embedding(
            cfg['vocab_size'], 
            cfg['d_model'], 
            dtype=cfg['dtype']
            )
        self.trf_blocks = nn.Sequential(*[
            TransformerBlock(cfg) for _ in range(cfg['n_blocks'])
        ])
        self.final_norm = nn.RMSNorm(cfg['d_model'])    
        self.out_head = nn.Linear(
            cfg['d_model'], 
            cfg['vocab_size'],
            bias=False,
            dtype=cfg['dtype']
            )
        self.out_head.weight = nn.Parameter(self.token_embedding.weight.T)
        self.out_head.weight.requires_grad = False

    def forward(self, x):
        x= self.token_embedding(x)
        x = self.trf_blocks(x)  
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits 

### Llama 3.2 1B config

In [39]:
LLAMA3_1_CONFIG_8B = {
    "vocab_size": 128_256,     
    "context_length": 131_072, 
    "d_model": 4096,
    "d_k": 128,
    "d_v": 128,       
    "n_heads": 32,  
    "n_kv_groups": 8,         
    "n_blocks": 32,          
    "hidden_dim": 14_336,    
    "dtype": torch.bfloat16  
}

LLAMA3_2_CONFIG_1B = {
    "vocab_size": 128_256,     
    "context_length": 131_072, 
    "d_model": 2048,
    "d_k": 128,
    "d_v": 128,       
    "n_heads": 32,  
    "n_kv_groups": 8,         
    "n_blocks": 16,          
    "hidden_dim": 8192,    
    "dtype": torch.bfloat16  
}

LLAMA3_2_CONFIG_TOY = {
    "vocab_size": 128_256,     
    "context_length": 1000, 
    "d_model": 64,
    "d_k": 4,
    "d_v": 4,       
    "n_heads": 16,  
    "n_kv_groups": 8,         
    "n_blocks": 1,          
    "hidden_dim": 64,    
    "dtype": torch.bfloat16  
}

### Instantiate toy model

In [40]:
model = Llama3_2Model(LLAMA3_2_CONFIG_TOY)
model

Llama3_2Model(
  (token_embedding): Embedding(128256, 64)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (attn): GroupedQueryAttention(
        (wq): Linear(in_features=64, out_features=64, bias=False)
        (wk): Linear(in_features=64, out_features=32, bias=False)
        (wv): Linear(in_features=64, out_features=32, bias=False)
        (linear): Linear(in_features=64, out_features=64, bias=False)
      )
      (norm1): RMSNorm((64,), eps=None, elementwise_affine=True)
      (fc1): Linear(in_features=64, out_features=64, bias=False)
      (fc2): Linear(in_features=64, out_features=64, bias=False)
      (fc3): Linear(in_features=64, out_features=64, bias=False)
      (silu): SiLU()
      (norm2): RMSNorm((64,), eps=None, elementwise_affine=True)
    )
  )
  (final_norm): RMSNorm((64,), eps=None, elementwise_affine=True)
  (out_head): Linear(in_features=64, out_features=128256, bias=False)
)

In [41]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 8,233,152 trainable parameters
