# Week 3.4: From GPT-2 to LLaMA3: Understanding Modern LLM Architectures

In this notebook, we'll explore how language model architectures have evolved from GPT-2 to the modern LLaMA family of models. We'll implement key components of the LLaMA architecture and highlight the architectural innovations that have improved performance and efficiency.

## Learning Objectives

After completing this notebook, you will be able to:

1. Understand the key architectural differences between GPT-2 and LLaMA models
2. Implement and explain Rotary Positional Embeddings (RoPE)
3. Implement and explain RMSNorm as an alternative to LayerNorm
4. Understand Group Query Attention (GQA) and its advantages
5. Implement SwiGLU activation for feed-forward networks
6. Integrate these components into a coherent model architecture

## Required Libraries

Let's start by importing the necessary libraries.

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

## 1. Architectural Overview: GPT-2 vs. LLaMA

Before diving into the implementation details, let's compare the high-level architectures of GPT-2 and LLaMA models:

| Component | GPT-2 | LLaMA |
|-----------|-------|-------|
| Normalization | LayerNorm | RMSNorm |
| Normalization Placement | Post-attention & Post-FFN | Pre-attention & Pre-FFN ("RMSNorm-only") |
| Positional Encoding | Learned Positional Embeddings | Rotary Positional Embeddings (RoPE) |
| Attention Mechanism | Multi-head Attention (equal Q/K/V heads) | Group Query Attention (more Q than K/V heads) |
| Feed-Forward Network | GELU Activation | SwiGLU Activation |
| Parameter Count | 124M - 1.5B | 7B - 70B+ |

Both models follow the general Transformer architecture with a stack of self-attention and feed-forward layers, but LLaMA introduces several key innovations that improve efficiency and performance at scale.

## 2. Configuration

Let's define configurations for both GPT-2 (for reference) and a small LLaMA model:

In [2]:
class GPT2Config:
    def __init__(self):
        self.vocab_size = 50257
        self.n_positions = 1024
        self.n_embd = 768
        self.n_layer = 12
        self.n_head = 12
        self.dropout = 0.1

class LLaMAConfig:
    def __init__(self):
        # Model dimensions
        self.vocab_size = 32000
        self.hidden_size = 768  # Dimension of embeddings
        self.intermediate_size = 2048  # Dimension of feed-forward layer
        self.n_layers = 12
        
        # Attention-specific parameters
        self.n_heads = 32  # Number of query heads
        self.n_kv_heads = 8  # Number of key/value heads (for GQA)
        self.head_dim = 64   # Dimension of each attention head
        
        # Positional embedding
        self.max_position_embeddings = 2048
        self.rope_theta = 10000.0  # Base for RoPE calculations
        
        # Other parameters
        self.norm_eps = 1e-5  # Epsilon for normalization stability
        self.dropout = 0.0  # Modern LLMs often use no dropout

# We'll use these configs later when implementing our models
gpt2_config = GPT2Config()
llama_config = LLaMAConfig()

## 3. Normalization: LayerNorm vs. RMSNorm

GPT-2 uses standard LayerNorm, which normalizes inputs by both mean and variance. LLaMA uses RMSNorm (Root Mean Square Normalization), which only normalizes by the root mean square of activations, making it more computationally efficient.

Let's implement both:

In [3]:
class LayerNorm(nn.Module):
    """LayerNorm normalization layer, as used in GPT-2"""
    def __init__(self, ndim, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
        
    def forward(self, x):
        # Normalize by mean and variance
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x = (x - mean) / torch.sqrt(var + 1e-5)  # Add epsilon for numerical stability
        
        # Apply weights and optional bias
        if self.bias is not None:
            x = self.weight * x + self.bias
        else:
            x = self.weight * x
        return x

class RMSNorm(nn.Module):
    """RMSNorm normalization layer, as used in LLaMA"""
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps
        
    def forward(self, x):
        # Normalize by Root Mean Square (RMS) only
        # No centering step (no mean subtraction)
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        x = x / rms * self.weight
        return x

Let's compare the behavior of these two normalization methods on a simple example:

In [4]:
# Create sample data with specific mean and variance
x = torch.randn(5, 10) * 3 + 2  # mean ≈ 2, std ≈ 3

# Apply both normalization methods
layer_norm = LayerNorm(10)
rms_norm = RMSNorm(10)

x_ln = layer_norm(x)
x_rms = rms_norm(x)

# Compare statistics
print("Original data:")
print(f"  Mean: {x.mean().item():.4f}, Std: {x.std().item():.4f}")

print("\nAfter LayerNorm:")
print(f"  Mean: {x_ln.mean().item():.4f}, Std: {x_ln.std().item():.4f}")
print(f"  Mean per sample: {x_ln.mean(dim=1)}")

print("\nAfter RMSNorm:")
print(f"  Mean: {x_rms.mean().item():.4f}, Std: {x_rms.std().item():.4f}")
print(f"  Mean per sample: {x_rms.mean(dim=1)}")

# Key difference: LayerNorm forces means to 0, while RMSNorm preserves direction

Original data:
  Mean: 2.1427, Std: 2.9749

After LayerNorm:
  Mean: -0.0000, Std: 1.0102
  Mean per sample: tensor([-5.9605e-09, -2.9802e-09, -4.1723e-08,  2.3842e-08, -5.9605e-09],
       grad_fn=<MeanBackward1>)

After RMSNorm:
  Mean: 0.6163, Std: 0.7955
  Mean per sample: tensor([0.6000, 0.6447, 0.7569, 0.5719, 0.5080], grad_fn=<MeanBackward1>)


### Key Differences:

1. **Computation**: LayerNorm subtracts the mean and divides by standard deviation, while RMSNorm only divides by the root mean square.
2. **Parameter Efficiency**: RMSNorm has fewer parameters (no bias term).
3. **Performance**: RMSNorm is faster to compute.
4. **Direction Preservation**: RMSNorm preserves the direction of the input vectors, only scaling their magnitude, which can better preserve the information in representations.

RMSNorm has proven to be particularly effective for larger language models, helping them train more stably with improved gradient flow.

## 4. Positional Encoding: Learned Embeddings vs. RoPE

GPT-2 uses learned positional embeddings, which are added to the token embeddings at the input layer. LLaMA uses Rotary Positional Embeddings (RoPE), which apply a rotation to the query and key vectors in the attention mechanism.

Let's implement both approaches:

In [5]:
class LearnedPositionalEmbeddings(nn.Module):
    """Learned positional embeddings, as used in GPT-2"""
    def __init__(self, max_seq_len, embed_dim):
        super().__init__()
        self.position_embeddings = nn.Embedding(max_seq_len, embed_dim)
        
    def forward(self, token_embeddings):
        seq_len = token_embeddings.size(1)
        positions = torch.arange(0, seq_len, device=token_embeddings.device).unsqueeze(0)
        position_embeddings = self.position_embeddings(positions)
        
        # Add positional embeddings to token embeddings
        return token_embeddings + position_embeddings

Now, let's implement Rotary Positional Embeddings (RoPE). RoPE uses complex number rotations to encode positions, which provides better relative position modeling and helps extend context length beyond the training sequence length.

In [6]:
class RotaryEmbedding(nn.Module):
    """Rotary Positional Embeddings (RoPE), as used in LLaMA"""
    def __init__(self, dim, max_position=2048, base=10000.0):
        super().__init__()
        # For each dimension, we'll have a different frequency
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.max_position = max_position
        self.dim = dim
        
    def forward(self, x, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[1]
        
        # Create position indices
        position = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
        
        # Outer product to get a frequency for each position and dimension
        freqs = torch.outer(position, self.inv_freq)
        
        # Create complex numbers from frequencies (cos and sin values)
        emb = torch.cat((freqs, freqs), dim=-1)
        
        # Reshape to match expected dimensions [seq_len, dim]
        return emb.cos(), emb.sin()
    
def apply_rotary_embeddings(q, k, cos, sin):
    """Apply rotary embeddings to query and key tensors"""
    # Reshape inputs for efficient broadcasting
    q_embed = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1)
    k_embed = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)
    
    # Create rotated version by applying trig functions
    q_cos = cos * q_embed[..., :q.shape[-1]//2]
    q_sin = sin * q_embed[..., q.shape[-1]//2:]
    k_cos = cos * k_embed[..., :k.shape[-1]//2]
    k_sin = sin * k_embed[..., k.shape[-1]//2:]
    
    # Combine real and imaginary parts
    q_rotated = torch.cat(
        [q_cos - q_sin, q_cos + q_sin], 
        dim=-1
    ).reshape(q.shape)
    
    k_rotated = torch.cat(
        [k_cos - k_sin, k_cos + k_sin], 
        dim=-1
    ).reshape(k.shape)
    
    return q_rotated, k_rotated

To better understand RoPE, let's visualize how it embeds positional information:

In [7]:
# Visualize rotary embeddings
def visualize_rope(dim=64, seq_len=10, base=10000.0):
    # Create a rotation embedding instance
    rope = RotaryEmbedding(dim=dim, base=base)
    
    # Create a simple query vector filled with ones
    # [batch_size=1, seq_len, dim]
    q = torch.ones(1, seq_len, dim)
    
    # Get the cos and sin components
    cos, sin = rope(q)
    
    # Apply rotary embeddings
    q_rotated, _ = apply_rotary_embeddings(q, q, cos, sin)
    
    # Visualize the original and rotated vectors for the first dimension pair
    plt.figure(figsize=(10, 6))
    
    # Select a pair of dimensions to visualize (each position has a rotation in 2D space)
    dim_1, dim_2 = 0, 1
    
    # Plot original points
    x_orig = q[0, :, dim_1].detach().numpy()
    y_orig = q[0, :, dim_2].detach().numpy()
    plt.scatter(x_orig, y_orig, label='Original', color='blue')
    
    # Plot rotated points
    x_rot = q_rotated[0, :, dim_1].detach().numpy()
    y_rot = q_rotated[0, :, dim_2].detach().numpy()
    plt.scatter(x_rot, y_rot, label='After RoPE', color='red')
    
    # Add arrows to show the rotation for each position
    for i in range(seq_len):
        plt.annotate(f"Pos {i}", (x_rot[i], y_rot[i]), fontsize=9)
        plt.arrow(x_orig[i], y_orig[i], x_rot[i]-x_orig[i], y_rot[i]-y_orig[i], 
                 width=0.02, head_width=0.05, color='gray', alpha=0.5)
    
    plt.title(f"Rotary Position Embeddings Effect on Vector Dimensions {dim_1} and {dim_2}")
    plt.xlabel(f"Dimension {dim_1}")
    plt.ylabel(f"Dimension {dim_2}")
    plt.grid(True)
    plt.legend()
    plt.axis('equal')
    plt.show()

visualize_rope(dim=64, seq_len=8)

RuntimeError: The size of tensor a (64) must match the size of tensor b (32) at non-singleton dimension 2

### Key Differences:

1. **Integration**: Learned embeddings are added to token embeddings at the input, while RoPE is applied to query and key vectors in each attention layer.
2. **Inductive Bias**: RoPE explicitly encodes relative positions, while learned embeddings must learn position relationships implicitly.
3. **Extrapolation**: RoPE allows models to generalize to longer sequences than seen during training.
4. **Parameter Efficiency**: RoPE doesn't add trainable parameters, while learned embeddings do.

RoPE's ability to model relative positions and extrapolate to longer sequences has made it particularly valuable for modern LLMs.

## 5. Attention Mechanism: Standard MHA vs. Group Query Attention

GPT-2's architecture uses standard Multi-Head Attention (MHA) where the number of query, key, and value heads are equal. LLaMA introduces Group Query Attention (GQA), where multiple query heads share the same key and value heads, reducing computational cost.

Let's implement both approaches:

In [None]:
class GPT2Attention(nn.Module):
    """Standard multi-head attention as used in GPT-2"""
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        
        # Single projection for Q, K, V (as in GPT-2)
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        
    def forward(self, x, attention_mask=None):
        B, T, C = x.size()  # batch, sequence length, embedding dimensionality
        
        # Calculate query, key, values for all heads in batch
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        
        # Reshape to [B, nh, T, hs]
        head_size = C // self.n_head
        k = k.view(B, T, self.n_head, head_size).transpose(1, 2)
        q = q.view(B, T, self.n_head, head_size).transpose(1, 2)
        v = v.view(B, T, self.n_head, head_size).transpose(1, 2)
        
        # Compute attention scores
        # [B, nh, T, T]
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        
        # Apply attention mask (for causal attention)
        if attention_mask is None:
            # Causal self-attention mask
            mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
            att = att.masked_fill(mask == 0, float('-inf'))
        else:
            att = att + attention_mask
        
        # Apply softmax
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        
        # Apply attention to values
        y = att @ v  # [B, nh, T, hs]
        
        # Reshape back to [B, T, C]
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        
        # Output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

In [None]:
class LLaMAGroupQueryAttention(nn.Module):
    """Group Query Attention as used in LLaMA"""
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads  # Number of query heads
        self.n_kv_heads = config.n_kv_heads  # Number of key/value heads
        self.head_dim = config.head_dim
        self.hidden_size = config.hidden_size
        
        # Compute mapping from query heads to key/value heads
        # Each key/value head serves multiple query heads
        self.n_rep = self.n_heads // self.n_kv_heads
        
        # Separate projections for Q, K, V
        self.q_proj = nn.Linear(config.hidden_size, self.n_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.n_heads * self.head_dim, config.hidden_size, bias=False)
        
        # Rotary embeddings
        self.rope = RotaryEmbedding(self.head_dim, max_position=config.max_position_embeddings, 
                                   base=config.rope_theta)
        
    def forward(self, x, attention_mask=None):
        B, T, C = x.size()  # batch, sequence length, embedding dimensionality
        
        # Compute Q, K, V
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)  # [B, T, n_q, head_dim]
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim)  # [B, T, n_kv, head_dim]
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim)  # [B, T, n_kv, head_dim]
        
        # Apply rotary embeddings
        cos, sin = self.rope(x, seq_len=T)
        q = q.transpose(1, 2)  # [B, n_q, T, head_dim]
        k = k.transpose(1, 2)  # [B, n_kv, T, head_dim]
        v = v.transpose(1, 2)  # [B, n_kv, T, head_dim]
        
        # Apply RoPE to q and k
        for i in range(self.n_heads):
            # Find corresponding kv head - integer division maps query head to kv head
            kv_idx = i // self.n_rep
            q_i = q[:, i]
            k_i = k[:, kv_idx]
            
            # Apply rotary embeddings
            q_rotated, k_rotated = apply_rotary_embeddings(q_i.unsqueeze(1), k_i.unsqueeze(1), cos, sin)
            
            # Store rotated values back
            q[:, i] = q_rotated.squeeze(1)
            k[:, kv_idx] = k_rotated.squeeze(1)
            
        # Compute attention scores
        attn_output = torch.zeros(B, self.n_heads, T, self.head_dim, device=x.device)
        
        # Process each query head
        for i in range(self.n_heads):
            # Find corresponding kv head
            kv_idx = i // self.n_rep
            q_i = q[:, i]  # [B, T, head_dim]
            k_i = k[:, kv_idx]  # [B, T, head_dim]
            v_i = v[:, kv_idx]  # [B, T, head_dim]
            
            # Compute attention scores
            att = (q_i @ k_i.transpose(-2, -1)) * (1.0 / math.sqrt(k_i.size(-1)))  # [B, T, T]
            
            # Apply causal mask
            mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0)
            att = att.masked_fill(mask == 0, float('-inf'))
            
            # Apply softmax
            att = F.softmax(att, dim=-1)  # [B, T, T]
            
            # Apply attention to values
            attn_output[:, i] = att @ v_i  # [B, T, head_dim]
        
        # Reshape and project outputs
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
        output = self.o_proj(attn_output)
        
        return output

Let's visualize the key difference in head structure between standard attention and Group Query Attention:

In [None]:
def visualize_head_structure():
    """Visualize the head structure difference between MHA and GQA"""
    # Define parameters
    n_mha_heads = 8
    n_gqa_q_heads = 32
    n_gqa_kv_heads = 8
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    # MHA visualization
    mha_q = np.ones((n_mha_heads, 1))
    mha_k = np.ones((n_mha_heads, 1))
    mha_v = np.ones((n_mha_heads, 1))
    
    ax1.imshow(np.hstack([mha_q, mha_k, mha_v]), cmap='Blues', aspect='auto')
    
    # Add labels
    for i in range(n_mha_heads):
        ax1.text(-0.5, i, f"H{i}", ha="center", va="center")
        ax1.text(0, i, f"Q{i}", ha="center", va="center", color='white')
        ax1.text(1, i, f"K{i}", ha="center", va="center", color='white')
        ax1.text(2, i, f"V{i}", ha="center", va="center", color='white')
    
    ax1.set_title(f"Multi-Head Attention (MHA)\n{n_mha_heads} query, key, and value heads")
    ax1.set_xticks([0, 1, 2])
    ax1.set_xticklabels(['Query', 'Key', 'Value'])
    ax1.set_yticks([])
    
    # GQA visualization
    gqa_heads = np.ones((n_gqa_q_heads, 3))
    
    # Show which KV heads are shared
    colors = plt.cm.tab10(np.linspace(0, 1, n_gqa_kv_heads))
    gqa_colors = np.zeros((n_gqa_q_heads, 3, 3))
    
    # Assign colors to show grouping
    for i in range(n_gqa_q_heads):
        kv_idx = i // (n_gqa_q_heads // n_gqa_kv_heads)
        gqa_colors[i, 0] = [1, 1, 1]  # White for query
        gqa_colors[i, 1] = colors[kv_idx, :3]  # Colored for key
        gqa_colors[i, 2] = colors[kv_idx, :3]  # Same color for value
    
    # Plot rectangles manually
    for i in range(n_gqa_q_heads):
        kv_idx = i // (n_gqa_q_heads // n_gqa_kv_heads)
        ax2.add_patch(plt.Rectangle((0, i), 1, 1, color='royalblue'))
        ax2.add_patch(plt.Rectangle((1, i), 1, 1, color=colors[kv_idx, :3]))
        ax2.add_patch(plt.Rectangle((2, i), 1, 1, color=colors[kv_idx, :3]))
        
        # Add text labels
        ax2.text(-0.5, i, f"H{i}", ha="center", va="center")
        ax2.text(0, i, f"Q{i}", ha="center", va="center", color='white')
        ax2.text(1, i, f"K{kv_idx}", ha="center", va="center", color='white')
        ax2.text(2, i, f"V{kv_idx}", ha="center", va="center", color='white')
    
    ax2.set_xlim(-1, 3)
    ax2.set_ylim(n_gqa_q_heads, -1)  # Reverse y axis
    ax2.set_title(f"Group Query Attention (GQA)\n{n_gqa_q_heads} query heads, {n_gqa_kv_heads} key-value heads")
    ax2.set_xticks([0, 1, 2])
    ax2.set_xticklabels(['Query', 'Key', 'Value'])
    ax2.set_yticks([])
    
    plt.tight_layout()
    plt.show()

visualize_head_structure()

### Key Differences:

1. **Parameter Efficiency**: GQA reduces the number of parameters by sharing key and value projections across multiple query heads.
2. **Computation Cost**: GQA significantly reduces FLOPs in large models by computing fewer key and value projections.
3. **Representation Power**: GQA maintains high representation power with more query heads, while reducing memory and computation costs.
4. **Memory Bandwidth**: GQA reduces KV cache memory requirements during inference by a factor proportional to the ratio of query to key-value heads.

GQA has been particularly important for scaling LLMs efficiently, reducing memory bottlenecks during both training and inference.

## 6. Feed-Forward Network: GELU vs. SwiGLU

GPT-2 uses a standard feed-forward network with GELU activation, while LLaMA uses SwiGLU, which is a variant of the GLU (Gated Linear Unit) activation.

Let's implement both:

In [None]:
class GPT2MLP(nn.Module):
    """Standard MLP with GELU activation, as used in GPT-2"""
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x):
        x = self.c_fc(x)
        x = F.gelu(x)  # GELU activation
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class LLaMASwiGLU(nn.Module):
    """SwiGLU feed-forward network, as used in LLaMA"""
    def __init__(self, config):
        super().__init__()
        self.w1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.w3 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.w2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        
    def forward(self, x):
        # SwiGLU activation: (x * W1) * SiLU((x * W3))
        # First compute the gating function SiLU(x * W3)
        gate = F.silu(self.w3(x))
        # Then compute the linear transform x * W1
        x = self.w1(x)
        # Multiply the gate with the linear transform
        x = x * gate
        # Apply the final linear projection
        x = self.w2(x)
        return x

Let's visualize and compare the GELU and SiLU (used in SwiGLU) activation functions:

In [None]:
def visualize_activations():
    """Visualize GELU and SiLU activation functions"""
    x = torch.linspace(-5, 5, 1000)
    
    gelu = F.gelu(x)
    silu = F.silu(x)  # Same as Swish activation
    
    plt.figure(figsize=(10, 6))
    plt.plot(x.numpy(), gelu.numpy(), label='GELU (GPT-2)', linewidth=2)
    plt.plot(x.numpy(), silu.numpy(), label='SiLU/Swish (LLaMA)', linewidth=2, linestyle='--')
    
    # Add baseline ReLU for comparison
    plt.plot(x.numpy(), F.relu(x).numpy(), label='ReLU (reference)', color='gray', alpha=0.5)
    
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12)
    plt.title('Comparison of Activation Functions')
    plt.xlabel('Input')
    plt.ylabel('Output')
    plt.axhline(y=0, color='k', linestyle='-', alpha=0.2)
    plt.axvline(x=0, color='k', linestyle='-', alpha=0.2)
    plt.show()
    
    # Explain the gating in SwiGLU
    plt.figure(figsize=(10, 6))
    
    # Example where the gate varies
    x1 = torch.ones(1000) * 2.0  # Constant input
    x2 = torch.linspace(-5, 5, 1000)  # Varying input to gate
    
    plt.plot(x2.numpy(), (x1 * F.silu(x2)).numpy(), label='SwiGLU Gate Effect', linewidth=2)
    plt.plot(x2.numpy(), F.silu(x2).numpy(), label='SiLU Gate', linewidth=2, linestyle='--')
    plt.plot(x2.numpy(), x1.numpy(), label='Input Signal', color='gray', alpha=0.5)
    
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12)
    plt.title('SwiGLU Gating Mechanism')
    plt.xlabel('Gate Input')
    plt.ylabel('Output')
    plt.axhline(y=0, color='k', linestyle='-', alpha=0.2)
    plt.axvline(x=0, color='k', linestyle='-', alpha=0.2)
    plt.show()

visualize_activations()

### Key Differences:

1. **Activation Function**: GPT-2 uses GELU activation while LLaMA uses the SiLU (Swish) activation in a gated architecture (SwiGLU).
2. **Gating Mechanism**: SwiGLU employs a gating mechanism where one projection is multiplied by a non-linear transformation of another projection, allowing for more complex feature interactions.
3. **Projections**: GPT-2 uses two projections (up and down), while LLaMA uses three (w1, w2, and w3).
4. **Performance**: SwiGLU has been shown to perform better than GELU for large language models, offering improved convergence and better final performance.

The SwiGLU activation provides a more expressive transformation in the feed-forward network, contributing to LLaMA's improved performance.

## 7. Putting It All Together: LLaMA Block and Model

Now, let's implement a full LLaMA transformer block and the complete model structure:

In [None]:
class LLaMABlock(nn.Module):
    """LLaMA Transformer block"""
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        
        # Pre-normalization for attention (different from GPT-2's post-normalization)
        self.attention_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
        self.attention = LLaMAGroupQueryAttention(config)
        
        # Pre-normalization for feed-forward
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
        self.feed_forward = LLaMASwiGLU(config)
        
    def forward(self, x):
        # Pre-norm for attention (different from GPT-2)
        h = x + self.attention(self.attention_norm(x))
        
        # Pre-norm for feed-forward
        out = h + self.feed_forward(self.ffn_norm(h))
        
        return out

class LLaMAModel(nn.Module):
    """LLaMA language model"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size
        
        # Token embeddings
        self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        
        # No separate position embeddings like in GPT-2 (RoPE is applied inside attention)
        
        # Transformer blocks
        self.layers = nn.ModuleList([LLaMABlock(config) for _ in range(config.n_layers)])
        
        # Final normalization layer
        self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
        
        # Language modeling head
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        """Initialize weights with standard approach"""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        
    def forward(self, input_ids):
        """Forward pass"""
        B, T = input_ids.size()  # batch size, sequence length
        
        # Get token embeddings
        h = self.token_embeddings(input_ids)
        
        # Apply transformer blocks
        for layer in self.layers:
            h = layer(h)
        
        # Apply final normalization
        h = self.norm(h)
        
        # Language modeling head
        logits = self.lm_head(h)
        
        return logits

## 8. Comparison: Key Differences Visualized

Let's visualize the key architectural differences between GPT-2 and LLaMA with diagrams:

In [None]:
def visualize_architecture_comparison():
    """Visualize key architecture differences between GPT-2 and LLaMA"""
    # Create figure
    fig, ax = plt.subplots(1, 2, figsize=(15, 10))
    
    # GPT-2 Architecture
    gpt2_blocks = [
        "Input Embeddings",
        "+ Positional Embeddings",
        "↓",
        "Self-Attention",
        "+ Residual",
        "LayerNorm",
        "↓",
        "MLP (GELU)",
        "+ Residual",
        "LayerNorm",
        "↓",
        "x N Layers",
        "↓",
        "LayerNorm",
        "↓",
        "Language Modeling Head"
    ]
    
    # LLaMA Architecture
    llama_blocks = [
        "Input Embeddings",
        "↓",
        "RMSNorm",
        "↓",
        "Self-Attention (RoPE, GQA)",
        "+ Residual",
        "↓",
        "RMSNorm",
        "↓",
        "MLP (SwiGLU)",
        "+ Residual",
        "↓",
        "x N Layers",
        "↓",
        "RMSNorm",
        "↓",
        "Language Modeling Head"
    ]
    
    # Color maps for different components
    gpt2_colors = [
        'lightblue', 'lightblue',  # Embeddings
        'white',  # Arrow
        'lightgreen', 'lightgreen', 'lightgreen',  # Attention
        'white',  # Arrow
        'lightcoral', 'lightcoral', 'lightcoral',  # MLP
        'white',  # Arrow
        'lightyellow',  # Layers indicator
        'white',  # Arrow
        'lightgreen',  # Final LayerNorm
        'white',  # Arrow
        'lightgrey'  # LM Head
    ]
    
    llama_colors = [
        'lightblue',  # Embeddings
        'white',  # Arrow
        'lightgreen',  # RMSNorm
        'white',  # Arrow
        'lightgreen', 'lightgreen',  # Attention
        'white',  # Arrow
        'lightgreen',  # RMSNorm
        'white',  # Arrow
        'lightcoral', 'lightcoral',  # MLP
        'white',  # Arrow
        'lightyellow',  # Layers indicator
        'white',  # Arrow
        'lightgreen',  # Final RMSNorm
        'white',  # Arrow
        'lightgrey'  # LM Head
    ]
    
    # Plot GPT-2 architecture
    for i, (block, color) in enumerate(zip(gpt2_blocks, gpt2_colors)):
        if block in ["↓"]:
            # Draw arrow
            ax[0].text(0.5, 1 - (i + 0.5) / len(gpt2_blocks), "↓", 
                     ha="center", va="center", fontsize=14)
        else:
            # Draw box
            box = plt.Rectangle((0.1, 1 - (i + 1) / len(gpt2_blocks) * 0.9), 
                              0.8, 0.9 / len(gpt2_blocks), 
                              facecolor=color, alpha=0.8, edgecolor='black')
            ax[0].add_patch(box)
            # Add text
            ax[0].text(0.5, 1 - (i + 0.5) / len(gpt2_blocks) * 0.9, block, 
                     ha="center", va="center", fontsize=10)
    
    # Plot LLaMA architecture
    for i, (block, color) in enumerate(zip(llama_blocks, llama_colors)):
        if block in ["↓"]:
            # Draw arrow
            ax[1].text(0.5, 1 - (i + 0.5) / len(llama_blocks), "↓", 
                     ha="center", va="center", fontsize=14)
        else:
            # Draw box
            box = plt.Rectangle((0.1, 1 - (i + 1) / len(llama_blocks) * 0.9), 
                              0.8, 0.9 / len(llama_blocks), 
                              facecolor=color, alpha=0.8, edgecolor='black')
            ax[1].add_patch(box)
            # Add text
            ax[1].text(0.5, 1 - (i + 0.5) / len(llama_blocks) * 0.9, block, 
                     ha="center", va="center", fontsize=10)
    
    # Set titles and adjust
    ax[0].set_title("GPT-2 Architecture", fontsize=16)
    ax[1].set_title("LLaMA Architecture", fontsize=16)
    
    for i in range(2):
        ax[i].set_xlim(0, 1)
        ax[i].set_ylim(0, 1)
        ax[i].axis('off')
        
    # Annotations for key differences
    plt.figtext(0.5, 0.02, "Key Differences: LLaMA vs GPT-2\n" + 
                "1. Pre-normalization (RMSNorm) vs Post-normalization (LayerNorm)\n" + 
                "2. Rotary positional embeddings vs Learned positional embeddings\n" + 
                "3. Group Query Attention vs standard Multi-head Attention\n" + 
                "4. SwiGLU activation vs GELU activation", 
                ha="center", fontsize=12, bbox=dict(facecolor='lightyellow', alpha=0.5))
    
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    plt.show()

visualize_architecture_comparison()

## 9. A Simple Example: Testing the LLaMA Block

Let's create a small example to test the LLaMA block and ensure that our implementation is working correctly:

In [None]:
# Create a small test configuration
test_config = LLaMAConfig()
test_config.hidden_size = 64
test_config.intermediate_size = 128
test_config.n_heads = 4
test_config.n_kv_heads = 2
test_config.head_dim = 16
test_config.n_layers = 2

# Instantiate a LLaMA block
block = LLaMABlock(test_config)

# Create a random input tensor
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, test_config.hidden_size)

# Forward pass
out = block(x)

# Check output shape
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")

# Check that the output is different from the input (transformation happened)
print(f"Input != Output: {not torch.allclose(x, out)}")

# Check that gradients can flow
loss = out.sum()
loss.backward()
print(f"Block has parameters with gradients: {all(p.grad is not None for p in block.parameters())}")

## 10. Key Insights and Best Practices

Let's summarize the key innovations in the LLaMA architecture and their impact on model performance:

1. **RMSNorm**:
   - More computationally efficient than LayerNorm
   - Preserves vector directions, improving representation quality
   - Combined with pre-normalization for better gradient flow

2. **Rotary Positional Embeddings (RoPE)**:
   - Encodes relative positions more effectively than absolute embeddings
   - Allows models to generalize to longer sequences than seen during training
   - Theoretically sound approach based on complex number rotations

3. **Group Query Attention (GQA)**:
   - Reduces memory and computation costs compared to standard multi-head attention
   - Maintains representational power with more query heads
   - Particularly beneficial for memory bandwidth during inference

4. **SwiGLU Activation**:
   - More expressive than GELU with gating mechanism
   - Improves convergence and final model performance
   - Worth the additional computational cost

**Best Practices for Modern LLM Architecture**:

- Use pre-normalization (norm before each sub-layer) instead of post-normalization
- Prefer RMSNorm over LayerNorm for better scaling properties
- Implement Rotary Positional Embeddings for better position modeling
- Consider Group Query Attention for larger models where memory is a concern
- Use SwiGLU or similar gated activations in feed-forward networks
- Match the number of model dimensions to be divisible by the number of attention heads
- For very large models, consider techniques like parameter sharing and factorized attention

## 11. Conclusion

In this notebook, we've explored the key architectural differences between GPT-2 and the modern LLaMA family of models. We've implemented and compared each key component:

- Normalization: LayerNorm vs. RMSNorm
- Positional encoding: Learned embeddings vs. RoPE
- Attention: Standard MHA vs. Group Query Attention
- Feed-forward networks: GELU vs. SwiGLU

These architectural innovations have played a crucial role in scaling language models to unprecedented sizes while maintaining computational efficiency. Understanding these components provides insights into the principles behind modern LLM design and can guide future research and development in this rapidly evolving field.

The LLaMA architecture represents a careful balance of efficiency and performance, incorporating the best techniques from recent research to create models that can be trained and deployed at scale.