- https://blog.briankitano.com/llama-from-scratch/
- https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
- https://github.com/facebookresearch/llama/blob/main/llama/model.py
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

In [4]:
@dataclass
class Config:
    dim = 256
    hidden_dim = int(8*dim / 3)
    hidden_dim = hidden_dim + 256 - (hidden_dim % 256) # make it multiple of 256
    n_heads = 8
    n_kv_heads = 4
    attn_bias = False
    max_pos_embeds = 256

In [5]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dim = config.dim
        self.hidden_dim = config.hidden_dim
        
        self.fc1 = nn.Linear(self.dim, self.hidden_dim, bias=False)
        self.fc2 = nn.Linear(self.dim, self.hidden_dim, bias=False)
        self.proj = nn.Linear(self.hidden_dim, self.dim, bias=False)
        
    def forward(self, x):
        x = F.silu(self.fc1(x)) * self.fc2(x)
        return self.proj(x)

### Root Mean Square Normalization

$$\bar{a}_i = \frac{a_i}{\text{RMS}(a)}g_i \\
\text{RMS}(a) = \sqrt{\frac{1}{n}\sum_{i=1}^{n}a_i^2}
$$

In [6]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.g = nn.Parameter(torch.ones(dim))
        self.eps = eps
        
    def _norm(self,x):
        ms = x.pow(2).mean(-1,keepdim=True) + self.eps # ms(a)
        return x * torch.rsqrt(ms) # 1/rms(a)
        
    def forward(self, x):
        x = self._norm(x.float()).to(dtype=x.dtype)
        return self.g * x

### Rotary Positional Embeddings (RoPE)

In [13]:
class RoPE(nn.Module):
    def __init__(self, dim, max_positions, base=10_000, scaling_factor = 1.0):
        super().__init__()
        self.dim = dim
        self.max_positions = max_positions
        self.base = base
        self.scaling_factor = scaling_factor
        
        inv_freq = 1.0 / (self.base ** (torch.arange(0,self.dim,2).float() / self.dim))
        self.register_buffer('inv_freq',inv_freq)
        
        self.seq_len_cache = None
        self.sin_cache = None
        self.cos_cache = None
        
    def _create_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        
        freqs = torch.outer(t, self.inv_freq)
        
        emb = torch.cat([freqs,freqs],dim=-1)
        self.register_buffer('sin_cache',emb.sin().to(dtype))
        self.register_buffer('cos_cache',emb.cos().to(dtype))
        
    def forward(self, position_ids, seq_len, device, dtype):
        # x: B x nH x T x H
        
        if seq_len != self.seq_len_cache:
            self.seq_len_cache = seq_len
            t = torch.arange(self.seq_len_cache, device=device).type_as(self.inv_freq)
            t = t / self.scaling_factor
            freqs = self.inv_freq[:, None].float().expand(-1, position_ids.shape[0]) @ position_ids.float()
            freqs = freqs.t()
            emb = torch.cat([freqs, freqs],dim=-1)
            
            
        return emb.cos().type_as(x), emb.sin().type_as(x)
    
    def _rotate_half(self, x):
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat([-x2,x1],dim=-1)
    
    def apply(self, x, cos, sin, position_ids):
        """
        x: q or k tensor : [b, nH, T, H]
        cos[position_ids], sin[position_ids] : [b, T, H]
        """
        
        return x*cos + (self._rotate_half(x) * sin)

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.dropout = config.attention_dropout
        self.dim = config.dim
        self.num_heads = config.num_heads
        self.head_dim = self.dim // self.num_heads
        self.max_positions = config.max_positions
        self.rope_theta = config.rope_theta
        self.scaling_factor = config.scaling_factor
        
        self.qkv = nn.Linear(self.dim, self.dim*3, bias=False)
        self.proj = nn.Linear(self.dim, self.dim, bias=False)
        
        self.rope = RoPE(
            dim=self.head_dim,
            base=self.rope_theta,
            scaling_factor=self.scaling_factor
        )
        
    def forward(self, x, mask, position_ids):
        
        B,T,C = x.shape
        
        q,k,v = self.qkv(x).split(3,dim=-1)
        # B, T, nH, H
        q = q.reshape(B,T,self.num_heads,self.head_dim).transpose(1,2)
        k = k.reshape(B,T,self.num_heads,self.head_dim).transpose(1,2)
        v = v.reshape(B,T,self.num_heads,self.head_dim).transpose(1,2)
        
        cos, sin = self.rope(q, )