# From Llama 3 to Llama 3.1

Description - to fill in later

### Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat

### Llama 3.1 RoPE parameters

In [4]:
theta_base = 500_000
context_length= 131_072
freq_config = {
    "factor": 32.0,
    "low_freq_factor": 1.0,
    "high_freq_factor": 4.0,
    "original_context_length": 8192
}

### Implement RoPE

- the RoPE method used by Llama 3.1 introduces additional adjustments to the inverse frequency calculations. I haven't gone through this in detail. The code in the cell below between `#New section` and `#End new section` is copied from Raschka's repo.

In [7]:
def precompute_rope_params(d, theta_base=theta_base, 
                        context_length=context_length,
                        freq_config=freq_config):
    div_term = torch.exp(torch.arange(0, d, 2)[: (d // 2)].float() * (-torch.log(torch.tensor(theta_base)) / d))
    
    # New section
    inv_freq = div_term
    low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
    high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]

    wavelen = 2 * torch.pi / inv_freq
    inv_freq_llama = torch.where(
        wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
    )

    smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
        freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
    )

    smoothed_inv_freq = (
        (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
    )

    is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
    inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)

    div_term = inv_freq_llama
    # End new section
    
    positions = rearrange(torch.arange(0, context_length, dtype=torch.float), 'i -> i 1')
    angles = positions * div_term
    angles = torch.cat([angles, angles], dim=-1)
    cos = torch.cos(angles)
    sin = torch.sin(angles)
    return cos, sin

def compute_rope(x, cos, sin):
    b, h, t, d = x.shape
    assert d % 2 == 0, "Head dimension must be even"

    x1 = x[:, :, :, : d // 2]
    x2 = x[:, :, :, d // 2 :]

    cos = rearrange(cos[: t, :], 't d -> 1 1 t d')
    sin = rearrange(sin[: t, :], 't d -> 1 1 t d')
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = x * cos + rotated * sin
    return x_rotated.to(dtype=x.dtype)

### Illustration of effects of applying RoPE

In [14]:
batch_size = 2
n_heads = 4
d_k = 16

cos, sin = precompute_rope_params(d_k, theta_base, 
            context_length, freq_config)

torch.manual_seed(0)
q = torch.randn(batch_size, n_heads, context_length, d_k)   
k = torch.randn(batch_size, n_heads, context_length, d_k)   

print(q[0, 0, 1, :])
print(k[0, 0, 1, :])

q_rotated = compute_rope(q, cos, sin)
k_rotated = compute_rope(k, cos, sin)

print(q_rotated[0, 0, 1, :])
print(k_rotated[0, 0, 1, :])

tensor([-1.3527, -1.6959,  0.5667,  0.7935,  0.5988, -1.5551, -0.3414,  1.8530,
         0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463, -0.8437])
tensor([-1.3243,  0.4273,  1.4935, -1.9839, -0.5936,  1.2878,  0.1049,  2.4659,
        -0.4574,  0.8525, -1.5453,  1.3069, -1.2914,  0.9167, -0.9227, -0.4272])
tensor([-1.3621, -1.5513,  0.5728,  0.7921,  0.5982, -1.5551, -0.3414,  1.8530,
        -0.7329, -0.9013, -0.1520,  0.1893,  1.3896,  1.5863,  0.9463, -0.8437])
tensor([-0.3306,  0.2550,  1.5505, -1.9934, -0.5931,  1.2878,  0.1049,  2.4659,
        -1.3615,  0.9188, -1.4881,  1.2924, -1.2917,  0.9167, -0.9227, -0.4272])


### Grouped-query attention

- same as Llama 3

In [3]:
class GroupedQueryAttention(nn.Module):

    def __init__(self, d_model, d_k, d_v, 
                    context_length, n_heads,
                    n_kv_groups, dtype=None):
        super().__init__()

        assert n_heads % n_kv_groups == 0, "Number of heads must be divisible by number of key-value groups"
        
        self.n_heads = n_heads
        self.n_kv_groups = n_kv_groups
        self.group_size = n_heads // n_kv_groups
        self.d_k = d_k

        self.wq = nn.Linear(d_model, n_heads * d_k, bias=False, dtype=dtype)
        self.wk = nn.Linear(d_model, n_kv_groups * d_k, bias=False, dtype=dtype)
        self.wv = nn.Linear(d_model, n_kv_groups * d_v, bias=False, dtype=dtype)
        self.linear = nn.Linear(n_heads * d_v, d_model, bias=False, dtype=dtype)     
        
        self.register_buffer('mask', 
            torch.triu(torch.ones(context_length, context_length), 
            diagonal=1))   
            
        cos, sin = precompute_rope_params(d=self.d_k, context_length=context_length)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin) 
        
    def forward(self, x):
        q = rearrange(self.wq(x), 'b t (h k) -> b h t k', h=self.n_heads)
        k = rearrange(self.wk(x), 'b t (nkv k) -> b nkv t k', nkv=self.n_kv_groups)
        v = rearrange(self.wv(x), 'b t (nkv v) -> b nkv t v', nkv=self.n_kv_groups)

        q = compute_rope(q, self.cos, self.sin)
        k = compute_rope(k, self.cos, self.sin)

        k = repeat(k, 'b nkv t k -> b (nkv gsz) t k', gsz=self.group_size)
        v = repeat(v, 'b nkv t v -> b (nkv gsz) t v', gsz=self.group_size)
        
        attn = torch.einsum('bhtk, bhsk -> bhts', q, k) / self.d_k**0.5
        mask_bool = self.mask.bool()[:x.size(1), :x.size(1)]
        attn = attn.masked_fill(mask_bool, -torch.inf)
        attn = F.softmax(attn, dim=3)
        out = torch.einsum('bhts, bhsv -> bhtv', attn, v)
        out = rearrange(out, 'b h t v -> b t (h v)')
        return self.linear(out)

### Transformer block

- same as Llama 3

In [None]:
class TransformerBlock(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.attn = GroupedQueryAttention(
            cfg['d_model'], cfg['d_k'], cfg['d_v'], 
            cfg['context_length'], cfg['n_heads'],
            cfg['n_kv_groups'], cfg['dtype'])
        self.norm1 = nn.RMSNorm(cfg['d_model'])
        self.fc1 = nn.Linear(cfg['d_model'], cfg['hidden_dim'],
                        dtype=cfg['dtype'], bias=False) 
        self.fc2 = nn.Linear(cfg['d_model'], cfg['hidden_dim'],
                        dtype=cfg['dtype'], bias=False) 
        self.fc3 = nn.Linear(cfg['hidden_dim'], cfg['d_model'],
                        dtype=cfg['dtype'], bias=False)
        self.silu = nn.SiLU()
        self.norm2 = nn.RMSNorm(cfg['d_model'])

    def forward(self, x):
        shortcut = x
        x = self.attn(self.norm1(x))
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.silu(self.fc1(x)) * self.fc2(x)
        x = self.fc3(x)
        x = x + shortcut
        return x

### Llama 3.1 model class

- same as Llama 3 except for the name

In [None]:
class Llama3_1Model(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.token_embedding = nn.Embedding(
            cfg['vocab_size'], 
            cfg['d_model'], 
            dtype=cfg['dtype']
            )
        self.trf_blocks = nn.Sequential(*[
            TransformerBlock(cfg) for _ in range(cfg['n_blocks'])
        ])
        self.final_norm = nn.RMSNorm(cfg['d_model'])    
        self.out_head = nn.Linear(
            cfg['d_model'], 
            cfg['vocab_size'],
            bias=False,
            dtype=cfg['dtype']
            )

    def forward(self, x):
        x= self.token_embedding(x)
        x = self.trf_blocks(x)  
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits 

### Llama 3.1 8B config

In [None]:
LLAMA3_CONFIG_8B = {
    "vocab_size": 128_256,     
    "context_length": 8192, 
    "d_model": 4096,
    "d_k": 128,
    "d_v": 128,       
    "n_heads": 32,  
    "n_kv_groups": 8,         
    "n_blocks": 32,          
    "hidden_dim": 14_336,    
    "dtype": torch.bfloat16  
}