[standalone-qwen3.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/11_qwen3/standalone-qwen3.ipynb)

In [9]:
from importlib.metadata import version
pkg = [
    'huggingface_hub',
    'tokenizers',
    'torch'
]
for p in pkg:
    print(f"{p} version: {version(p)}")

huggingface_hub version: 0.30.1
tokenizers version: 0.21.1
torch version: 2.3.1


In [10]:
USE_BASE_MODEL = False
USE_RESONING_MODEL = True
USE_INSTRUCT_MODEL = False

if (USE_BASE_MODEL + USE_RESONING_MODEL + USE_INSTRUCT_MODEL) != 1:
    raise ValueError("Exactly one of USE_BASE_MODEL, USE_RESONING_MODEL, or USE_INSTRUCT_MODEL must be True")

In [11]:
import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"],cfg["hidden_dim"],dtype=cfg["dtype"],bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"],cfg["hidden_dim"],dtype=cfg["dtype"],bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"],cfg["emb_dim"],dtype=cfg["dtype"],bias=False)

    def forward(self,x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [12]:
class RMSNorm(nn.Module):
    def __init__(self,emb_dim,eps=1e-6,bias=False,qwen3_compatible=True):
        super().__init__()
        self.eps = eps
        self.qwen3_compatible = qwen3_compatible
        self.scale = nn.Parameter(torch.one(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None

    def forward(self,x):
        input_dtype = x.dtype

        if self.qwen3_compatible:
            x = x.to(torch.float32)
        
        variance = x.pow(2).mean(dim=-1,keepdim=True)
        norm_x = x * torch.rsqrt(variance + self.eps)
        norm_x = norm_x * self.scale

        if self.shift is not None:
            norm_x = norm_x * self.shift
        
        return norm_x.to(input_dtype)

In [13]:
dtype = torch.float16
theta_base=10_000
head_dim = 10

torch.arange(0, head_dim, 2, dtype=dtype)
#inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
#inv_freq


tensor([0., 2., 4., 6., 8.], dtype=torch.float16)

In [None]:
context_length=4096
theta_base=10_000
head_dim=10
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
positions = torch.arange(context_length, dtype=dtype)
print("positions:",positions.unsqueeze(1),positions.unsqueeze(1).shape)   
print("inv_freq:",inv_freq.unsqueeze(0),inv_freq.unsqueeze(0).shape)
print("rope:",positions.unsqueeze(1) * inv_freq.unsqueeze(0),positions.unsqueeze(1) * inv_freq.unsqueeze(0).shape)

positions: tensor([[0.0000e+00],
        [1.0000e+00],
        [2.0000e+00],
        ...,
        [4.0920e+03],
        [4.0940e+03],
        [4.0960e+03]], dtype=torch.float16) torch.Size([4096, 1])
inv_freq: tensor([[1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]]) torch.Size([1, 5])


In [None]:
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
    print("inv_freq:", inv_freq.shape)
    # Generate position indices
    positions = torch.arange(context_length, dtype=dtype)
    print("positions:", positions.shape)
    # Compute the angles
    angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)  # Shape: (context_length, head_dim // 2)
    print("angeles:", angles.shape)
    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin


def apply_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    # It's ok to use lower-precision after applying cos and sin rotation
    return x_rotated.to(dtype=x.dtype)

In [12]:
compute_rope_params(128)

inv_freq: torch.Size([64])
positions: torch.Size([4096])


(tensor([[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.5403,  0.6479,  0.7318,  ...,  1.0000,  1.0000,  1.0000],
         [-0.4161, -0.1604,  0.0709,  ...,  1.0000,  1.0000,  1.0000],
         ...,
         [-0.8799,  0.7803, -0.9998,  ...,  0.8079,  0.8547,  0.8904],
         [-0.8753,  0.0292, -0.7446,  ...,  0.8078,  0.8546,  0.8903],
         [-0.0660, -0.7424, -0.0900,  ...,  0.8077,  0.8546,  0.8903]]),
 tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 8.4147e-01,  7.6172e-01,  6.8156e-01,  ...,  1.5399e-04,
           1.3335e-04,  1.1548e-04],
         [ 9.0930e-01,  9.8705e-01,  9.9748e-01,  ...,  3.0799e-04,
           2.6670e-04,  2.3096e-04],
         ...,
         [ 4.7523e-01,  6.2535e-01,  1.9127e-02,  ...,  5.8938e-01,
           5.1911e-01,  4.5525e-01],
         [-4.8361e-01,  9.9957e-01, -6.6752e-01,  ...,  5.8951e-01,
           5.1922e-01,  4.5535e-01],
         [-9.9782e-