[standalone-qwen3.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/11_qwen3/standalone-qwen3.ipynb)

In [28]:
from importlib.metadata import version
pkg = [
    'huggingface_hub',
    'tokenizers',
    'torch'
]
for p in pkg:
    print(f"{p} version: {version(p)}")

huggingface_hub version: 0.30.1
tokenizers version: 0.21.1
torch version: 2.3.1


In [29]:
USE_BASE_MODEL = False
USE_RESONING_MODEL = True
USE_INSTRUCT_MODEL = False

if (USE_BASE_MODEL + USE_RESONING_MODEL + USE_INSTRUCT_MODEL) != 1:
    raise ValueError("Exactly one of USE_BASE_MODEL, USE_RESONING_MODEL, or USE_INSTRUCT_MODEL must be True")

In [30]:
import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"],cfg["hidden_dim"],dtype=cfg["dtype"],bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"],cfg["hidden_dim"],dtype=cfg["dtype"],bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"],cfg["emb_dim"],dtype=cfg["dtype"],bias=False)

    def forward(self,x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [31]:
class RMSNorm(nn.Module):
    def __init__(self,emb_dim,eps=1e-6,bias=False,qwen3_compatible=True):
        super().__init__()
        self.eps = eps
        self.qwen3_compatible = qwen3_compatible
        self.scale = nn.Parameter(torch.one(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None

    def forward(self,x):
        input_dtype = x.dtype

        if self.qwen3_compatible:
            x = x.to(torch.float32)
        
        variance = x.pow(2).mean(dim=-1,keepdim=True)
        norm_x = x * torch.rsqrt(variance + self.eps)
        norm_x = norm_x * self.scale

        if self.shift is not None:
            norm_x = norm_x * self.shift
        
        return norm_x.to(input_dtype)

In [32]:
dtype = torch.float16
theta_base=10_000
head_dim = 10

torch.arange(0, head_dim, 2, dtype=dtype)
#inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
#inv_freq


tensor([0., 2., 4., 6., 8.], dtype=torch.float16)

In [33]:
context_length=4096
theta_base=10_000
head_dim=10
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
positions = torch.arange(context_length, dtype=dtype)
print("positions:",positions.unsqueeze(1),positions.unsqueeze(1).shape)   
print("inv_freq:",inv_freq.unsqueeze(0),inv_freq.unsqueeze(0).shape)
print("rope:",positions.unsqueeze(1) * inv_freq.unsqueeze(0),(positions.unsqueeze(1) * inv_freq.unsqueeze(0)).shape)

positions: tensor([[0.0000e+00],
        [1.0000e+00],
        [2.0000e+00],
        ...,
        [4.0920e+03],
        [4.0940e+03],
        [4.0960e+03]], dtype=torch.float16) torch.Size([4096, 1])
inv_freq: tensor([[1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]]) torch.Size([1, 5])
rope: tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04],
        [2.0000e+00, 3.1698e-01, 5.0238e-02, 7.9621e-03, 1.2619e-03],
        ...,
        [4.0920e+03, 6.4854e+02, 1.0279e+02, 1.6291e+01, 2.5819e+00],
        [4.0940e+03, 6.4886e+02, 1.0284e+02, 1.6299e+01, 2.5831e+00],
        [4.0960e+03, 6.4917e+02, 1.0289e+02, 1.6306e+01, 2.5844e+00]]) torch.Size([4096, 5])


In [34]:
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))

    # Generate position indices
    positions = torch.arange(context_length, dtype=dtype)

    # Compute the angles
    angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin


def apply_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    # It's ok to use lower-precision after applying cos and sin rotation
    return x_rotated.to(dtype=x.dtype)

这段代码对应的公式可以理解为用同一组角度生成余弦和正弦两个分量，从而把原本只有 head_dim/2 个角度扩展到 head_dim 维度：

原始 angles：由 positions 与逆频率相乘得到，形状为 (context_length, head_dim/2)，对应旋转嵌入中的基本角度 θ.

torch.cat([angles, angles], dim=1)：把同样的角度矩阵在列方向上拼接两次，相当于得到 [θ, θ]。在后续计算中，一半会用于 cos(θ)，另一半会用于 sin(θ)，从而配合二维旋转公式：

[x_even’, x_odd’] = [x_even * cos(θ) - x_odd * sin(θ), 
                     x_even * sin(θ) + x_odd * cos(θ)]
这个拼接确保准备好的角度矩阵与头维度 (head_dim) 匹配，使每个偶数和奇数通道都能得到对应的 θ。

In [35]:
class GroupedQueryAttention(nn.Module):
    def __init__(
        self,d_in,num_heads,num_kv_groups,head_dim=None,qk_norm=False,dtype=None      
    ):
        super().__init__()
        assert num_heads % num_kv_groups == 0,"num_heads must be divisble by num_by_groups"

        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        if head_dim is None:
            assert d_in % num_heads == 0 , "`d_in` must be divisible by `num_heads` if `head_dim` is not set"
            head_dim = d_in // num_heads
        
        self.head_dim = head_dim
        self.d_out = num_heads * head_dim

        self.W_query = nn.Linear(d_in,self.d_out,bias=False,dtype=dtype)
        self.W_key = nn.Linear(d_in,num_kv_groups*head_dim,bias=False,dtype=dtype)
        self.W_value = nn.Linear(d_in,num_kv_groups*head_dim,bias=False,dtype=dtype)

        self.out_proj = nn.Linear(self.d_out,d_in,bias=False,dtype=dtype)

        if qk_norm:
            self.q_norm = RMSNorm(head_dim,eps=1e-6)
            self.k_norm = RMSNorm(head_dim,eps=1e-6)
        else:
            self.q_norm = self.k_norm = None

    def forward(self,x,mask,cos,sin):
        b,num_tokens,_ = x.shape

        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        queries = queries.view(b,num_tokens,self.num_heads,self.head_dim).transpose(1,2)
        keys = keys.view(b,num_tokens,self.num_kv_groups,self.head_dim).transpose(1,2)
        values = values.view(b,num_tokens,self.num_kv_groups,self.head_dim).transpose(1,2)

        if self.q_norm:
            queries = self.q_norm(queries)
        if self.k_norm:
            keys = self.k_norm(keys)

        queries = apply_rope(queries,cos,sin)
        keys = apply_rope(keys,cos,sin)

        keys = keys.repeat_interleave(self.group_size,dim=1)
        values = values.repeat_interleave(self.group_size,dim=1)

        attn_scores = queries @ keys.transpose(2,3)
        attn_scores = attn_scores.masked_fill(mask,-torch.inf)
        attn_weights = torch.softmax(attn_scores / self.head_dim**0.5,dim=-1)

        context = (attn_weights @ values).transpose(1,2).reshape(b,num_tokens,self.d_out)
        return self.out_proj(context)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.att = GroupedQueryAttention(
            d_in=cfg["emb_dim"],
            num_heads=cfg["n_heads"],
            head_dim=cfg["head_dim"]
            num_kv_groups=cfg["num_kv_groups"],
            qk_norm=cfg["qk_norm"],
            dtype=cfg["dtype"]                                                             
        )
        self.ff = FeedForward(cfg)
        self.norm1 = RMSNorm(cfg["emb_dim"],eps=1e-6)
        self.norm2 = RMSNorm(cfg["emb_dim"],eps=1e-6)

    def forward(self,x,mask,cos,sin):
        shortcut = x 
        x = self.norm1(x)
        x = self.att(x,mask,cos,sin)
        x = x + shortcut


        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = x + shortcut

        return x