# Reverse Engineering SmolLM2-135M

This notebook guides you through downloading the SmolLM2-135M model, inspecting its architecture, implementing it from scratch in PyTorch, and validating the implementation.

In [None]:
# Force reinstall torch and torchvision to fix version mismatch
!pip uninstall -y torch torchvision
!pip install torch torchvision transformers datasets huggingface_hub

> **IMPORTANT**: After running the cell above, you **MUST** restart the Jupyter Kernel for the changes to take effect. Go to **Kernel > Restart Kernel** in the menu.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import math

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


## 1. Download and Load Model
We load the pre-trained model from HuggingFace.

In [2]:
model_id = "HuggingFaceTB/SmolLM2-135M"
tokenizer = AutoTokenizer.from_pretrained(model_id)
hf_model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
hf_config = AutoConfig.from_pretrained(model_id)

print("Model Config:")
print(hf_config)
print("\nModel Structure:")
print(hf_model)

Model Config:
LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "dtype": "bfloat16",
  "eos_token_id": 0,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 576,
  "initializer_range": 0.041666666666666664,
  "intermediate_size": 1536,
  "is_llama_config": true,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 9,
  "num_hidden_layers": 30,
  "num_key_value_heads": 3,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_interleaved": false,
  "rope_scaling": null,
  "rope_theta": 100000,
  "tie_word_embeddings": true,
  "transformers_version": "4.57.1",
  "use_cache": true,
  "vocab_size": 49152
}


Model Structure:
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_feat

## 2. Reverse Engineering Architecture
Based on the config and printout, we can see it uses a Llama-style architecture:
- **RMSNorm** for normalization.
- **Rotary Positional Embeddings (RoPE)**.
- **SwiGLU** activation in the MLP.
- **Grouped Query Attention (GQA)** (though for 135M it might be standard MHA, let's check `num_key_value_heads`).

Let's define the model from scratch.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        mean_square = (x.pow(2).mean(-1, keepdim=True))
        x = x * torch.rsqrt(mean_square + self.eps)
        return self.weight * x

def rotate_half(x):
    # Rotates half the hidden dims of the input.
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    # q, k: [bsz, heads, seq_len, head_dim]
    # cos, sin: [seq_len, head_dim] -> unsqueeze to [1, 1, seq_len, head_dim]
    cos = cos.unsqueeze(0).unsqueeze(0)
    sin = sin.unsqueeze(0).unsqueeze(0)
    
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        
        self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)

    def forward(self, x, cos, sin, mask=None):
        bsz, seq_len, _ = x.shape
        q = self.q_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        
        k = k.repeat_interleave(self.num_key_value_groups, dim=1)
        v = v.repeat_interleave(self.num_key_value_groups, dim=1)
        
        attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            attn_weights = attn_weights + mask
            
        attn_weights = F.softmax(attn_weights, dim=-1)
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
        return self.o_proj(output)

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = Attention(config)
        self.mlp = MLP(config)
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self, x, cos, sin, mask=None):
        h = x + self.self_attn(self.input_layernorm(x), cos, sin, mask)
        out = h + self.mlp(self.post_attention_layernorm(h))
        return out

class SmolLM2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # RoPE setup
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.rope_theta = getattr(config, "rope_theta", 10000.0)
        self.inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
        self.max_pos = config.max_position_embeddings * 2
        self._set_cos_sin_cache(self.max_pos)

    def _set_cos_sin_cache(self, seq_len):
        t = torch.arange(seq_len, dtype=torch.float32)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)

    def forward(self, input_ids):
        bsz, seq_len = input_ids.shape
        x = self.embed_tokens(input_ids)
        
        if self.cos_cached.device != x.device or self.cos_cached.shape[0] < seq_len:
            self.inv_freq = self.inv_freq.to(x.device)
            self._set_cos_sin_cache(max(seq_len, 2048))
            
        cos = self.cos_cached[:seq_len].to(dtype=x.dtype)
        sin = self.sin_cached[:seq_len].to(dtype=x.dtype)
        
        mask = None
        if seq_len > 1:
            mask = torch.full((seq_len, seq_len), float("-inf"), device=input_ids.device)
            mask = torch.triu(mask, diagonal=1)
            
        for layer in self.layers:
            x = layer(x, cos, sin, mask)
            
        x = self.norm(x)
        logits = self.lm_head(x)
        return logits


## 3. Weight Transfer
Now we copy the weights from the HuggingFace model to our custom implementation.

In [4]:
custom_model = SmolLM2(hf_config).to(device)

def copy_weights(hf_model, custom_model):
    with torch.no_grad():
        # Embeddings
        custom_model.embed_tokens.weight.copy_(hf_model.model.embed_tokens.weight)
        
        # Layers
        for i, (hf_layer, custom_layer) in enumerate(zip(hf_model.model.layers, custom_model.layers)):
            custom_layer.input_layernorm.weight.copy_(hf_layer.input_layernorm.weight)
            custom_layer.post_attention_layernorm.weight.copy_(hf_layer.post_attention_layernorm.weight)
            
            # Attention
            custom_layer.self_attn.q_proj.weight.copy_(hf_layer.self_attn.q_proj.weight)
            custom_layer.self_attn.k_proj.weight.copy_(hf_layer.self_attn.k_proj.weight)
            custom_layer.self_attn.v_proj.weight.copy_(hf_layer.self_attn.v_proj.weight)
            custom_layer.self_attn.o_proj.weight.copy_(hf_layer.self_attn.o_proj.weight)
            
            # MLP
            custom_layer.mlp.gate_proj.weight.copy_(hf_layer.mlp.gate_proj.weight)
            custom_layer.mlp.up_proj.weight.copy_(hf_layer.mlp.up_proj.weight)
            custom_layer.mlp.down_proj.weight.copy_(hf_layer.mlp.down_proj.weight)
            
        # Final Norm and Head
        custom_model.norm.weight.copy_(hf_model.model.norm.weight)
        custom_model.lm_head.weight.copy_(hf_model.lm_head.weight)

copy_weights(hf_model, custom_model)
print("Weights copied successfully.")

Weights copied successfully.


## 4. Validation
We verify that the outputs match.

In [5]:
input_text = "Once upon a time"
inputs = tokenizer(input_text, return_tensors="pt").to(device)

with torch.no_grad():
    hf_outputs = hf_model(**inputs).logits
    custom_outputs = custom_model(inputs.input_ids)

diff = (hf_outputs - custom_outputs).abs().max()
print(f"Max difference: {diff.item()}")

assert diff < 1e-4, "Models diverge!"
print("Validation Passed!")

Max difference: 4.1961669921875e-05
Validation Passed!


## 5. Export Model
We save the model class to `model.py` for use in training.

In [6]:
code = """
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        mean_square = (x.pow(2).mean(-1, keepdim=True))
        x = x * torch.rsqrt(mean_square + self.eps)
        return self.weight * x

def rotate_half(x):
    # Rotates half the hidden dims of the input.
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    # q, k: [bsz, heads, seq_len, head_dim]
    # cos, sin: [seq_len, head_dim] -> unsqueeze to [1, 1, seq_len, head_dim]
    cos = cos.unsqueeze(0).unsqueeze(0)
    sin = sin.unsqueeze(0).unsqueeze(0)
    
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        
        self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)

    def forward(self, x, cos, sin, mask=None):
        bsz, seq_len, _ = x.shape
        q = self.q_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        
        k = k.repeat_interleave(self.num_key_value_groups, dim=1)
        v = v.repeat_interleave(self.num_key_value_groups, dim=1)
        
        attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            attn_weights = attn_weights + mask
            
        attn_weights = F.softmax(attn_weights, dim=-1)
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
        return self.o_proj(output)

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = Attention(config)
        self.mlp = MLP(config)
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self, x, cos, sin, mask=None):
        h = x + self.self_attn(self.input_layernorm(x), cos, sin, mask)
        out = h + self.mlp(self.post_attention_layernorm(h))
        return out

class SmolLM2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # RoPE setup
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.rope_theta = getattr(config, "rope_theta", 10000.0)
        self.inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
        self.max_pos = config.max_position_embeddings * 2
        self._set_cos_sin_cache(self.max_pos)

    def _set_cos_sin_cache(self, seq_len):
        t = torch.arange(seq_len, dtype=torch.float32)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)

    def forward(self, input_ids):
        bsz, seq_len = input_ids.shape
        x = self.embed_tokens(input_ids)
        
        if self.cos_cached.device != x.device or self.cos_cached.shape[0] < seq_len:
            self.inv_freq = self.inv_freq.to(x.device)
            self._set_cos_sin_cache(max(seq_len, 2048))
            
        cos = self.cos_cached[:seq_len].to(dtype=x.dtype)
        sin = self.sin_cached[:seq_len].to(dtype=x.dtype)
        
        mask = None
        if seq_len > 1:
            mask = torch.full((seq_len, seq_len), float("-inf"), device=input_ids.device)
            mask = torch.triu(mask, diagonal=1)
            
        for layer in self.layers:
            x = layer(x, cos, sin, mask)
            
        x = self.norm(x)
        logits = self.lm_head(x)
        return logits
"""

with open("model.py", "w") as f:
    f.write(code)
    
print("model.py created.")

model.py created.
