In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from types import SimpleNamespace
from collections import OrderedDict

#### a starting point: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb

# RMSNorm
replaces LayerNorm

$$
y_i = \frac{x_i}{\text{RMS}(x)}\gamma_i\\
RMS(x) = \sqrt{\epsilon + \frac{1}{n} \sum x_i^2}
$$
$\gamma$ is a learnable parameter
- $x$ is input, $x_i$ will be one feature/neuron
i.e. if x is of shape (1,128,1024) # bs, seq_len, num_fts then we need to normalize along the last dim of 1024 neurons
- we init gamma with 1

In [2]:
class LlamaRMSNorm(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.weight = nn.Parameter(
            torch.ones(self.embed_dim,dtype=torch.float32),
            requires_grad=True
        )
        
    def forward(self, x):
        # x [B, S, D]
        mean = x.pow(2).mean(dim=-1,keepdim=True)
        r_sqrt = x * torch.rsqrt(mean + 1e-5) # [B, S, 1]
        y = r_sqrt * self.weight
        return y.to(x.dtype)

In [3]:
# testing RMSNorm
rms = LlamaRMSNorm(embed_dim=8)
x = torch.rand(1,3,8)
rms(x).shape,torch.allclose(rms(x),nn.RMSNorm(8,eps=1e-5)(x))

(torch.Size([1, 3, 8]), True)

# New activation function: SiLU (Swish)

$$\text{SiLU}(x) = x * \sigma(x)$$
$\sigma(x)$ is sigmoid function 

In [4]:
class SiLU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        # x [B S D]
        return x * F.sigmoid(x)

In [5]:
# testing SiLU
torch.allclose(SiLU()(x),F.silu(x))

True

# New MLP! SwiGLU

$$
\text{SwiGLU}(x) = \text{SiLU}(\text{linear}_1(x)) * \text{linear}_2(x) \\
\text{output} = \text{linear}_3(\text{SwiGLU}(x))
$$

In [6]:
class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.intermediate_dim = config.intermediate_dim
        self.gate_proj = nn.Linear(self.embed_dim, self.intermediate_dim, bias=False, dtype=config.dtype)
        self.up_proj = nn.Linear(self.embed_dim, self.intermediate_dim, bias=False, dtype=config.dtype)
        self.down_proj = nn.Linear(self.intermediate_dim, self.embed_dim, bias=False, dtype=config.dtype)
        self.act_fn = SiLU()
        
    def forward(self, x):
        # x [B S D]
        x1 = self.gate_proj(x)
        x2 = self.up_proj(x)
        x = self.act_fn(x1) * x2
        x = self.down_proj(x)
        return x
        

# RoPE
rotary positional embeddings!

- applied to q,k at MHA step
- precomputed angles, their sine and cosine based on model's context length
- current implementation input shape: B, nH, S, H

In [7]:
def precompute_rope(head_dim, base_theta=10_000, context_length=4096):
    k = torch.arange(0,head_dim,2)[:head_dim//2].float()
    inv_freq = 1 / (base_theta ** (k/head_dim))
    
    positions = torch.arange(context_length)
    angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # [S, H/2]
    angles = torch.cat([angles, angles],dim=-1) # [S, H]
    
    cos = torch.cos(angles) # [S, H]
    sin = torch.sin(angles) # [S, H]
    
    
    return cos, sin

In [8]:
def apply_rope(x, cos, sin):
    B, nH, S, H = x.shape
    x1 = x[...,:H//2] # [B, nH, S, H/2]
    x2 = x[...,H//2:] # [B, nH, S, H/2]
    cos_values = cos[:S,:].unsqueeze(0).unsqueeze(1) # [1,1,S,H]
    sin_values = sin[:S,:].unsqueeze(0).unsqueeze(1) # [1,1,S,H]
    rotated = torch.cat([-x2,x1],dim=-1)
    x_rope = (x * cos_values) + (rotated * sin_values)
    return x_rope.to(x.dtype)

In [9]:
head_dim = 512
inv_freq2 = 1.0 / (10000 ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
# plt.plot(inv_freq2.numpy())
# plt.show()
print(torch.arange(256)[None,:].shape ,inv_freq2[:,None].shape, (torch.arange(256)[None,:] * inv_freq2[:,None]).shape)
cos, sin = precompute_rope(head_dim)
print(cos.shape, sin.shape)


# plot it to make sure
# import numpy
# import matplotlib.pyplot as plt
# plt.figure(figsize=(12,36))
# plt.imshow(cos.numpy())
# plt.show()
# plt.figure(figsize=(12,36))
# plt.imshow(sin.numpy())
# plt.show()

x = torch.rand(2,8,128,64)
cos,sin = precompute_rope(64,context_length=128)
x_rope = apply_rope(x, cos, sin)
x_rope.shape

torch.Size([1, 256]) torch.Size([256, 1]) torch.Size([256, 256])
torch.Size([4096, 512]) torch.Size([4096, 512])


torch.Size([2, 8, 128, 64])

In [10]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.embed_dim = config.embed_dim
        self.num_kv_heads = config.num_kv_heads
        self.num_q_heads = config.num_q_heads
        
        assert self.embed_dim % self.num_q_heads == 0, 'embed_dim should be div. by num. of query heads'
        assert self.num_q_heads % self.num_kv_heads ==0, 'num. query heads should be div. by num. key-value heads'
        
        self.head_dim = self.embed_dim // self.num_q_heads
        
        self.q_proj = nn.Linear(self.embed_dim, self.head_dim * self.num_q_heads, bias=False, dtype=config.dtype)
        self.k_proj = nn.Linear(self.embed_dim, self.head_dim * self.num_kv_heads, bias=False, dtype=config.dtype)
        self.v_proj = nn.Linear(self.embed_dim, self.head_dim * self.num_kv_heads, bias=False, dtype=config.dtype)
        
        self.drop = nn.Dropout(config.attn_dropout)
        
        self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False, dtype=config.dtype)
        
        self.register_buffer('causal_mask',
                             torch.triu(torch.ones(
                                 config.max_position_embeddings,
                                 config.max_position_embeddings
                             ),diagonal=1))
        
        cos, sin = precompute_rope(self.head_dim,base_theta=config.base_theta,context_length=config.max_position_embeddings)
        self.register_buffer('rope_cos', cos)
        self.register_buffer('rope_sin', sin)
        
    def forward(self, x):
        # x [B S D]
        
        B,S,D = x.shape
        
        q = self.q_proj(x) # [B S H*nQ]
        k = self.k_proj(x) # [B S H*nKV]
        v = self.v_proj(x) # [B S H*nKV]
        
        q = q.view(B, S, self.num_q_heads, self.head_dim).transpose(1,2) # [B nQ S H]
        k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1,2) # [B nKV S H]
        v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1,2) # [B nKV S H]
        
        q = apply_rope(q, self.rope_cos, self.rope_sin)
        k = apply_rope(k, self.rope_cos, self.rope_sin)
        
        k = k.repeat_interleave(self.num_q_heads//self.num_kv_heads, dim=1) # [B nQ S H]
        v = v.repeat_interleave(self.num_q_heads//self.num_kv_heads, dim=1) # [B nQ S H]
        
        attn = q @ k.transpose(2,3) # [B nQ S H] @ [B nQ H S] = [B nQ S S]
        
        # apply mask, mul with v, reshape, return
        mask = self.causal_mask[:S,:S].bool()
        attn.masked_fill_(mask,-torch.inf)
        
        attn = F.softmax(attn / (self.head_dim ** 0.5), dim=-1)
        
        attn = self.drop(attn)
        
        out = attn @ v # [B nQ S S] @ [B nQ S H] = [B nQ S H]
        out = out.transpose(1,2) # [B S nQ H]
        out = out.reshape(B, S, D)
        
        proj = self.o_proj(out)
        
        return proj

In [11]:
class LlamaDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.self_attn = GroupedQueryAttention(config)
        self.mlp = LlamaMLP(config)
        
        self.input_layernorm = LlamaRMSNorm(config.embed_dim)
        self.post_attention_layernorm = LlamaRMSNorm(config.embed_dim)
        
        
    def forward(self, x):
        # x [B S D]
        skip = x
        x = self.input_layernorm(x)
        x = self.self_attn(x)
        x = x + skip
        
        skip = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = x + skip
        
        return x

In [12]:
class LLaMA(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(
            self.config.vocab_size, 
            self.config.embed_dim, 
            padding_idx=self.config.eos_token_id,
            dtype=self.config.dtype)
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(self.config) for _ in range(self.config.num_layers)
        ])

        self.norm = LlamaRMSNorm(self.config.embed_dim)
        self.lm_head = nn.Linear(self.config.embed_dim, self.config.vocab_size, bias=False, dtype=self.config.dtype)

        # weight tying
        self.embed_tokens.weight = self.lm_head.weight
        
    def forward(self, input_ids):
        # input_ids [B S]
        x = self.embed_tokens(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        logits = self.lm_head(x)
        return logits

In [13]:
config = SimpleNamespace(
    embed_dim = 576,
    intermediate_dim = 1536,
    max_position_embeddings = 8192,
    base_theta = 100000,
    num_q_heads = 9,
    num_kv_heads = 3,
    attn_dropout = 0.,
    num_layers = 30,
    vocab_size = 49152,
    eos_token_id = 2,
    dtype = torch.float32,
)

In [14]:
if 'model' in globals() or 'model' in locals():
    print('...')
    del model
model = LLaMA(config)

In [15]:
def model_memory_size(model, input_dtype=torch.float32):
    total_params = 0
    total_grads = 0
    for param in model.parameters():
        # Calculate total number of elements per parameter
        param_size = param.numel()
        total_params += param_size
        # Check if gradients are stored for this parameter
        if param.requires_grad:
            total_grads += param_size

    # Calculate buffer size (non-parameters that require memory)
    total_buffers = sum(buf.numel() for buf in model.buffers())

    # Size in bytes = (Number of elements) * (Size of each element in bytes)
    # We assume parameters and gradients are stored in the same type as input dtype
    element_size = torch.tensor(0, dtype=input_dtype).element_size()
    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size

    # Convert bytes to gigabytes
    total_memory_gb = total_memory_bytes / (1024**3)

    return total_memory_gb, total_params

total_mem, total_params = model_memory_size(model, input_dtype=torch.float32)
print(f"float32 (PyTorch default): {total_mem:.2f} GB with {total_params:,} parameters")

float32 (PyTorch default): 8.62 GB with 134,515,008 parameters


In [16]:
x = torch.randint(0,config.vocab_size,(1,10)).long()
x

tensor([[20857, 34648, 24865, 15421, 38794, 18933, 40766, 17794, 29075, 34495]])

In [17]:
out = model(x)

In [18]:
out.shape

torch.Size([1, 10, 49152])

In [19]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
smol = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")

tokenizer_config.json:   0%|          | 0.00/3.76k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/801k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.10M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/655 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/861 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

In [20]:
smol_sd = smol.state_dict()
model_sd = model.state_dict()
model_sd = {k:v for k,v in model_sd.items() if not any([s in k for s in ['rope','causal_mask']])}

len(smol_sd), len(model_sd)

(273, 273)

In [21]:
sum([p.numel() for p in smol.parameters()]),sum([p.numel() for p in model.parameters()])

(134515008, 134515008)

In [29]:
keys = list(model_sd.keys())
for idx,(k,v) in enumerate(smol_sd.items()):
    new_key = keys[idx]
    setattr(model,new_key,v.clone())

In [34]:
torch.allclose(smol.model.embed_tokens.weight, model.embed_tokens.weight)

True

In [40]:
def get_input(text):
    messages = [{"role": "user", "content": text}]
    input_text=tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer.encode(input_text, return_tensors="pt")
    return inputs

In [64]:
def generate(
    model,
    input_ids,
    max_new_tokens=32,
    context_length = config.max_position_embeddings,
    temperature = 0.,
    eos_token_id = config.eos_token_id
):
    model.eval()
    inputs = input_ids.clone()
    print(tokenizer.decode(inputs.flatten().numpy()))
    for _ in range(max_new_tokens):
        context = inputs[:,-context_length:]
        with torch.inference_mode():
            logits = model(context)
            logits = logits[:,-1,:]

            if temperature > 0.:
                logits = logits / temperature

            probs = logits.softmax(dim=-1)

            if temperature > 0.:
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(probs,dim=-1,keepdim=True)

            if next_token == eos_token_id:
                break
        print(tokenizer.decode(next_token.flatten().numpy()),end='')
        inputs = torch.cat([inputs, next_token],dim=1)  
    print()
    return inputs            

In [68]:
input = get_input('give me a random fact about llamas')
generated = generate(model, input, max_new_tokens=80, temperature=0.125)

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
give me a random fact about llamas<|im_end|>

<|im_start|>assistant
Llamas are large, four-legged animals native to the Andes Mountains of South America. They are known for their unique adaptations, including their ability to run at incredible speeds of up to 30 miles per hour. They are also known for their distinctive horns, which are shaped like a pair of horns, and their ability to climb trees. Llamas are also known
