# Qwen3 From Scratch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Config

In [2]:
from dataclasses import dataclass
import torch

@dataclass()
class Qwen3Config:
    vocab_size=151936
    hidden_size=1024
    intermediate_size=3072
    num_hidden_layers=28
    num_attention_heads=16
    num_key_value_heads=8
    attention_bias=False
    head_dim=128
    hidden_act="silu"
    max_position_embeddings=40_960
    rms_norm_eps=1e-6
    tie_word_embeddings=False
    rope_theta=10000.0
    dtyp=torch.bfloat16

## Model

In [3]:
class Qwen3Model(nn.Module):
    def __init__(self, config: Qwen3Config):
        super().__init__()
        self.config = config
        
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_hidden_layers)
        ])
        
        self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        
    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, cu_lens: torch.Tensor) -> torch.Tensor:
        # input_ids shape: [seq_len]
        
        # [seq_len, hidden_size]
        x = self.embed_tokens(input_ids)

        for layer in self.layers:
            x = layer(x, positions, cu_lens)

        # shape not change
        x = self.norm(x)

        # [seq_len, hidden_size]
        return x

class Qwen3ForCausalLM(nn.Module):
    def __init__(self, config: Qwen3Config):
        super().__init__()
        
        self.model = Qwen3Model(config)
        
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, cu_lens: torch.Tensor) -> torch.Tensor:

        # [seq_len, hidden_size]
        x = self.model(input_ids, positions, cu_lens)

        # [seq_len, vocab_size]
        # extract the last token of each sequence
        x = self.lm_head(x[cu_lens[1:]-1, :])

        return x

## Transformer Block

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, config: Qwen3Config):
        super().__init__()

        self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.self_attn = Qwen3Attention(config)

        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.mlp = Qwen3MLP(config)

    def forward(self, x, positions: torch.Tensor, cu_lens: torch.Tensor):
        shortcut = x
        x = self.input_layernorm(x)
        x = self.self_attn(x, positions, cu_lens)
        x = x + shortcut

        shortcut = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = x + shortcut

        return x        

### Qwen3Attention

In [5]:
class Qwen3Attention(nn.Module):
    def __init__(self, config: Qwen3Config):
        super().__init__()

        self.config = config

        self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=config.attention_bias)

        self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=config.attention_bias)

        self.q_norm = nn.RMSNorm(config.head_dim, eps=config.rms_norm_eps)
        self.k_norm = nn.RMSNorm(config.head_dim, eps=config.rms_norm_eps)

        self.scale = self.config.head_dim**-0.5

        self.rotary_embedding = RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta)

    def forward(self, x, positions: torch.Tensor, cu_lens: torch.Tensor):
        seqlen, _ = x.shape
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.view(seqlen, self.config.num_attention_heads, self.config.head_dim)
        k = k.view(seqlen, self.config.num_key_value_heads, self.config.head_dim)
        v = v.view(seqlen, self.config.num_key_value_heads, self.config.head_dim)

        # [qhead, seqlen, head_dim]
        q = q.transpose(0, 1)
        # [kvhead, seqlen, head_dim]
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)

        
        q = self.q_norm(q)
        k = self.k_norm(k)

        
        q = self.rotary_embedding(q, positions)
        k = self.rotary_embedding(k, positions)


        group_size = self.config.num_attention_heads // self.config.num_key_value_heads
        
        """
        [1,2,3].repeat_interleave(2, dim=0) => [1, 1, 2, 2, 3, 3]
        """
        # [kv_head * group_size, seqlen, head_dim]
        k = k.repeat_interleave(group_size, dim=0)
        v = v.repeat_interleave(group_size, dim=0)


        o = []
        for i in range(cu_lens.shape[0] - 1):
            start = cu_lens[i].item()
            end = cu_lens[i + 1].item()
            seqlen_i = end - start

            q_i = q[:,start:end,:]
            k_i = k[:,start:end,:]
            v_i = v[:,start:end,:]

            scores = q_i @ k_i.transpose(-2, -1)

            mask = torch.triu(torch.ones(seqlen_i, seqlen_i, device=x.device, dtype=torch.bool), diagonal=1)
            scores = scores.masked_fill(mask, -torch.inf)
            scores = scores * self.scale
            weights = F.softmax(scores, dim=-1)

            # [q_head, seqlen, head_dim]
            o_i = weights @ v_i

            # [seqlen, q_head, head_dim]
            o_i = o_i.transpose(0, 1)
            
            o.append(o_i.flatten(1))
        
        o = torch.concat(o, dim=0)
        
        out = self.o_proj(o)
        return out

## Rotary Embedding

In [6]:
class RotaryEmbedding(nn.Module):
    def __init__(
            self,
            dim: int,
            max_position_embeddings: int,
            rope_theta: float,
    ) -> None:
        super().__init__()
        
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.rope_theta = rope_theta

        # 1 / theta^(0, 2, 4, ..., dim-2) / dim
        inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))

        # position: [max_position_embeddings]
        position = torch.arange(max_position_embeddings, dtype=torch.float)

        # freqs: [max_position_embeddings, dim/2]
        freqs = torch.einsum("i,j->ij", position, inv_freq)

        # cos: [max_position_embeddings, dim/2]
        self.register_buffer("cos_cached", torch.cos(freqs))
        # sin: [max_position_embeddings, dim/2]
        self.register_buffer("sin_cached", torch.sin(freqs))
        

    def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
        # x: [heads, seqlen, head_dim]
        seqlen = x.size(1)
    
        cos = self.cos_cached[positions]
        sin = self.sin_cached[positions]
        cos = cos.unsqueeze(0)  # [1, seqlen, dim/2]
        sin = sin.unsqueeze(0)  # [1, seqlen, dim/2]
    
        x1, x2 = torch.chunk(x.float(), 2, dim=-1)
        x_rotated = torch.zeros_like(x)
        x_rotated[..., 0::2] = x1 * cos - x2 * sin
        x_rotated[..., 1::2] = x2 * cos + x1 * sin
        return x_rotated

## MLP

In [7]:
class Qwen3MLP(nn.Module):
    def __init__(self, config: Qwen3Config):
        super().__init__()

        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        x1 = self.gate_proj(x)
        x2 = self.up_proj(x)
        x = self.act_fn(x1) * x2
        x = self.down_proj(x)
        return x


## Initialize model

In [8]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")


model = Qwen3ForCausalLM(Qwen3Config())
model.to(device);

## Load pretrained weights

In [9]:
import os
from safetensors import safe_open
import glob

huggingface_model_dir = '~/huggingface/Qwen3-0.6B/'

def load_weight(huggingface_model_dir, model):
    params = dict(model.named_parameters())
    
    for file in glob.glob(os.path.join(path, "*.safetensors")):
        with safe_open(file, "pt", "cpu") as f:
            for name in f.keys():
                weight = f.get_tensor(name)
                assert name in params, f"Parameter {name} not found in model"
                param = params[name]
                param.data.copy_(weight)

In [10]:
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
load_weight(path, model)

## Tokenizer

In [11]:
import tokenizers

qwen3_tokenizer = tokenizers.Tokenizer.from_file(path + "tokenizer.json")

In [12]:
qwen3_tokenizer.encode("Hello, world").ids

[9707, 11, 1879]

In [13]:
qwen3_tokenizer.decode(qwen3_tokenizer.encode("Hello, world").ids)

'Hello, world'

## Generate text

In [14]:
def apply_chat_template(prompt: str, enable_think: bool = False) -> str:
    prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
    if enable_think is False:
        prompt += "<think>\n\n</think>\n"
    return prompt

In [15]:
apply_chat_template("What is the meaning of life?")

'<|im_start|>user\nWhat is the meaning of life?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n'

In [16]:
class Request:
    def __init__(self, tokens):
        self.tokens = tokens

In [17]:
def generate_one_step(model, requests: list[Request]):
    tokens = []
    positions = []
    cu_lens = [0]

    for req in requests:
        tokens.extend(req.tokens)
        positions.extend(range(len(req.tokens)))
        cu_lens.append(cu_lens[-1] + len(req.tokens))

    tokens = torch.tensor(tokens, dtype=torch.long, device=device)
    positions = torch.tensor(positions, dtype=torch.long, device=device)
    cu_lens = torch.tensor(cu_lens, dtype=torch.long, device=device)
    
    # [len(reqs), vocob_size]
    logits = model(tokens, positions, cu_lens)
    next_tokens = torch.argmax(logits, dim=-1)

    return next_tokens.tolist()

In [18]:
def generate(model: Qwen3ForCausalLM, tokenizer, prompts: list[str], enable_think=True, max_new_tokens=64):
    requests = []
    for prompt in prompts:
        prompt = apply_chat_template(prompt, enable_think)
        tokens = qwen3_tokenizer.encode(prompt).ids
        req = Request(tokens)
        requests.append(req)
    
    eos_token = tokenizer.encode("<|im_end|>").ids[0]

    new_tokens = 0;
    while len(requests) and new_tokens < max_new_tokens:
        new_tokens += 1

        tokens = generate_one_step(model, requests)
        for req, token in zip(requests, tokens):
            req.tokens.append(token)
        
        if new_tokens % 10 == 0:
            print("==================================================")
            for i, req in enumerate(requests):
                print(f"request {i}:\n", tokenizer.decode(req.tokens, skip_special_tokens=False))
                        
        # remove finished requests
        requests = [req for req in requests if req.tokens[-1] != eos_token]
        
    return requests

In [19]:
prompts = [
    "What is the meaning of life?",
    "How do I get started with LLMs?"
]
reqs = generate(model, qwen3_tokenizer, prompts)

request 0:
 <|im_start|>user
What is the meaning of life?<|im_end|>
<|im_start|>assistant
<think>
Okay, the user is asking about the
request 1:
 <|im_start|>user
How do I get started with LLMs?<|im_end|>
<|im_start|>assistant
<think>
Okay, the user is asking how to
request 0:
 <|im_start|>user
What is the meaning of life?<|im_end|>
<|im_start|>assistant
<think>
Okay, the user is asking about the meaning of life. First, I need to consider
request 1:
 <|im_start|>user
How do I get started with LLMs?<|im_end|>
<|im_start|>assistant
<think>
Okay, the user is asking how to get started with LLMs. Let me break
request 0:
 <|im_start|>user
What is the meaning of life?<|im_end|>
<|im_start|>assistant
<think>
Okay, the user is asking about the meaning of life. First, I need to consider different perspectives. The user might be looking for a
request 1:
 <|im_start|>user
How do I get started with LLMs?<|im_end|>
<|im_start|>assistant
<think>
Okay, the user is asking how to get started with LLMs. L