# [Understand How Llama3.1 Works — A Deep Dive Into the Model Flow](https://medium.com/@yuxiaojian/understand-how-llama3-1-works-a-deep-dive-into-the-model-flow-b149aba04bed)

In [1]:
#%pip install --upgrade transformers
#%pip install --upgrade trl

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel,PeftConfig

In [3]:
#Configure Huggingface token if download model from Huggingface
import getpass
import os
# Get the Huggingface Token
os.environ["HF_TOKEN"] = getpass.getpass()

 ········


In [5]:
# Load from HF 
base_model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(base_model, token=os.environ['HF_TOKEN'])

In [6]:
#  Quantize the llama3.1 FP16 model to BNB NF4. Load the quantized model to GPU
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Use the code below to load directly from HF
base_model_bnb_4b = AutoModelForCausalLM.from_pretrained(base_model, device_map="cuda:0", quantization_config=bnb_config, token=os.environ['HF_TOKEN'])

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
# Get the vocabulary size
vocab_size = len(tokenizer.get_vocab())
print(f"Vocabulary size: {vocab_size}")

Vocabulary size: 128256


In [None]:
# Get a token and ID mapping
def print_tokens_with_ids(txt):
    tokens = tokenizer.tokenize(txt, add_special_tokens=False)
    token_ids = tokenizer.encode(txt, add_special_tokens=False)
    print(list(zip(tokens, token_ids)))

prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Based on the information provided, rewrite the sentence by changing its tense from past to future.<|eot_id|><|start_header_id|>user<|end_header_id|>

She played the piano beautifully for hours and then stopped as it was midnight.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""
print_tokens_with_ids(prompt)

In [None]:
# Show the token IDs
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
input_ids

In [None]:
# After model
outputs = base_model_bnb_4b.generate(input_ids=input_ids,
                          pad_token_id=tokenizer.eos_token_id,
                          max_new_tokens=200,
                          do_sample=True,
                          top_p=0.9,
                          temperature=0.1)
outputs

In [None]:
# Map the IDs back to tokens
result = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=False)[0]
print(result)

In [None]:
base_model_bnb_4b.config

<p align="center">
  <img src="image/llama31-arch.png">
</p>

In [None]:
base_model_bnb_4b

In [None]:
import torch
import torch.nn.functional as F
import math

# Helper functions for Rotary Position Embedding (RoPE) and grouped-query attention

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    """
    Apply rotary positional embeddings to query and key tensors.
    
    Args:
    q, k: Query and key tensors
    cos, sin: Cosine and sine components of rotary embeddings
    position_ids: Tensor of position IDs
    
    Returns:
    q_embed, k_embed: Query and key tensors with rotary embeddings applied
    """
    # Reshape cosine and sine tensors
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    
    # Select relevant positional embeddings
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    
    # Apply rotary embeddings
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    
    return q_embed, k_embed

def rotate_half(x):
    """
    Rotate half of the hidden dimensions of the input tensor.
    
    This operation is part of the rotary position embedding technique.
    
    Args:
    x: Input tensor
    
    Returns:
    Tensor with half of its last dimension rotated
    """
    x1 = x[..., :x.shape[-1] // 2]  # First half of hidden dims
    x2 = x[..., x.shape[-1] // 2:]  # Second half of hidden dims
    return torch.cat((-x2, x1), dim=-1)  # Concatenate rotated halves

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    Repeat key and value states for grouped-query attention.
    
    This function expands the key and value states to match the number of query heads
    in grouped-query attention mechanisms.
    
    Args:
    hidden_states: Input tensor of shape (batch, num_key_value_heads, seqlen, head_dim)
    n_rep: Number of repetitions (usually num_query_heads // num_key_value_heads)
    
    Returns:
    Tensor of shape (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states  # No need to repeat if n_rep is 1
    
    # Expand and reshape to repeat the key/value states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

## Step by step

In [None]:
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
position_ids = torch.arange(0, input_ids.shape[1]).unsqueeze(0).cuda()
seq_length = input_ids.shape[1]
embeddings = base_model_bnb_4b.model.embed_tokens(input_ids)
print("Embeddings shape:", embeddings.shape)

In [None]:
hidden_states = embeddings
hidden_size = base_model_bnb_4b.config.hidden_size
batch_size = 1

In [None]:
# Process through each layer of the model
for layer_idx, layer in enumerate(base_model_bnb_4b.model.layers):
    print(f"Processing layer {layer_idx}")
    
    # 1. Input LayerNorm
    normalized_hidden_states = layer.input_layernorm(hidden_states)
    
    # 2. Self-attention mechanism
    # 2.1 Query, Key, Value projections
    query_states = layer.self_attn.q_proj(normalized_hidden_states)
    key_states = layer.self_attn.k_proj(normalized_hidden_states)
    value_states = layer.self_attn.v_proj(normalized_hidden_states)
    
    # 2.2 Reshape and transpose Q, K, V
    # (batch_size, seq_length, num_heads, head_dim) -> (batch_size, num_heads, seq_length, head_dim)
    query_states = query_states.view(batch_size, seq_length, layer.self_attn.num_heads, layer.self_attn.head_dim).transpose(1, 2)
    key_states = key_states.view(batch_size, seq_length, layer.self_attn.num_key_value_heads, layer.self_attn.head_dim).transpose(1, 2)
    value_states = value_states.view(batch_size, seq_length, layer.self_attn.num_key_value_heads, layer.self_attn.head_dim).transpose(1, 2)
    
    # 2.3 Apply rotary positional embeddings
    kv_seq_len = key_states.shape[-2]
    cos, sin = layer.self_attn.rotary_emb(value_states, position_ids)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    
    # 2.4 Handle grouped-query attention
    # Repeat K and V for each query group
    key_states = repeat_kv(key_states, layer.self_attn.num_key_value_groups)
    value_states = repeat_kv(value_states, layer.self_attn.num_key_value_groups)
    
    # 2.5 Compute attention scores
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(layer.self_attn.head_dim)
    attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    
    # 2.6 Apply attention to values
    attn_output = torch.matmul(attn_weights, value_states)
    
    # 2.7 Reshape attention output
    # (batch_size, num_heads, seq_length, head_dim) -> (batch_size, seq_length, hidden_size)
    attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)
    
    # 2.8 Output projection
    attn_output = layer.self_attn.o_proj(attn_output)
    
    # 2.9 Residual connection
    hidden_states = hidden_states + attn_output
    
    # 3. Post-attention LayerNorm
    normalized_hidden_states = layer.post_attention_layernorm(hidden_states)
    
    # 4. MLP (Feed-Forward Network)
    # 4.1 Apply gate and up projections
    gate_output = layer.mlp.gate_proj(normalized_hidden_states)
    up_output = layer.mlp.up_proj(normalized_hidden_states)
    
    # 4.2 Apply activation function and element-wise multiplication
    mlp_output = up_output * layer.mlp.act_fn(gate_output)
    
    # 4.3 Down projection
    mlp_output = layer.mlp.down_proj(mlp_output)
    
    # 4.4 Residual connection
    hidden_states = hidden_states + mlp_output
    
    print(f"  Output shape: {hidden_states.shape}")

In [None]:
# Final LayerNorm
# Normalize the hidden states from the last layer
hidden_states = base_model_bnb_4b.model.norm(hidden_states)

# Language Model Head
# Project the normalized hidden states to the vocabulary space
lm_logits = base_model_bnb_4b.lm_head(hidden_states)

# Get the logits for the last token
# We're only interested in predicting the next token, so we take the last position
last_token_logits = lm_logits[:, -1, :]

# Note: The following line is commented out as we're using a more sophisticated sampling method
# next_token = torch.argmax(last_token_logits, dim=-1)

# Step 7: Apply temperature
# Temperature adjusts the randomness of predictions. Lower values make the model more confident.
temperature = 0.1
scaled_logits = last_token_logits / temperature

# Step 8: Apply top-p (nucleus) sampling
# This method truncates the least likely tokens whose cumulative probability exceeds (1 - top_p)
top_p = 0.9
# Sort logits in descending order
sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True)
# Calculate cumulative probabilities
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
# Set logits of removed indices to negative infinity
scaled_logits[indices_to_remove] = float('-inf')

# Step 9: Sample from the filtered distribution
# Convert logits to probabilities
probs = torch.softmax(scaled_logits, dim=-1)
# Randomly sample a token based on the calculated probabilities
next_token = torch.multinomial(probs, num_samples=1)

# Print the input tokens and the predicted next token
print(f"\nInput tokens: {input_ids}")
print(f"Most likely next token ID: {next_token.item()}")

## Create a function

In [None]:
def generate_tokens(model, tokenizer, prompt, max_new_tokens=50, temperature=1.0, top_k=50, top_p=0.95):
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
    
    for _ in range(max_new_tokens):
        # Prepare input
        position_ids = torch.arange(0, input_ids.shape[1]).unsqueeze(0).cuda()
        seq_length = input_ids.shape[1]
        batch_size = 1
        hidden_size = model.config.hidden_size

        # Token embedding
        hidden_states = model.model.embed_tokens(input_ids)

        # Process through each layer
        for layer in model.model.layers:
            # 1. Input LayerNorm
            normalized_hidden_states = layer.input_layernorm(hidden_states)
            
            # 2. Self-attention
            query_states = layer.self_attn.q_proj(normalized_hidden_states)
            key_states = layer.self_attn.k_proj(normalized_hidden_states)
            value_states = layer.self_attn.v_proj(normalized_hidden_states)
            
            # Reshape and transpose
            query_states = query_states.view(batch_size, seq_length, layer.self_attn.num_heads, layer.self_attn.head_dim).transpose(1, 2)
            key_states = key_states.view(batch_size, seq_length, layer.self_attn.num_key_value_heads, layer.self_attn.head_dim).transpose(1, 2)
            value_states = value_states.view(batch_size, seq_length, layer.self_attn.num_key_value_heads, layer.self_attn.head_dim).transpose(1, 2)
            
            # Apply rotary embeddings
            cos, sin = layer.self_attn.rotary_emb(value_states, position_ids)
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
            
            # Handle grouped-query attention
            key_states = repeat_kv(key_states, layer.self_attn.num_key_value_groups)
            value_states = repeat_kv(value_states, layer.self_attn.num_key_value_groups)
            
            # Compute attention scores and apply attention
            attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(layer.self_attn.head_dim)
            attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_output = torch.matmul(attn_weights, value_states)
            
            # Reshape attention output
            attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)
            
            # Output projection
            attn_output = layer.self_attn.o_proj(attn_output)
            
            # Residual connection
            hidden_states = hidden_states + attn_output
            
            # 3. Post-attention LayerNorm
            normalized_hidden_states = layer.post_attention_layernorm(hidden_states)
            
            # 4. MLP
            mlp_output = layer.mlp.gate_proj(normalized_hidden_states)
            mlp_output = layer.mlp.up_proj(normalized_hidden_states) * layer.mlp.act_fn(mlp_output)
            mlp_output = layer.mlp.down_proj(mlp_output)
            
            # Residual connection
            hidden_states = hidden_states + mlp_output

        # Final LayerNorm
        hidden_states = model.model.norm(hidden_states)

        # Language Model Head
        lm_logits = model.lm_head(hidden_states)

        # Get the logits for the last token
        last_token_logits = lm_logits[:, -1, :]


        # Step 7: Apply temperature
        # Temperature adjusts the randomness of predictions. Lower values make the model more confident.
        scaled_logits = last_token_logits / temperature
        
        # Step 8: Apply top-p (nucleus) sampling
        # This method truncates the least likely tokens whose cumulative probability exceeds (1 - top_p)
        # Sort logits in descending order
        sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True)
        # Calculate cumulative probabilities
        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        # Scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        # Set logits of removed indices to negative infinity
        scaled_logits[indices_to_remove] = float('-inf')
        
        # Step 9: Sample from the filtered distribution
        # Convert logits to probabilities
        probs = torch.softmax(scaled_logits, dim=-1)
        # Randomly sample a token based on the calculated probabilities
        next_token = torch.multinomial(probs, num_samples=1)

        # Print the input tokens and the predicted next token
        print(f"\nInput tokens: {input_ids}")
        print(f"Most likely next token ID: {next_token.item()}")

        # Append the new token to input_ids
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        
        # Check if we've generated an end-of-sequence token
        if next_token.item() == tokenizer.eos_token_id:
            break

    return input_ids

In [None]:
# Use the function
prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Based on the information provided, rewrite the sentence by changing its tense from past to future.<|eot_id|><|start_header_id|>user<|end_header_id|>

She played the piano beautifully for hours and then stopped as it was midnight.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""
generated_ids = generate_tokens(base_model_bnb_4b, tokenizer, prompt, max_new_tokens=50)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
print(f"\nGenerated text:\n{generated_text}")