 [standalone-qwen3-plus-kvcache.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb)

In [1]:
from importlib.metadata import version

pkgs = [
    "huggingface_hub",
    "tokenizers",
    "torch",
]
for p in pkgs:
    print(f"{p}: {version(p)}")


huggingface_hub: 0.30.1
tokenizers: 0.21.1
torch: 2.3.1


In [2]:
USE_BASE_MODEL = False
USE_REASONING_MODEL = True
USE_INSTRUCT_MODEL = False

if USE_BASE_MODEL + USE_REASONING_MODEL + USE_INSTRUCT_MODEL != 1:
    raise AttributeError("Exactly one of USE_BASE_MODEL, USE_REASONING_MODEL, " \
        "USE_INSTRUCT_MODEL must be True.")

In [3]:
import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"],cfg["hidden_dim"],dtype=cfg["dtype"],bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"],cfg["hidden_dim"],dtype=cfg["dtype"],bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"],cfg["emb_dim"],dtype=cfg["dtype"],bias=False)

    def forward(self,x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [4]:
class RMSNorm(nn.Module):
    def __init__(self,emb_dim,eps=1e-6,bias=False,qwen3_compatible=True):
        super().__init__()
        self.eps = eps
        self.qwen3_compatible = qwen3_compatible
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
    
    def forward(self,x):
        input_type = x.dtype

        if self.qwen3_compatible:
            x = x.to(torch.float32)

        variance = x.pow(2).mean(dim=-1,keepdim=True)
        norm_x = x * torch.rsqrt(variance + self.eps)
        norm_x = norm_x * self.scale

        if self.shift is not None:
            norm_x = norm_x + self.shift
        
        return norm_x.to(input_type)

In [5]:
def compute_rope_params(head_dim,theta_base=10_000,context_length=4096,dtype=torch.float32):
    assert head_dim % 2 ==0,"embedding dimension must be even for RoPE"

    inv_freq = 1.0 / (theta_base ** (torch.arange(0,head_dim,2,dtype=dtype)[:(head_dim//2)].float()/head_dim))
    positions = torch.arange(context_length,dtype=dtype)
    angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)
    angles = torch.cat([angles,angles],dim=1)

    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos,sin

In [6]:
def apply_rope(x,cos,sin,offset=0):
    batch_size,num_heads,seq_len,head_dim = x.shape
    assert head_dim % 2 ==0,"embedding dimension must be even for RoPE"

    x1 = x[...,:head_dim//2]
    x2 = x[...,head_dim//2:]

    cos = cos[offset:offset+seq_len,:].unsqueeze(0).unsqueeze(0)
    sin = sin[offset:offset+seq_len,:].unsqueeze(0).unsqueeze(0)

    rotated = torch.cat([-x2,x1],dim=-1)
    x_rotated = x * cos + rotated * sin

    return x_rotated.to(dtype=x.dtype)

In [None]:
class GroupedQueryAttention(nn.Module):
    def __init(
            self,d_in,num_heads,num_kv_groups,head_dim=None,qk_norm=False,dtype=None
    ):
        super().__init__()
        assert num_heads % num_kv_groups == 0,"num_heads must be divisible by num_kv_groups"

        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        if head_dim is None:
            assert d_in % num_heads == 0,"d_in must be divisible by num_heads if head_dim is not specified"
            head_dim = d_in // num_heads
        
        self.head_dim = head_dim
        self.d_out = num_heads * head_dim

        self.W_query = nn.Linear(d_in,self.d_out,bias=False,dtype=dtype)
        self.W_key = nn.Linear(d_in,self.num_kv_groups*head_dim,bias=False,dtype=dtype)
        self.W_value = nn.Linear(d_in,self.num_kv_groups*head_dim,bias=False,dtype=dtype)

        self.out_proj = nn.Linear(self.d_out,d_in,bias=False,dtype=dtype)

        if qk_norm:
            self.q_norm = RMSNorm(head_dim,eps=1e-6)
            self.k_norm = RMSNorm(head_dim,eps=1e-6)
        else:
            self.q_norm = self.k_norm = None
    def forward(self,x,mask,cos,sin,start_pos=0,cache=None) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor]]:
        b,num_tokens,_ = x.shape

        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        queries = queries.view(b,num_tokens,self.num_heads,self.head_dim).transpose(1,2)
        keys_new = keys.view(b,num_tokens,self.num_kv_groups,self.head_dim).transpose(1,2)
        values_new = values.view(b,num_tokens,self.num_kv_groups,self.head_dim).transpose(1,2)

        if self.q_norm:
            queries = self.q_norm(queries)
        if self.k_norm:
            keys_new = self.k_norm(keys_new)
        
        if cache is not None:
            prev_key,prev_v = cache
            keys = torch.cat([prev_key,keys_new],dim=2)
            values = torch.cat([prev_v,values_new],dim=2)
            next_cache = (keys,values)
        else:
            start_pos = 0
            keys,values = keys_new,values_new
            next_cache = (keys,values)

        keys = keys.repeat_interleave(self.group_size,dim=1)
        values = values.repeat_interleave(self.group_size,dim=1)

        attn_scores = queries @ keys.transpose(2,3)
        attn_scores = attn_scores.masked_fill(mask,-torch.inf)
        attn_weights = torch.softmax(attn_scores/self.head_dim**0.5,dim=-1)

        context = (attn_weights @ values).transpose(1,2).reshape(b,num_tokens,self.d_out)
        return self.out_proj(context),next_cache


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self,cfg):
        super()._init__()
        self.att = GroupedQueryAttention(
            d_in=cfg["emb_dim"],
            num_heads=cfg["n_heads"],
            num_kv_groups=cfg["n_kv_groups"],
            qk_norm=cfg["qk_norm"],
            dtype=cfg["dtype"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = RMSNorm(cfg["emb_dim"],eps=1e-6)
        self.norm2 = RMSNorm(cfg["emb_dim"],eps=1e-6)
    
    def forward(self,x,mask,cos,sin,start_pos=0,cache=None):
        shortcut = x
        x = self.norm1(x)
        x,next_cache = self.att(x,mask,cos,sin,start_pos,cache)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = x + shortcut

        return x,next_cache


In [None]:
class Qweb3Model(nn.Module):
    def __init__(self,cfg):
        super().__init__()

        self.tok_emb = nn.Embedding(cfg["vocab_size"],cfg["emb_dim"],dtype=cfg["dtype"])

        self.trf_block = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )
        self.final_norm = RMSNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"],cfg["vocab_size"],bias=False,dtype=cfg["dtype"])

        if cfg["head_dim"] is None:
            head_dim = cfg["emb_dim"] // cfg["n_heads"]
        else:
            head_dim = cfg["head_dim"]
        cos,sin = compute_rope_params(
            head_dim = head_dim,
            theta_base = cfg["rope_base"],
            context_length = cfg["context_length"]
        )
        self.register_buffer("cos",cos,persistent = False)
        self.register_buffer("sin",sin,persistent = False)
        self.cfg = cfg
        self.current_pos = 0
    
    def forward(self,in_idx,cache=None):
        tok_embeds= self.tok_emb(in_idx)
        x = tok_embeds

        num_tokens = x.shape[1]

        if cache is not None:
            pos_start = self.current_pos
            pos_end = pos_start + num_tokens
            self.current_pos = pos_end
            mask = torch.triu(torch.ones(pos_end,pos_end),device = x.device,
                              dtype=torch.bool,diagonal=1
                              )[pos_start:pos_end,:pos_end]
        else:
            pos_start = 0 
            mask = torch.triu(
                torch.ones(num_tokens,num_tokens,device=x.device,
                           dtype=torch.bool),diagonal=1
            )
        mask = mask[None,None,:,:]

        for i,block in enumerate(self.trf_block):
            blk_cache = cache.get(i) if cache else None
            x,new_blk_cache = block(x,mask,self.cos,self.sin,
                                    start_pos=pos_start,cache=blk_cache)
            if cache is not None:
                cache.update(i,new_blk_cache)

        x = self.final_norm(x)
        logits = self.out_head(x.to(self.cfg["dtype"]))
        return logits

    def reset_kv_cache(self):
        self.current_pos = 0            

In [None]:
class KVCache:
    def __init__(self,n_layers):
        self.cache = [None] * n_layers

    def get(self,layer_idx):
        return self.cache[layer_idx]
    
    def update(self,layer_idx,value):
        self.cache[layer_idx] = value

    def get_all(self):
        return self.cache

    def reset(self):
        for i in range(len(self.cache)):
            self.cache[i] = None