# How to complete a roverb
[original link](https://github.com/Ginjing-Yuan/QWen2-from_ground_up/blob/main/Deconstructing-QWen2-from-Ground-Up.ipynb)

# Data Initia

In [1]:
import torch
import json
import matplotlib.pyplot as plt
import math
from torch import nn
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [77]:
import base64
from IPython.display import Image, display
import matplotlib.pyplot as plt

def mm(graph):
  graphbytes = graph.encode("ascii")
  base64_bytes = base64.b64encode(graphbytes)
  base64_string = base64_bytes.decode("ascii")
  display(
    Image(
      url="https://mermaid.ink/img/"
      + base64_string
    )
  )


In [67]:
model_path = "Qwen/Qwen2-0.5B"

tokenizer=AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto')
#TORCH_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to('cpu')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
    )
    (norm): Qwen2RMSNorm()
  )
  (lm_head): Linear(in_features=8

## The Model File

In [68]:
#model

## Qwen2Model

* Encoder - Decoder for reference (Decoder Only for most of LLMs)


<!-- HTML for setting image size -->
<img src="https://raw.githubusercontent.com/seast/ft-lora/main/images/EncoderDecoder.jpg" alt="Encoder Decoder" width="500" height="300">

```
Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
    )
    (norm): Qwen2RMSNorm()
  )
  (lm_head): Linear(in_features=896, out_features=151936, bias=False)
)
```
* embed_tokens: This is an embedding layer with 151936 tokens (vocabulary size) and embeddings of size 896.
* layers: This is a stack of Qwen2DecoderLayer modules.
    * Qwen2DecoderLayer X 24:
        * self_attn: Self-attention mechanism (Qwen2Attention):
            * q_proj, k_proj, v_proj: size 896 and output size 128.
            * o_proj: Output projection of self-attention with input and output size 896.
            * rotary_emb: Rotary embedding for attention mechanism.
        * mlp: Multi-layer perceptron (Qwen2MLP) with:
            * gate_proj, up_proj, down_proj: Linear projections within the MLP.
            * act_fn: Activation function (SiLU).
        * input_layernorm, post_attention_layernorm: Layer normalization (Qwen2RMSNorm) after input and after attention.
* norm
    * Qwen2RMSNorm: RMSNorm for the entire model.
* lm_head
    * Linear: Linear layer for the language model head, transforming from size 896 to 151936 (output vocabulary size).
* Summary
    * This architecture suggests a transformer-based causal language model named Qwen2, with specific modifications in its attention and MLP layers (Qwen2Attention and Qwen2MLP). It uses RMS normalization and rotary embeddings, indicating it may have specialized enhancements compared to standard transformer architectures. The model operates on an input vocabulary of 151936 tokens and generates output predictions accordingly.

In [144]:
BASE_DIR = '/mlx_devbox/users/haidong.shao/playground/'
with open(BASE_DIR+"ft-lora/qwen2_0.5b_config.json", "r") as f:
    config = json.load(f)
#config

{'architectures': ['Qwen2ForCausalLM'],
 'attention_dropout': 0.0,
 'bos_token_id': 151643,
 'eos_token_id': 151643,
 'hidden_act': 'silu',
 'hidden_size': 896,
 'initializer_range': 0.02,
 'intermediate_size': 4864,
 'max_position_embeddings': 131072,
 'max_window_layers': 24,
 'model_type': 'qwen2',
 'num_attention_heads': 14,
 'num_hidden_layers': 24,
 'num_key_value_heads': 2,
 'rms_norm_eps': 1e-06,
 'rope_theta': 1000000.0,
 'sliding_window': 131072,
 'tie_word_embeddings': True,
 'torch_dtype': 'bfloat16',
 'transformers_version': '4.40.1',
 'use_cache': True,
 'use_sliding_window': False,
 'vocab_size': 151936}

## We will use these configs
* 24 hidden transformer layers
* 14 attention heads
* 2 kv heads and so on.

In [70]:
dim = config["hidden_size"]
n_layers = config["num_hidden_layers"]
n_heads = config["num_attention_heads"]
n_kv_heads = config["num_key_value_heads"]
vocab_size = config["vocab_size"]
norm_eps = config["rms_norm_eps"]
rope_theta = torch.tensor(config["rope_theta"])

## Convert text to tokens (tokenizer)

In [71]:
#prompt = "床前明月光，疑是地上"
prompt = "独在异乡为异客，每逢佳节"
tokens = tokenizer.encode(prompt)
q_len = len(tokens)
tokens

[99510, 18493, 62945, 99474, 17714, 62945, 64754, 3837, 118620, 100191, 55502]

In [72]:
# check the decode result
#tokenizer.decode(tokens)

In [73]:
tokens = torch.tensor(tokens)

In [74]:
model_data = model.state_dict()
#print(json.dumps(list(model_data.keys())[:20], indent=4))

embedding_layer = torch.nn.Embedding.from_pretrained(model_data['model.embed_tokens.weight'])
token_embeddings_unnormalized = embedding_layer(tokens)
#token_embeddings_unnormalized

## Normalize the embedding using root mean square(RMS) normalization

* [RMS paper](https://arxiv.org/abs/1910.07467)
    * set a norm_eps to avoid the formula dived by 0.

$$ \bar a_{i} = \frac{a_{i}} {RMS(a)}{g_{i}}, \ \ \ \text{where} \ \ \ \ RMS(a) = \sqrt{\frac{1}{n}\sum_{i=1}^{n}a_{i}^{2}}  $$


In [145]:
## Qwen2ForCausalLM Model Architecture
mm("""
graph LR;
classDef norm fill:#f9f,stroke:#333,stroke-width:4px;
    B["Tokens"] --> C["Embedding (151936, 896)"]
    C --> C1["Qwen2RMSNorm"]
    C1 --> D["Transformers"]
    D --> E["Qwen2RMSNorm"]
    E --> F["lm_head, Linear (896, 151936)"]
    
    class C1,E norm;
""")

mm("""
graph TD;
classDef norm fill:#f9f,stroke:#333,stroke-width:4px;

    B["Tokens"] --> C["Embedding (151936, 896)"]
    C --> |"input_layernorm"| I["Qwen2RMSNorm"]
    I --> D["Transformers"]
    D --> E["24 x Qwen2DecoderLayer"]
    E --> |"14 head self_attn"| F["Qwen2Attention"]
    E --> |"mlp"|G["Qwen2MLP"]
    A --> J["Linear (896, 151936)"]
    
    E --> |"post_attention_layernorm"| H1["Qwen2RMSNorm"]

    F --> K["q_proj (896, 896)"]
    F --> L["k_proj (896, 128)"]
    F --> M["v_proj (896, 128)"]
    F --> N["o_proj (896, 896)"]
    F --> O["rotary_emb Qwen2RotaryEmbedding"]

    G --> P["gate_proj (896, 4864)"]
    G --> Q["up_proj (896, 4864)"]
    G --> R["down_proj (4864, 896)"]
    G --> S["act_fn SiLU"]
    class H,H1,I norm;
""")

In [63]:
def rms_norm(tensor, norm_weights):
    return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights

token_embeddings = rms_norm(token_embeddings_unnormalized, model_data["model.layers.0.input_layernorm.weight"])
token_embeddings

tensor([[ 3.2241e-02,  6.1609e-02, -3.5329e-02,  ..., -4.8085e-02,
         -2.0993e-01, -1.1316e-02],
        [-2.1844e-01, -5.6187e-03,  3.2549e-03,  ..., -7.0803e-02,
          4.3080e-02,  8.5175e-02],
        [-1.0140e-01,  6.1088e-02,  4.0035e-02,  ..., -2.8165e-02,
          3.5223e-02, -1.7562e-02],
        ...,
        [ 4.6448e-02,  8.1038e-02,  2.3371e-02,  ..., -4.8037e-02,
         -1.6664e-01, -8.7087e-03],
        [ 3.5723e-02,  1.5408e-01, -2.0525e-02,  ..., -5.4048e-02,
          4.1399e-02,  2.7549e-02],
        [-5.9662e-02,  1.1820e-01, -3.7939e-05,  ...,  2.3686e-02,
         -7.0856e-02,  1.4492e-02]])

# Transformer Layer

Now, let's process the normalized inputs to Q,K,V

In [146]:
mm("""
graph TD;
classDef norm fill:#f9f,stroke:#333,stroke-width:4px;

    A["Qwen2ForCausalLM"] --> B["Qwen2Model"]
    B --> C["Embedding (151936, 896)"]
    C --> |"input_layernorm"| I["Qwen2RMSNorm"]
    I --> D["Transformers"]
    D --> E["24 x Qwen2DecoderLayer"]
    E --> |"14 head self_attn"| F["Qwen2Attention"]
    E --> |"mlp"|G["Qwen2MLP"]
    A --> J["Linear (896, 151936)"]
    
    E --> |"post_attention_layernorm"| H1["Qwen2RMSNorm"]

    F --> K["q_proj (896, 896)"]
    F --> L["k_proj (896, 128)"]
    F --> M["v_proj (896, 128)"]
    F --> N["o_proj (896, 896)"]
    F --> O["rotary_emb Qwen2RotaryEmbedding"]

    G --> P["gate_proj (896, 4864)"]
    G --> Q["up_proj (896, 4864)"]
    G --> R["down_proj (4864, 896)"]
    G --> S["act_fn SiLU"]
    class K,L,M,I norm;
""")

In [34]:
q_layer0 = model_data["model.layers.0.self_attn.q_proj.weight"]
k_layer0 = model_data["model.layers.0.self_attn.k_proj.weight"]
v_layer0 = model_data["model.layers.0.self_attn.v_proj.weight"]
o_layer0 = model_data["model.layers.0.self_attn.o_proj.weight"]
q_layer0_bias = model_data['model.layers.0.self_attn.q_proj.bias']
k_layer0_bias = model_data['model.layers.0.self_attn.k_proj.bias']
v_layer0_bias = model_data['model.layers.0.self_attn.v_proj.bias']

In [152]:
# token_embeddings ([11, 896])
# q_layer0 ([896, 896]), q_layer0_bias([896]), query_states([11, 896])
query_states = torch.matmul(token_embeddings, q_layer0.T)+q_layer0_bias
# k_layer0 ([128, 896]), k_layer0_bias([128]), key_states([11, 128])
key_states = torch.matmul(token_embeddings, k_layer0.T)+k_layer0_bias
# v_layer0 ([128, 896]), v_layer0_bias([128]), value_states([11, 128])
value_states = torch.matmul(token_embeddings, v_layer0.T)+v_layer0_bias

In [153]:
# dim 896, 14 head, head_dim dimension is 64
head_dim = dim//n_heads
# reformat to torch.Size([1, 14, 11, 64])
query_states = query_states.view(1, q_len, n_heads, head_dim).transpose(1, 2)
# for key and value, it si torch.Size([1, 2, 11, 64])
key_states = key_states.view(1, q_len, n_kv_heads, head_dim).transpose(1, 2)
value_states = value_states.view(1, q_len, n_kv_heads, head_dim).transpose(1, 2)

In [154]:
class Qwen2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )
rotary_emb = Qwen2RotaryEmbedding(
            64,
            max_position_embeddings=131072,
            base=rope_theta,
        )

In [155]:
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

In [156]:
cos, sin = rotary_emb(value_states, seq_len=q_len)
position_ids = torch.arange(q_len).view(1,q_len)

In [157]:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

In [158]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

In [159]:
key_states = repeat_kv(key_states, n_heads // n_kv_heads)
value_states = repeat_kv(value_states, n_heads // n_kv_heads)

In [160]:
attn_output = torch.nn.functional.scaled_dot_product_attention(
    query_states,
    key_states,
    value_states,
    attn_mask=None,
    dropout_p= 0.0,
    # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
    is_causal= True,
)

In [161]:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(1, q_len, dim)

In [162]:
output_states = torch.matmul(attn_output, o_layer0.T)
output_states

tensor([[[-0.0032, -0.0077, -0.0035,  ...,  0.0111,  0.0147, -0.0019],
         [-0.0011, -0.0016, -0.0039,  ...,  0.0068,  0.0041, -0.0052],
         [-0.0024, -0.0026, -0.0025,  ...,  0.0051, -0.0040, -0.0026],
         ...,
         [-0.0037, -0.0064,  0.0047,  ...,  0.0024,  0.0128,  0.0035],
         [-0.0013, -0.0026,  0.0034,  ...,  0.0039,  0.0052,  0.0048],
         [ 0.0041, -0.0081,  0.0054,  ..., -0.0015,  0.0415,  0.0064]]])

In [163]:
output_states = output_states+token_embeddings_unnormalized

In [164]:
second_normalized = rms_norm(token_embeddings_unnormalized, model_data["model.layers.0.post_attention_layernorm.weight"])

In [165]:
w1 = model_data[f"model.layers.0.mlp.gate_proj.weight"]
w2 = model_data[f"model.layers.0.mlp.down_proj.weight"]
w3 = model_data[f"model.layers.0.mlp.up_proj.weight"]
output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(second_normalized, w1.T)) * torch.matmul(second_normalized, w3.T), w2.T)

In [166]:
final_embedding = token_embeddings_unnormalized
x= 0
for layer in range(n_layers):
    x+=1
    residual1 = final_embedding
    
    # embeding norm
    layer_embedding_norm = rms_norm(final_embedding, model_data[f"model.layers.{layer}.input_layernorm.weight"])
    
    q_layer = model_data[f"model.layers.{layer}.self_attn.q_proj.weight"]
    k_layer = model_data[f"model.layers.{layer}.self_attn.k_proj.weight"]
    v_layer = model_data[f"model.layers.{layer}.self_attn.v_proj.weight"]
    w_layer = model_data[f"model.layers.{layer}.self_attn.o_proj.weight"]
    q_layer_bias = model_data[f'model.layers.{layer}.self_attn.q_proj.bias']
    k_layer_bias = model_data[f'model.layers.{layer}.self_attn.k_proj.bias']
    v_layer_bias = model_data[f'model.layers.{layer}.self_attn.v_proj.bias']

    query_states = torch.matmul(layer_embedding_norm, q_layer.T)+q_layer_bias
    key_states = torch.matmul(layer_embedding_norm, k_layer.T)+k_layer_bias
    value_states = torch.matmul(layer_embedding_norm, v_layer.T)+v_layer_bias
    head_dim = dim//n_heads
    query_states = query_states.view(1, q_len, n_heads, head_dim).transpose(1, 2)
    key_states = key_states.view(1, q_len, n_kv_heads, head_dim).transpose(1, 2)
    value_states = value_states.view(1, q_len, n_kv_heads, head_dim).transpose(1, 2)

    cos, sin = rotary_emb(value_states, seq_len=q_len)
    position_ids = torch.arange(q_len).view(1,q_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    
    key_states = repeat_kv(key_states, n_heads // n_kv_heads)
    value_states = repeat_kv(value_states, n_heads // n_kv_heads)
    
    attn_output = torch.nn.functional.scaled_dot_product_attention(
        query_states,
        key_states,
        value_states,
        attn_mask=None,
        dropout_p= 0.0,
        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
        is_causal= True,
    )
    
    

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.view(1, q_len, dim)
    output_states = torch.matmul(attn_output, w_layer.T)
        
    hidden_state = residual1+output_states

    # Fully connected
    residual2 = hidden_state
    
    w1 = model_data[f"model.layers.{layer}.mlp.gate_proj.weight"]
    w2 = model_data[f"model.layers.{layer}.mlp.down_proj.weight"]
    w3 = model_data[f"model.layers.{layer}.mlp.up_proj.weight"]
    second_normalized = rms_norm(hidden_state, model_data[f"model.layers.{layer}.post_attention_layernorm.weight"])
    output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(second_normalized, w1.T)) * torch.matmul(second_normalized, w3.T), w2.T)
    final_embedding = residual2+output_after_feedforward

In [167]:
final_normalized = rms_norm(final_embedding, model_data["model.norm.weight"])
final_normalized.shape

torch.Size([1, 11, 896])

In [168]:
logits = torch.matmul(final_normalized[0][-1], model_data["lm_head.weight"].T)
logits.shape

torch.Size([151936])

In [169]:
next_token = torch.argmax(logits, dim=-1).view(1)
next_token

tensor([97306])

In [170]:
tokenizer.decode(next_token)

'倍'