In [1]:
import torch 
import torch.nn as nn 
from einops import einsum, rearrange
torch.manual_seed(123);

In [2]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["d_model"], cfg["d_ff"], bias=False)
        self.fc2 = nn.Linear(cfg["d_model"], cfg["d_ff"], bias=False)
        self.fc3 = nn.Linear(cfg["d_ff"], cfg["d_model"], bias=False)

    def forward(self, x):
        return self.fc3(nn.functional.silu(self.fc1(x)) * self.fc2(x))

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps: float = 1e-6, bias=False):
        super().__init__()
        self.eps = eps 
        self.scale = nn.Parameter(torch.ones(d_model))
        self.shift = nn.Parameter(torch.zeros(d_model)) if bias else None
    
    def forward(self, x):
        input_dtype = x.dtype 
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x * torch.rsqrt(variance + self.eps)
        x_norm = x_norm * self.scale        

        if self.shift is not None:
            x_norm += self.shift

        return x_norm.to(input_dtype)

In [4]:
class RoPE(nn.Module):
    def __init__(self, head_dim: int, ctx_len: int, theta_base: int = 10_000):
        super().__init__()
        assert head_dim % 2 == 0, "head_dim must be even!"

        self.theta_base = theta_base
        self.head_dim = head_dim
        self.ctx_len = ctx_len 

        self._compute_rope()

    def _compute_rope(self):
        inv_freq = 1.0 / (self.theta_base ** torch.arange(0, self.head_dim, 2) / self.head_dim)
        pos = torch.arange(self.ctx_len)
        angles = einsum(pos, inv_freq, "n,d -> n d") # (batch_size, head_dim // 2)
        angles = torch.cat((angles, angles), dim=1) # (batch_size, head_dim)

        self.register_buffer("cos", angles.cos(), persistent=False)
        self.register_buffer("sin", angles.sin(), persistent=False)

    def _rotate_half(self, x):
        x1 = x[...,:self.head_dim//2] # (batch_size, n_heads, seq_len, head_dim // 2)
        x2 = x[...,self.head_dim//2:] # (batch_size, n_heads, seq_len, head_dim // 2)
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, x):
        cos = self.cos[None,None,:,:]
        sin = self.sin[None,None,:,:]
        return x * cos + self._rotate_half(x) * sin

In [5]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_in, d_out, n_heads, n_kv_groups, ctx_len):
        super().__init__()
        assert n_heads % n_kv_groups == 0, "n_heads must be divisible by n_kv_groups!"
        assert d_out % n_heads == 0, "d_out must be divisible by n_heads!"

        self.head_dim = d_out // n_heads
        self.group_size = n_heads // n_kv_groups

        self.W_q = nn.Linear(d_in, d_out, bias=False)
        self.W_k = nn.Linear(d_in, n_kv_groups * self.head_dim, bias=False)
        self.W_v = nn.Linear(d_in, n_kv_groups * self.head_dim, bias=False)
        self.out_proj = nn.Linear(d_out, d_out, bias=False)

        self.q_norm = RMSNorm(self.head_dim, eps=1e-6)
        self.k_norm = RMSNorm(self.head_dim, eps=1e-6)

        self.rope = RoPE(self.head_dim, ctx_len)

        mask = torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1)
        self.register_buffer("mask", mask)

    def forward(self, x):
        batch_size, seq_len, d_in = x.shape

        queries = self.W_q(x) # (batch_size, seq_len, d_out)
        keys    = self.W_k(x) # (batch_size, seq_len, n_kv_groups * head_dim)
        values  = self.W_v(x) # (batch_size, seq_len, n_kv_groups * head_dim)

        queries = rearrange(queries, "b s (n_heads head_dim) -> b n_heads s head_dim", head_dim=self.head_dim)
        keys    = rearrange(keys, "b s (n_kv_groups head_dim) -> b n_kv_groups s head_dim", head_dim=self.head_dim)
        values  = rearrange(values, "b s (n_kv_groups head_dim) -> b n_kv_groups s head_dim", head_dim=self.head_dim)

        queries = self.q_norm(queries)
        keys    = self.k_norm(keys)

        queries = self.rope(queries)
        keys    = self.rope(keys)

        keys = keys.repeat_interleave(self.group_size, dim=1) # (batch_size, n_heads, seq_len, head_dim)
        values = values.repeat_interleave(self.group_size, dim=1) # (batch_size, n_heads, seq_lne, head_dim)

        attn_scores = einsum(queries, keys, "... s1 head_dim, ... s2 head_dim -> ... s1 s2")
        attn_scores.masked_fill_(self.mask[:seq_len,:seq_len], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vecs = einsum(attn_weights, values, "... s1 s2, ... s2 head_dim -> ... s1 head_dim")
        context_vecs = rearrange(context_vecs, "batch_size n_heads seq_len head_dim -> batch_size seq_len (n_heads head_dim)")
        context_vecs = self.out_proj(context_vecs)
        return context_vecs


In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attn = GroupedQueryAttention(
            d_in=cfg["d_model"],
            d_out=cfg["d_model"],
            ctx_len=cfg["ctx_len"],
            n_heads=cfg["n_heads"],
            n_kv_groups=cfg["n_kv_groups"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = RMSNorm(cfg["d_model"], eps=1e-6)
        self.norm2 = RMSNorm(cfg["d_model"], eps=1e-6)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

In [7]:
class Qwen3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["d_model"])
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )

        self.final_norm = RMSNorm(cfg["d_model"])
        self.out_head = nn.Linear(cfg["d_model"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx):
        x = self.tok_emb(in_idx)
        for blk in self.transformer_blocks:
            x = blk(x)

        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

In [None]:
QWEN3_CONFIG = {
    "vocab_size": 151_936,           # Vocabulary size
    "ctx_len": 40_960,               # Context length that was used to train the model
    "d_model": 1024,                 # Embedding dimension
    "n_heads": 16,                   # Number of attention heads
    "n_layers": 28,                  # Number of layers
    "d_ff": 3072,                    # Size of the intermediate dimension in FeedForward
    "qk_norm": True,                 # Whether to normalize queries and values in GQA
    "n_kv_groups": 8,                # Key-Value groups for grouped-query attention
    "rope_base": 1_000_000.0,        # The base in RoPE's "theta"
    "dtype": torch.bfloat16,         # Lower-precision dtype to reduce memory usage
}

: 

In [None]:
torch.manual_seed(123)
model = Qwen3Model(QWEN3_CONFIG)

In [None]:
model