In [1]:
import torch 
import torch.nn as nn 

In [None]:
QWEN3_CONFIG = {
    "d_model": xx,
    "hidden_dim": xx,
    "dtype": xx,
    "qkv_bias": False,
}

In [None]:
class FeedForward(nn.Module):
    """
    SwiGLU(x, W1, W2, W3) = W3(SiLU(W1x) ⊙ W2x) 
    """
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["d_model"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=cfg["qkv_bias"])
        self.fc2 = nn.Linear(cfg["d_model"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=cfg["qkv_bias"])
        self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["d_model"], dtype=cfg["dtype"], bias=cfg["qkv_bias"])

    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, cfg, eps: float = 1e-6, bias: bool = False):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(cfg["d_model"]))
        self.shift = nn.Parameter(torch.zeros(cfg["d_model"])) if bias else None

    def forward(self, x):
        input_dtype = x.dtype
        var = x.pow(2).mean(dim=-1, keepdim=True)
        norm_x = x * torch.rsqrt(var + self.eps)
        norm_x *= self.scale
        if self.shift is not None:
            norm_x += self.shift
        return norm_x.to(input_dtype)

In [None]:
class RoPE(nn.Module):
   def __init__(self, head_dim: int, theta_base: int = 10_000, context_length: int = 4096, dtype=torch.float32):
        assert head_dim % 2 == 0, "Embedding dimension must be even!"

        # theta = theta_base ** (-2i/d)
        # where d = d_model; i = 0 -> d/2
        token_pos = torch.arange(0, head_dim, 2, dtype=dtype)
        # inv_freq = 1.0 / (theta_base ** token_pos[:(head_dim//2)].float() / head_dim)
        inv_freq = 1.0 / theta_base ** (token_pos.float() / head_dim)
        positions = torch.arange(context_length, dtype=dtype)         
        angles = positions[:,None] * inv_freq[None,:] # (context_length, head_dim // 2)
        angles = torch.cat([angles, angles], dim=1) # (context_length, head_dim)
        
        self.cos = torch.cos(angles)
        self.sin = torch.sin(angles)

    def forward(self, x):
        # only in MHSA, so the input shape is: 
        # (batch_size, n_heads, seq_len, head_dim)
        batch_size, n_heads, seq_len, head_dim = x.shape

        x1 = x[..., :head_dim//2] # first half
        x2 = x[..., head_dim//2:] # second half

        cos = self.cos[None,None,:seq_len,:]
        sin = self.sin[None,None,:seq_len,:]

        rotated = torch.cat((-x2, x1) ,dim=-1)
        x_rotated = x * cos + rotated * sin
        return x_rotated.to(dtype=x.dtype)

In [11]:
torch.cat([
    torch.arange(0,10,2)[None,:],
    torch.arange(0,10,2)[None,:],
], dim=1)

tensor([[0, 2, 4, 6, 8, 0, 2, 4, 6, 8]])