 [standalone-qwen3-plus-kvcache.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb)

In [4]:
from importlib.metadata import version

pkgs = [
    "huggingface_hub",
    "tokenizers",
    "torch",
]
for p in pkgs:
    print(f"{p}: {version(p)}")


huggingface_hub: 0.30.1
tokenizers: 0.21.1
torch: 2.3.1


In [5]:
USE_BASE_MODEL = False
USE_REASONING_MODEL = True
USE_INSTRUCT_MODEL = False

if USE_BASE_MODEL + USE_REASONING_MODEL + USE_INSTRUCT_MODEL != 1:
    raise AttributeError("Exactly one of USE_BASE_MODEL, USE_REASONING_MODEL, " \
        "USE_INSTRUCT_MODEL must be True.")

In [7]:
import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"],cfg["hidden_dim"],dtype=cfg["dtype"],bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"],cfg["hidden_dim"],dtype=cfg["dtype"],bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"],cfg["emb_dim"],dtype=cfg["dtype"],bias=False)

    def forward(self,x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [None]:
class RMSNorm(nn.Module):
    def __init__(self,emb_dim,eps=1e-6,bias=False,qwen3_compatible=True):
        super().__init__()
        self.eps = eps
        self.qwen3_compatible = qwen3_compatible
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
    
    def forward(self,x):
        input_type = x.dtype

        if self.qwen3_compatible:
            x = x.to(torch.float32)

        variance = x.pow(2).mean(dim=-1,keepdim=True)
        norm_x = x * torch.rsqrt(variance + self.eps)
        norm_x = norm_x * self.scale

        if self.shift is not None:
            norm_x = norm_x + self.shift
        
        return norm_x.to(input_type)

In [None]:
def compute_rope_params(head_dim,theta_base=10_000,context_length=4096,dtype=torch.float32):
    assert head_dim % 2 ==0,"embedding dimension must be even for RoPE"

    inv_freq = 1.0 / (theta_base ** (torch.arange(0,head_dim,2,dtype=dtype)[:(head_dim//2)].float()/head_dim))
    positions = torch.arange(context_length,dtype=dtype)
    angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)
    angles = torch.cat([angles,angles],dim=1)

    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos,sin

In [None]:
def apply_rope(x,cos,sin,offset=0):
    batch_size,num_heads,seq_len,head_dim = x.shape
    assert head_dim % 2 ==0,"embedding dimension must be even for RoPE"

    x1 = x[...,:head_dim//2]
    x2 = x[...,head_dim//2:]

    cos = cos[offset:offset+seq_len,:].unsqueeze(0).unsqueeze(0)
    sin = sin[offset:offset+seq_len,:].unsqueeze(0).unsqueeze(0)

    rotated = torch.cat([-x2,x1],dim=-1)
    x_rotated = x * cos + rotated * sin

    return x_rotated.to(dtype=x.dtype)